Unverified Commit cfc838dd authored by Denis Ismailaj's avatar Denis Ismailaj Committed by GitHub
Browse files

Respect explicitly set framework parameter in pipeline (#24322)

* Respect framework parameter

* Move check to pipeline()

* Add check inside infer_framework_load_model again
parent c5454eba
...@@ -781,9 +781,9 @@ def pipeline( ...@@ -781,9 +781,9 @@ def pipeline(
model_name = model if isinstance(model, str) else None model_name = model if isinstance(model, str) else None
# Infer the framework from the model # Load the correct model if possible
# Forced if framework already defined, inferred if it's None # Infer the framework from the model if not already defined
# Will load the correct model if possible if isinstance(model, str) or framework is None:
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]} model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
framework, model = infer_framework_load_model( framework, model = infer_framework_load_model(
model, model,
......
...@@ -277,6 +277,7 @@ def infer_framework_load_model( ...@@ -277,6 +277,7 @@ def infer_framework_load_model(
if isinstance(model, str): if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.") raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
if framework is None:
framework = infer_framework(model.__class__) framework = infer_framework(model.__class__)
return framework, model return framework, model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment