Unverified Commit cec5f7ab authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Update summarization `run_pipeline_test` (#20623)



* update summarization run_pipeline_test

* update
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 3e4c9e5c
...@@ -17,11 +17,7 @@ import unittest ...@@ -17,11 +17,7 @@ import unittest
from transformers import ( from transformers import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
LEDConfig,
LongT5Config,
SummarizationPipeline, SummarizationPipeline,
SwitchTransformersConfig,
T5Config,
pipeline, pipeline,
) )
from transformers.testing_utils import require_tf, require_torch, slow, torch_device from transformers.testing_utils import require_tf, require_torch, slow, torch_device
...@@ -55,7 +51,17 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe ...@@ -55,7 +51,17 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
) )
self.assertEqual(outputs, [{"summary_text": ANY(str)}]) self.assertEqual(outputs, [{"summary_text": ANY(str)}])
if not isinstance(model.config, (SwitchTransformersConfig, T5Config, LongT5Config, LEDConfig)): model_can_handle_longer_seq = [
"SwitchTransformersConfig",
"T5Config",
"LongT5Config",
"LEDConfig",
"PegasusXConfig",
"FSMTConfig",
"M2M100Config",
"ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values)
]
if model.config.__class__.__name__ not in model_can_handle_longer_seq:
# Switch Transformers, LED, T5, LongT5 can handle it. # Switch Transformers, LED, T5, LongT5 can handle it.
# Too long. # Too long.
with self.assertRaises(Exception): with self.assertRaises(Exception):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment