"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "66446909b236c17498276857fa88e23d2c91d004"
Unverified Commit fc21c9be authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[CookieCutter] Clarify questions (#18959)



* Clarify cookiecutter questions

* Update first question
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 6f8f2f6a
...@@ -1442,7 +1442,9 @@ def get_user_input(): ...@@ -1442,7 +1442,9 @@ def get_user_input():
# Get old model type # Get old model type
valid_model_type = False valid_model_type = False
while not valid_model_type: while not valid_model_type:
old_model_type = input("What is the model you would like to duplicate? ") old_model_type = input(
"What is the model you would like to duplicate? Please provide the lowercase `model_type` (e.g. roberta): "
)
if old_model_type in model_types: if old_model_type in model_types:
valid_model_type = True valid_model_type = True
else: else:
...@@ -1465,38 +1467,42 @@ def get_user_input(): ...@@ -1465,38 +1467,42 @@ def get_user_input():
"We couldn't find the name of the base checkpoint for that model, please enter it here." "We couldn't find the name of the base checkpoint for that model, please enter it here."
) )
model_name = get_user_field("What is the name for your new model?") model_name = get_user_field(
"What is the name (with no special casing) for your new model in the paper (e.g. RoBERTa)? "
)
default_patterns = ModelPatterns(model_name, model_name) default_patterns = ModelPatterns(model_name, model_name)
model_type = get_user_field( model_type = get_user_field(
"What identifier would you like to use for the model type of this model?", "What identifier would you like to use for the `model_type` of this model? ",
default_value=default_patterns.model_type, default_value=default_patterns.model_type,
) )
model_lower_cased = get_user_field( model_lower_cased = get_user_field(
"What name would you like to use for the module of this model?", "What lowercase name would you like to use for the module (folder) of this model? ",
default_value=default_patterns.model_lower_cased, default_value=default_patterns.model_lower_cased,
) )
model_camel_cased = get_user_field( model_camel_cased = get_user_field(
"What prefix (camel-cased) would you like to use for the model classes of this model?", "What prefix (camel-cased) would you like to use for the model classes of this model (e.g. Roberta)? ",
default_value=default_patterns.model_camel_cased, default_value=default_patterns.model_camel_cased,
) )
model_upper_cased = get_user_field( model_upper_cased = get_user_field(
"What prefix (upper-cased) would you like to use for the constants relative to this model?", "What prefix (upper-cased) would you like to use for the constants relative to this model? ",
default_value=default_patterns.model_upper_cased, default_value=default_patterns.model_upper_cased,
) )
config_class = get_user_field( config_class = get_user_field(
"What will be the name of the config class for this model?", default_value=f"{model_camel_cased}Config" "What will be the name of the config class for this model? ", default_value=f"{model_camel_cased}Config"
)
checkpoint = get_user_field(
"Please give a checkpoint identifier (on the model Hub) for this new model (e.g. facebook/roberta-base): "
) )
checkpoint = get_user_field("Please give a checkpoint identifier (on the model Hub) for this new model.")
old_processing_classes = [ old_processing_classes = [
c for c in [old_feature_extractor_class, old_tokenizer_class, old_processor_class] if c is not None c for c in [old_feature_extractor_class, old_tokenizer_class, old_processor_class] if c is not None
] ]
old_processing_classes = ", ".join(old_processing_classes) old_processing_classes = ", ".join(old_processing_classes)
keep_processing = get_user_field( keep_processing = get_user_field(
f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes})?", f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes}) (yes/no)? ",
convert_to=convert_to_bool, convert_to=convert_to_bool,
fallback_message="Please answer yes/no, y/n, true/false or 1/0.", fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
) )
if keep_processing: if keep_processing:
feature_extractor_class = old_feature_extractor_class feature_extractor_class = old_feature_extractor_class
...@@ -1505,21 +1511,21 @@ def get_user_input(): ...@@ -1505,21 +1511,21 @@ def get_user_input():
else: else:
if old_tokenizer_class is not None: if old_tokenizer_class is not None:
tokenizer_class = get_user_field( tokenizer_class = get_user_field(
"What will be the name of the tokenizer class for this model?", "What will be the name of the tokenizer class for this model? ",
default_value=f"{model_camel_cased}Tokenizer", default_value=f"{model_camel_cased}Tokenizer",
) )
else: else:
tokenizer_class = None tokenizer_class = None
if old_feature_extractor_class is not None: if old_feature_extractor_class is not None:
feature_extractor_class = get_user_field( feature_extractor_class = get_user_field(
"What will be the name of the feature extractor class for this model?", "What will be the name of the feature extractor class for this model? ",
default_value=f"{model_camel_cased}FeatureExtractor", default_value=f"{model_camel_cased}FeatureExtractor",
) )
else: else:
feature_extractor_class = None feature_extractor_class = None
if old_processor_class is not None: if old_processor_class is not None:
processor_class = get_user_field( processor_class = get_user_field(
"What will be the name of the processor class for this model?", "What will be the name of the processor class for this model? ",
default_value=f"{model_camel_cased}Processor", default_value=f"{model_camel_cased}Processor",
) )
else: else:
...@@ -1539,7 +1545,7 @@ def get_user_input(): ...@@ -1539,7 +1545,7 @@ def get_user_input():
) )
add_copied_from = get_user_field( add_copied_from = get_user_field(
"Should we add # Copied from statements when creating the new modeling file?", "Should we add # Copied from statements when creating the new modeling file (yes/no)? ",
convert_to=convert_to_bool, convert_to=convert_to_bool,
default_value="yes", default_value="yes",
fallback_message="Please answer yes/no, y/n, true/false or 1/0.", fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
...@@ -1547,7 +1553,7 @@ def get_user_input(): ...@@ -1547,7 +1553,7 @@ def get_user_input():
all_frameworks = get_user_field( all_frameworks = get_user_field(
"Should we add a version of your new model in all the frameworks implemented by" "Should we add a version of your new model in all the frameworks implemented by"
f" {old_model_type} ({old_frameworks})?", f" {old_model_type} ({old_frameworks}) (yes/no)? ",
convert_to=convert_to_bool, convert_to=convert_to_bool,
default_value="yes", default_value="yes",
fallback_message="Please answer yes/no, y/n, true/false or 1/0.", fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment