"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0b693e90e0748e16427a2764d516e9f5ba801bcc"
Unverified Commit cfc838dd authored by Denis Ismailaj's avatar Denis Ismailaj Committed by GitHub
Browse files

Respect explicitly set framework parameter in pipeline (#24322)

* Respect framework parameter

* Move check to pipeline()

* Add check inside infer_framework_load_model again
parent c5454eba
......@@ -781,19 +781,19 @@ def pipeline(
model_name = model if isinstance(model, str) else None
# Infer the framework from the model
# Forced if framework already defined, inferred if it's None
# Will load the correct model if possible
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
framework, model = infer_framework_load_model(
model,
model_classes=model_classes,
config=config,
framework=framework,
task=task,
**hub_kwargs,
**model_kwargs,
)
# Load the correct model if possible
# Infer the framework from the model if not already defined
if isinstance(model, str) or framework is None:
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
framework, model = infer_framework_load_model(
model,
model_classes=model_classes,
config=config,
framework=framework,
task=task,
**hub_kwargs,
**model_kwargs,
)
model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash
......
......@@ -277,7 +277,8 @@ def infer_framework_load_model(
if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
framework = infer_framework(model.__class__)
if framework is None:
framework = infer_framework(model.__class__)
return framework, model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment