Unverified Commit 7f921bcf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix add-new-model-like when old model checkpoint is not found (#15805)

* Fix add-new-model-like command when old checkpoint can't be recovered

* Style
parent bb7949b3
...@@ -1115,6 +1115,7 @@ def create_new_model_like( ...@@ -1115,6 +1115,7 @@ def create_new_model_like(
new_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns,
add_copied_from: bool = True, add_copied_from: bool = True,
frameworks: Optional[List[str]] = None, frameworks: Optional[List[str]] = None,
old_checkpoint: Optional[str] = None,
): ):
""" """
Creates a new model module like a given model of the Transformers library. Creates a new model module like a given model of the Transformers library.
...@@ -1126,11 +1127,22 @@ def create_new_model_like( ...@@ -1126,11 +1127,22 @@ def create_new_model_like(
Whether or not to add "Copied from" statements to all classes in the new model modeling files. Whether or not to add "Copied from" statements to all classes in the new model modeling files.
frameworks (`List[str]`, *optional*): frameworks (`List[str]`, *optional*):
If passed, will limit the duplicate to the frameworks specified. If passed, will limit the duplicate to the frameworks specified.
old_checkpoint (`str`, *optional*):
The name of the base checkpoint for the old model. Should be passed along when it can't be automatically
recovered from the `model_type`.
""" """
# Retrieve all the old model info. # Retrieve all the old model info.
model_info = retrieve_info_for_model(model_type, frameworks=frameworks) model_info = retrieve_info_for_model(model_type, frameworks=frameworks)
model_files = model_info["model_files"] model_files = model_info["model_files"]
old_model_patterns = model_info["model_patterns"] old_model_patterns = model_info["model_patterns"]
if old_checkpoint is not None:
old_model_patterns.checkpoint = old_checkpoint
if len(old_model_patterns.checkpoint) == 0:
raise ValueError(
"The old model checkpoint could not be recovered from the model type. Please pass it to the "
"`old_checkpoint` argument."
)
keep_old_processing = True keep_old_processing = True
for processing_attr in ["feature_extractor_class", "processor_class", "tokenizer_class"]: for processing_attr in ["feature_extractor_class", "processor_class", "tokenizer_class"]:
if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr): if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
...@@ -1291,8 +1303,15 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand): ...@@ -1291,8 +1303,15 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand):
self.model_patterns = ModelPatterns(**config["new_model_patterns"]) self.model_patterns = ModelPatterns(**config["new_model_patterns"])
self.add_copied_from = config.get("add_copied_from", True) self.add_copied_from = config.get("add_copied_from", True)
self.frameworks = config.get("frameworks", ["pt", "tf", "flax"]) self.frameworks = config.get("frameworks", ["pt", "tf", "flax"])
self.old_checkpoint = config.get("old_checkpoint", None)
else: else:
self.old_model_type, self.model_patterns, self.add_copied_from, self.frameworks = get_user_input() (
self.old_model_type,
self.model_patterns,
self.add_copied_from,
self.frameworks,
self.old_checkpoint,
) = get_user_input()
self.path_to_repo = path_to_repo self.path_to_repo = path_to_repo
...@@ -1310,6 +1329,7 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand): ...@@ -1310,6 +1329,7 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand):
new_model_patterns=self.model_patterns, new_model_patterns=self.model_patterns,
add_copied_from=self.add_copied_from, add_copied_from=self.add_copied_from,
frameworks=self.frameworks, frameworks=self.frameworks,
old_checkpoint=self.old_checkpoint,
) )
...@@ -1402,6 +1422,12 @@ def get_user_input(): ...@@ -1402,6 +1422,12 @@ def get_user_input():
old_processor_class = old_model_info["model_patterns"].processor_class old_processor_class = old_model_info["model_patterns"].processor_class
old_frameworks = old_model_info["frameworks"] old_frameworks = old_model_info["frameworks"]
old_checkpoint = None
if len(old_model_info["model_patterns"].checkpoint) == 0:
old_checkpoint = get_user_field(
"We couldn't find the name of the base checkpoint for that model, please enter it here."
)
model_name = get_user_field("What is the name for your new model?") model_name = get_user_field("What is the name for your new model?")
default_patterns = ModelPatterns(model_name, model_name) default_patterns = ModelPatterns(model_name, model_name)
...@@ -1497,4 +1523,4 @@ def get_user_input(): ...@@ -1497,4 +1523,4 @@ def get_user_input():
) )
frameworks = list(set(frameworks.split(" "))) frameworks = list(set(frameworks.split(" ")))
return (old_model_type, model_patterns, add_copied_from, frameworks) return (old_model_type, model_patterns, add_copied_from, frameworks, old_checkpoint)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment