Unverified Commit e7f33e8c authored by Alex Hedges's avatar Alex Hedges Committed by GitHub
Browse files

Pass `model_kwargs` when loading a model in `pipeline()` (#12449)

* Pass model_kwargs when loading a model in pipeline

* Add test for model_kwargs parameter of pipeline()

* Rewrite test to not download model

* Fix failing style checks
parent 18ca59e1
...@@ -426,7 +426,13 @@ def pipeline( ...@@ -426,7 +426,13 @@ def pipeline(
# Will load the correct model if possible # Will load the correct model if possible
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]} model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
framework, model = infer_framework_load_model( framework, model = infer_framework_load_model(
model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task model,
model_classes=model_classes,
config=config,
framework=framework,
revision=revision,
task=task,
**model_kwargs,
) )
model_config = model.config model_config = model.config
......
...@@ -61,6 +61,13 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. ...@@ -61,6 +61,13 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
for key in output_keys: for key in output_keys:
self.assertIn(key, result) self.assertIn(key, result)
@require_torch
def test_model_kwargs_passed_to_model_load(self):
ner_pipeline = pipeline(task="ner", model=self.small_models[0])
self.assertFalse(ner_pipeline.model.config.output_attentions)
ner_pipeline = pipeline(task="ner", model=self.small_models[0], model_kwargs={"output_attentions": True})
self.assertTrue(ner_pipeline.model.config.output_attentions)
@require_torch @require_torch
@slow @slow
def test_spanish_bert(self): def test_spanish_bert(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment