"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "211e130811145053440c56eec62ac6229d9d90b0"
Unverified Commit 7f921bcf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix add-new-model-like when old model checkpoint is not found (#15805)

* Fix add-new-model-like command when old checkpoint can't be recovered

* Style
parent bb7949b3
......@@ -1115,6 +1115,7 @@ def create_new_model_like(
new_model_patterns: ModelPatterns,
add_copied_from: bool = True,
frameworks: Optional[List[str]] = None,
old_checkpoint: Optional[str] = None,
):
"""
Creates a new model module like a given model of the Transformers library.
......@@ -1126,11 +1127,22 @@ def create_new_model_like(
Whether or not to add "Copied from" statements to all classes in the new model modeling files.
frameworks (`List[str]`, *optional*):
If passed, will limit the duplicate to the frameworks specified.
old_checkpoint (`str`, *optional*):
The name of the base checkpoint for the old model. Should be passed along when it can't be automatically
recovered from the `model_type`.
"""
# Retrieve all the old model info.
model_info = retrieve_info_for_model(model_type, frameworks=frameworks)
model_files = model_info["model_files"]
old_model_patterns = model_info["model_patterns"]
if old_checkpoint is not None:
old_model_patterns.checkpoint = old_checkpoint
if len(old_model_patterns.checkpoint) == 0:
raise ValueError(
"The old model checkpoint could not be recovered from the model type. Please pass it to the "
"`old_checkpoint` argument."
)
keep_old_processing = True
for processing_attr in ["feature_extractor_class", "processor_class", "tokenizer_class"]:
if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
......@@ -1291,8 +1303,15 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand):
self.model_patterns = ModelPatterns(**config["new_model_patterns"])
self.add_copied_from = config.get("add_copied_from", True)
self.frameworks = config.get("frameworks", ["pt", "tf", "flax"])
self.old_checkpoint = config.get("old_checkpoint", None)
else:
self.old_model_type, self.model_patterns, self.add_copied_from, self.frameworks = get_user_input()
(
self.old_model_type,
self.model_patterns,
self.add_copied_from,
self.frameworks,
self.old_checkpoint,
) = get_user_input()
self.path_to_repo = path_to_repo
......@@ -1310,6 +1329,7 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand):
new_model_patterns=self.model_patterns,
add_copied_from=self.add_copied_from,
frameworks=self.frameworks,
old_checkpoint=self.old_checkpoint,
)
......@@ -1402,6 +1422,12 @@ def get_user_input():
old_processor_class = old_model_info["model_patterns"].processor_class
old_frameworks = old_model_info["frameworks"]
old_checkpoint = None
if len(old_model_info["model_patterns"].checkpoint) == 0:
old_checkpoint = get_user_field(
"We couldn't find the name of the base checkpoint for that model, please enter it here."
)
model_name = get_user_field("What is the name for your new model?")
default_patterns = ModelPatterns(model_name, model_name)
......@@ -1497,4 +1523,4 @@ def get_user_input():
)
frameworks = list(set(frameworks.split(" ")))
return (old_model_type, model_patterns, add_copied_from, frameworks)
return (old_model_type, model_patterns, add_copied_from, frameworks, old_checkpoint)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment