"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7490a97cac20cef6858f32e5f39a61f31ad64552"
Unverified Commit fc21c9be authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[CookieCutter] Clarify questions (#18959)



* Clarify cookiecutter questions

* Update first question
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 6f8f2f6a
......@@ -1442,7 +1442,9 @@ def get_user_input():
# Get old model type
valid_model_type = False
while not valid_model_type:
old_model_type = input("What is the model you would like to duplicate? ")
old_model_type = input(
"What is the model you would like to duplicate? Please provide the lowercase `model_type` (e.g. roberta): "
)
if old_model_type in model_types:
valid_model_type = True
else:
......@@ -1465,38 +1467,42 @@ def get_user_input():
"We couldn't find the name of the base checkpoint for that model, please enter it here."
)
model_name = get_user_field("What is the name for your new model?")
model_name = get_user_field(
"What is the name (with no special casing) for your new model in the paper (e.g. RoBERTa)? "
)
default_patterns = ModelPatterns(model_name, model_name)
model_type = get_user_field(
"What identifier would you like to use for the model type of this model?",
"What identifier would you like to use for the `model_type` of this model? ",
default_value=default_patterns.model_type,
)
model_lower_cased = get_user_field(
"What name would you like to use for the module of this model?",
"What lowercase name would you like to use for the module (folder) of this model? ",
default_value=default_patterns.model_lower_cased,
)
model_camel_cased = get_user_field(
"What prefix (camel-cased) would you like to use for the model classes of this model?",
"What prefix (camel-cased) would you like to use for the model classes of this model (e.g. Roberta)? ",
default_value=default_patterns.model_camel_cased,
)
model_upper_cased = get_user_field(
"What prefix (upper-cased) would you like to use for the constants relative to this model?",
"What prefix (upper-cased) would you like to use for the constants relative to this model? ",
default_value=default_patterns.model_upper_cased,
)
config_class = get_user_field(
"What will be the name of the config class for this model?", default_value=f"{model_camel_cased}Config"
"What will be the name of the config class for this model? ", default_value=f"{model_camel_cased}Config"
)
checkpoint = get_user_field(
"Please give a checkpoint identifier (on the model Hub) for this new model (e.g. facebook/roberta-base): "
)
checkpoint = get_user_field("Please give a checkpoint identifier (on the model Hub) for this new model.")
old_processing_classes = [
c for c in [old_feature_extractor_class, old_tokenizer_class, old_processor_class] if c is not None
]
old_processing_classes = ", ".join(old_processing_classes)
keep_processing = get_user_field(
f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes})?",
f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes}) (yes/no)? ",
convert_to=convert_to_bool,
fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
)
if keep_processing:
feature_extractor_class = old_feature_extractor_class
......@@ -1505,21 +1511,21 @@ def get_user_input():
else:
if old_tokenizer_class is not None:
tokenizer_class = get_user_field(
"What will be the name of the tokenizer class for this model?",
"What will be the name of the tokenizer class for this model? ",
default_value=f"{model_camel_cased}Tokenizer",
)
else:
tokenizer_class = None
if old_feature_extractor_class is not None:
feature_extractor_class = get_user_field(
"What will be the name of the feature extractor class for this model?",
"What will be the name of the feature extractor class for this model? ",
default_value=f"{model_camel_cased}FeatureExtractor",
)
else:
feature_extractor_class = None
if old_processor_class is not None:
processor_class = get_user_field(
"What will be the name of the processor class for this model?",
"What will be the name of the processor class for this model? ",
default_value=f"{model_camel_cased}Processor",
)
else:
......@@ -1539,7 +1545,7 @@ def get_user_input():
)
add_copied_from = get_user_field(
"Should we add # Copied from statements when creating the new modeling file?",
"Should we add # Copied from statements when creating the new modeling file (yes/no)? ",
convert_to=convert_to_bool,
default_value="yes",
fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
......@@ -1547,7 +1553,7 @@ def get_user_input():
all_frameworks = get_user_field(
"Should we add a version of your new model in all the frameworks implemented by"
f" {old_model_type} ({old_frameworks})?",
f" {old_model_type} ({old_frameworks}) (yes/no)? ",
convert_to=convert_to_bool,
default_value="yes",
fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment