Unverified Commit 6f0723a9 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Restore original task in test_warning_logs (#17985)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 009171d1
...@@ -777,9 +777,17 @@ class PipelineRegistryTest(unittest.TestCase): ...@@ -777,9 +777,17 @@ class PipelineRegistryTest(unittest.TestCase):
logger_ = transformers_logging.get_logger("transformers.pipelines.base") logger_ = transformers_logging.get_logger("transformers.pipelines.base")
alias = "text-classification" alias = "text-classification"
# Get the original task, so we can restore it at the end.
# (otherwise the subsequential tests in `TextClassificationPipelineTests` will fail)
original_task, original_task_options = PIPELINE_REGISTRY.check_task(alias)
try:
with CaptureLogger(logger_) as cm: with CaptureLogger(logger_) as cm:
PIPELINE_REGISTRY.register_pipeline(alias, {}) PIPELINE_REGISTRY.register_pipeline(alias, {})
self.assertIn(f"{alias} is already registered", cm.out) self.assertIn(f"{alias} is already registered", cm.out)
finally:
# restore
PIPELINE_REGISTRY.register_pipeline(alias, original_task)
@require_torch @require_torch
def test_register_pipeline(self): def test_register_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment