Unverified Commit 7e4844fc authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Enable ONNX export when PyTorch and TensorFlow installed in the same environment (#15625)

parent 6cf06d19
...@@ -303,8 +303,16 @@ class FeaturesManager: ...@@ -303,8 +303,16 @@ class FeaturesManager:
The instance of the model. The instance of the model.
""" """
# If PyTorch and TensorFlow are installed in the same environment, we
# load an AutoModel class by default
model_class = FeaturesManager.get_model_class_for_feature(feature) model_class = FeaturesManager.get_model_class_for_feature(feature)
return model_class.from_pretrained(model) try:
model = model_class.from_pretrained(model)
# Load TensorFlow weights in an AutoModel instance if PyTorch and
# TensorFlow are installed in the same environment
except OSError:
model = model_class.from_pretrained(model, from_tf=True)
return model
@staticmethod @staticmethod
def check_supported_model_or_raise( def check_supported_model_or_raise(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment