Unverified Commit 2afe9cd2 authored by Xia's avatar Xia Committed by GitHub
Browse files

Add argument "cache_dir" for transformers.onnx (#16284)

* Add argument "cache_dir" for transformers.onnx

* Reformate files that can't pass CI.
parent 3f0f75e4
...@@ -42,6 +42,7 @@ def main(): ...@@ -42,6 +42,7 @@ def main():
"--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export." "--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export."
) )
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
# Retrieve CLI arguments # Retrieve CLI arguments
args = parser.parse_args() args = parser.parse_args()
...@@ -61,7 +62,9 @@ def main(): ...@@ -61,7 +62,9 @@ def main():
raise ValueError(f"Unsupported model type: {config.model_type}") raise ValueError(f"Unsupported model type: {config.model_type}")
# Allocate the model # Allocate the model
model = FeaturesManager.get_model_from_feature(args.feature, args.model, framework=args.framework) model = FeaturesManager.get_model_from_feature(
args.feature, args.model, framework=args.framework, cache_dir=args.cache_dir
)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature) model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config) onnx_config = model_onnx_config(model.config)
......
...@@ -325,7 +325,7 @@ class FeaturesManager: ...@@ -325,7 +325,7 @@ class FeaturesManager:
return task_to_automodel[task] return task_to_automodel[task]
def get_model_from_feature( def get_model_from_feature(
feature: str, model: str, framework: str = "pt" feature: str, model: str, framework: str = "pt", cache_dir: str = None
) -> Union[PreTrainedModel, TFPreTrainedModel]: ) -> Union[PreTrainedModel, TFPreTrainedModel]:
""" """
Attempts to retrieve a model from a model's name and the feature to be enabled. Attempts to retrieve a model from a model's name and the feature to be enabled.
...@@ -344,12 +344,12 @@ class FeaturesManager: ...@@ -344,12 +344,12 @@ class FeaturesManager:
""" """
model_class = FeaturesManager.get_model_class_for_feature(feature, framework) model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
try: try:
model = model_class.from_pretrained(model) model = model_class.from_pretrained(model, cache_dir=cache_dir)
except OSError: except OSError:
if framework == "pt": if framework == "pt":
model = model_class.from_pretrained(model, from_tf=True) model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir)
else: else:
model = model_class.from_pretrained(model, from_pt=True) model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir)
return model return model
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment