Unverified Commit fdf84096 authored by Matt's avatar Matt Committed by GitHub
Browse files

pt-to-tf model architecture override (#22055)

* Add an argument to pt-to-tf to allow overriding the model class

* make fixup

* Minor fix to error message

* Remove unused extra conversion from the script
parent 04bfac83
...@@ -68,6 +68,7 @@ def convert_command_factory(args: Namespace): ...@@ -68,6 +68,7 @@ def convert_command_factory(args: Namespace):
args.no_pr, args.no_pr,
args.push, args.push,
args.extra_commit_description, args.extra_commit_description,
args.override_model_class,
) )
...@@ -126,6 +127,13 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -126,6 +127,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
default="", default="",
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).", help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
) )
train_parser.add_argument(
"--override-model-class",
type=str,
default=None,
help="If you think you know better than the auto-detector, you can specify the model class here. "
"Can be either an AutoModel class or a specific model class like BertForSequenceClassification.",
)
train_parser.set_defaults(func=convert_command_factory) train_parser.set_defaults(func=convert_command_factory)
@staticmethod @staticmethod
...@@ -175,6 +183,7 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -175,6 +183,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
no_pr: bool, no_pr: bool,
push: bool, push: bool,
extra_commit_description: str, extra_commit_description: str,
override_model_class: str,
*args, *args,
): ):
self._logger = logging.get_logger("transformers-cli/pt_to_tf") self._logger = logging.get_logger("transformers-cli/pt_to_tf")
...@@ -185,6 +194,7 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -185,6 +194,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self._no_pr = no_pr self._no_pr = no_pr
self._push = push self._push = push
self._extra_commit_description = extra_commit_description self._extra_commit_description = extra_commit_description
self._override_model_class = override_model_class
def get_inputs(self, pt_model, config): def get_inputs(self, pt_model, config):
""" """
...@@ -269,7 +279,20 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -269,7 +279,20 @@ class PTtoTFCommand(BaseTransformersCLICommand):
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights # Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
config = AutoConfig.from_pretrained(self._local_dir) config = AutoConfig.from_pretrained(self._local_dir)
architectures = config.architectures architectures = config.architectures
if architectures is None: # No architecture defined -- use auto classes if self._override_model_class is not None:
if self._override_model_class.startswith("TF"):
architectures = [self._override_model_class[2:]]
else:
architectures = [self._override_model_class]
try:
pt_class = getattr(import_module("transformers"), architectures[0])
except AttributeError:
raise ValueError(f"Model class {self._override_model_class} not found in transformers.")
try:
tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
except AttributeError:
raise ValueError(f"TF model class TF{self._override_model_class} not found in transformers.")
elif architectures is None: # No architecture defined -- use auto classes
pt_class = getattr(import_module("transformers"), "AutoModel") pt_class = getattr(import_module("transformers"), "AutoModel")
tf_class = getattr(import_module("transformers"), "TFAutoModel") tf_class = getattr(import_module("transformers"), "TFAutoModel")
self._logger.warning("No detected architecture, using AutoModel/TFAutoModel") self._logger.warning("No detected architecture, using AutoModel/TFAutoModel")
...@@ -287,7 +310,6 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -287,7 +310,6 @@ class PTtoTFCommand(BaseTransformersCLICommand):
pt_model = pt_class.from_pretrained(self._local_dir) pt_model = pt_class.from_pretrained(self._local_dir)
pt_model.eval() pt_model.eval()
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
pt_input, tf_input = self.get_inputs(pt_model, config) pt_input, tf_input = self.get_inputs(pt_model, config)
with torch.no_grad(): with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment