Unverified Commit a12c5cbc authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Change a logic in pipeline test regarding TF (#20710)



* Fix the pipeline test regarding TF

* Fix the pipeline test regarding TF

* update comment
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 1af4bee8
...@@ -18,9 +18,10 @@ from transformers import ( ...@@ -18,9 +18,10 @@ from transformers import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
SummarizationPipeline, SummarizationPipeline,
TFPreTrainedModel,
pipeline, pipeline,
) )
from transformers.testing_utils import require_tf, require_torch, slow, torch_device from transformers.testing_utils import get_gpu_count, require_tf, require_torch, slow, torch_device
from transformers.tokenization_utils import TruncationStrategy from transformers.tokenization_utils import TruncationStrategy
from .test_pipelines_common import ANY, PipelineTestCaseMeta from .test_pipelines_common import ANY, PipelineTestCaseMeta
...@@ -51,6 +52,7 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe ...@@ -51,6 +52,7 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
) )
self.assertEqual(outputs, [{"summary_text": ANY(str)}]) self.assertEqual(outputs, [{"summary_text": ANY(str)}])
# Some models (Switch Transformers, LED, T5, LongT5, etc) can handle long sequences.
model_can_handle_longer_seq = [ model_can_handle_longer_seq = [
"SwitchTransformersConfig", "SwitchTransformersConfig",
"T5Config", "T5Config",
...@@ -62,10 +64,16 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe ...@@ -62,10 +64,16 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
"ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values) "ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values)
] ]
if model.config.__class__.__name__ not in model_can_handle_longer_seq: if model.config.__class__.__name__ not in model_can_handle_longer_seq:
# Switch Transformers, LED, T5, LongT5 can handle it. # Too long and exception is expected.
# Too long. # For TF models, if the weights are initialized in GPU context, we won't get expected index error from
with self.assertRaises(Exception): # the embedding layer.
outputs = summarizer("This " * 1000) if not (
isinstance(model, TFPreTrainedModel)
and get_gpu_count() > 0
and len(summarizer.model.trainable_weights) > 0
):
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST) outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST)
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment