Unverified Commit 321ef388 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Include image processor in add-new-model-like (#20439)

parent 0bae286d
......@@ -62,6 +62,9 @@ class ModelPatterns:
The tokenizer class associated with this model. Will default to `"{model_camel_cased}Config"`.
tokenizer_class (`str`, *optional*):
The tokenizer class associated with this model (leave to `None` for models that don't use a tokenizer).
image_processor_class (`str`, *optional*):
The image processor class associated with this model (leave to `None` for models that don't use an image
processor).
feature_extractor_class (`str`, *optional*):
The feature extractor class associated with this model (leave to `None` for models that don't use a feature
extractor).
......@@ -77,6 +80,7 @@ class ModelPatterns:
model_upper_cased: Optional[str] = None
config_class: Optional[str] = None
tokenizer_class: Optional[str] = None
image_processor_class: Optional[str] = None
feature_extractor_class: Optional[str] = None
processor_class: Optional[str] = None
......@@ -101,6 +105,7 @@ class ModelPatterns:
ATTRIBUTE_TO_PLACEHOLDER = {
"config_class": "[CONFIG_CLASS]",
"tokenizer_class": "[TOKENIZER_CLASS]",
"image_processor_class": "[IMAGE_PROCESSOR_CLASS]",
"feature_extractor_class": "[FEATURE_EXTRACTOR_CLASS]",
"processor_class": "[PROCESSOR_CLASS]",
"checkpoint": "[CHECKPOINT]",
......@@ -283,7 +288,7 @@ def replace_model_patterns(
# contains the camel-cased named, but will be treated before.
attributes_to_check = ["config_class"]
# Add relevant preprocessing classes
for attr in ["tokenizer_class", "feature_extractor_class", "processor_class"]:
for attr in ["tokenizer_class", "image_processor_class", "feature_extractor_class", "processor_class"]:
if getattr(old_model_patterns, attr) is not None and getattr(new_model_patterns, attr) is not None:
attributes_to_check.append(attr)
......@@ -553,6 +558,7 @@ def get_model_files(model_type: str, frameworks: Optional[List[str]] = None) ->
f"test_modeling_tf_{module_name}.py",
f"test_modeling_flax_{module_name}.py",
f"test_tokenization_{module_name}.py",
f"test_image_processing_{module_name}.py",
f"test_feature_extraction_{module_name}.py",
f"test_processor_{module_name}.py",
]
......@@ -687,6 +693,7 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
tokenizer_class = tokenizer_classes[0] if tokenizer_classes[0] is not None else tokenizer_classes[1]
else:
tokenizer_class = None
image_processor_class = auto_module.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES.get(model_type, None)
feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None)
processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None)
......@@ -731,6 +738,7 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
model_upper_cased=model_upper_cased,
config_class=config_class,
tokenizer_class=tokenizer_class,
image_processor_class=image_processor_class,
feature_extractor_class=feature_extractor_class,
processor_class=processor_class,
)
......@@ -748,14 +756,15 @@ def clean_frameworks_in_init(
):
"""
Removes all the import lines that don't belong to a given list of frameworks or concern tokenizers/feature
extractors/processors in an init.
extractors/image processors/processors in an init.
Args:
init_file (`str` or `os.PathLike`): The path to the init to treat.
frameworks (`List[str]`, *optional*):
If passed, this will remove all imports that are subject to a framework not in frameworks
keep_processing (`bool`, *optional*, defaults to `True`):
Whether or not to keep the preprocessing (tokenizer, feature extractor, processor) imports in the init.
Whether or not to keep the preprocessing (tokenizer, feature extractor, image processor, processor) imports
in the init.
"""
if frameworks is None:
frameworks = get_default_frameworks()
......@@ -808,8 +817,9 @@ def clean_frameworks_in_init(
idx += 1
# Otherwise we keep the line, except if it's a tokenizer import and we don't want to keep it.
elif keep_processing or (
re.search('^\s*"(tokenization|processing|feature_extraction)', lines[idx]) is None
and re.search("^\s*from .(tokenization|processing|feature_extraction)", lines[idx]) is None
re.search('^\s*"(tokenization|processing|feature_extraction|image_processing)', lines[idx]) is None
and re.search("^\s*from .(tokenization|processing|feature_extraction|image_processing)", lines[idx])
is None
):
new_lines.append(lines[idx])
idx += 1
......@@ -885,6 +895,7 @@ def add_model_to_main_init(
if not with_processing:
processing_classes = [
old_model_patterns.tokenizer_class,
old_model_patterns.image_processor_class,
old_model_patterns.feature_extractor_class,
old_model_patterns.processor_class,
]
......@@ -962,6 +973,7 @@ AUTO_CLASSES_PATTERNS = {
' ("{model_type}", "{pretrained_archive_map}"),',
],
"feature_extraction_auto.py": [' ("{model_type}", "{feature_extractor_class}"),'],
"image_processing_auto.py": [' ("{model_type}", "{image_processor_class}"),'],
"modeling_auto.py": [' ("{model_type}", "{any_pt_class}"),'],
"modeling_tf_auto.py": [' ("{model_type}", "{any_tf_class}"),'],
"modeling_flax_auto.py": [' ("{model_type}", "{any_flax_class}"),'],
......@@ -995,6 +1007,14 @@ def add_model_to_auto_classes(
)
elif "{config_class}" in pattern:
new_patterns.append(pattern.replace("{config_class}", old_model_patterns.config_class))
elif "{image_processor_class}" in pattern:
if (
old_model_patterns.image_processor_class is not None
and new_model_patterns.image_processor_class is not None
):
new_patterns.append(
pattern.replace("{image_processor_class}", old_model_patterns.image_processor_class)
)
elif "{feature_extractor_class}" in pattern:
if (
old_model_patterns.feature_extractor_class is not None
......@@ -1121,6 +1141,10 @@ def duplicate_doc_file(
# We only add the tokenizer if necessary
if old_model_patterns.tokenizer_class != new_model_patterns.tokenizer_class:
new_blocks.append(new_block)
elif "ImageProcessor" in block_class:
# We only add the image processor if necessary
if old_model_patterns.image_processor_class != new_model_patterns.image_processor_class:
new_blocks.append(new_block)
elif "FeatureExtractor" in block_class:
# We only add the feature extractor if necessary
if old_model_patterns.feature_extractor_class != new_model_patterns.feature_extractor_class:
......@@ -1182,7 +1206,7 @@ def create_new_model_like(
)
keep_old_processing = True
for processing_attr in ["feature_extractor_class", "processor_class", "tokenizer_class"]:
for processing_attr in ["image_processor_class", "feature_extractor_class", "processor_class", "tokenizer_class"]:
if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
keep_old_processing = False
......@@ -1198,7 +1222,10 @@ def create_new_model_like(
files_to_adapt = [
f
for f in files_to_adapt
if "tokenization" not in str(f) and "processing" not in str(f) and "feature_extraction" not in str(f)
if "tokenization" not in str(f)
and "processing" not in str(f)
and "feature_extraction" not in str(f)
and "image_processing" not in str(f)
]
os.makedirs(module_folder, exist_ok=True)
......@@ -1236,7 +1263,10 @@ def create_new_model_like(
files_to_adapt = [
f
for f in files_to_adapt
if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f)
if "tokenization" not in str(f)
and "processor" not in str(f)
and "feature_extraction" not in str(f)
and "image_processing" not in str(f)
]
def disable_fx_test(filename: Path) -> bool:
......@@ -1458,6 +1488,7 @@ def get_user_input():
old_model_info = retrieve_info_for_model(old_model_type)
old_tokenizer_class = old_model_info["model_patterns"].tokenizer_class
old_image_processor_class = old_model_info["model_patterns"].image_processor_class
old_feature_extractor_class = old_model_info["model_patterns"].feature_extractor_class
old_processor_class = old_model_info["model_patterns"].processor_class
old_frameworks = old_model_info["frameworks"]
......@@ -1497,7 +1528,9 @@ def get_user_input():
)
old_processing_classes = [
c for c in [old_feature_extractor_class, old_tokenizer_class, old_processor_class] if c is not None
c
for c in [old_image_processor_class, old_feature_extractor_class, old_tokenizer_class, old_processor_class]
if c is not None
]
old_processing_classes = ", ".join(old_processing_classes)
keep_processing = get_user_field(
......@@ -1506,6 +1539,7 @@ def get_user_input():
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
)
if keep_processing:
image_processor_class = old_image_processor_class
feature_extractor_class = old_feature_extractor_class
processor_class = old_processor_class
tokenizer_class = old_tokenizer_class
......@@ -1517,6 +1551,11 @@ def get_user_input():
)
else:
tokenizer_class = None
if old_image_processor_class is not None:
image_processor_class = get_user_field(
"What will be the name of the image processor class for this model? ",
default_value=f"{model_camel_cased}ImageProcessor",
)
if old_feature_extractor_class is not None:
feature_extractor_class = get_user_field(
"What will be the name of the feature extractor class for this model? ",
......@@ -1541,6 +1580,7 @@ def get_user_input():
model_upper_cased=model_upper_cased,
config_class=config_class,
tokenizer_class=tokenizer_class,
image_processor_class=image_processor_class,
feature_extractor_class=feature_extractor_class,
processor_class=processor_class,
)
......
......@@ -44,12 +44,14 @@ BERT_MODEL_FILES = {
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/models/bert/tokenization_bert.py",
"src/transformers/models/bert/tokenization_bert_fast.py",
"src/transformers/models/bert/tokenization_bert_tf.py",
"src/transformers/models/bert/modeling_bert.py",
"src/transformers/models/bert/modeling_flax_bert.py",
"src/transformers/models/bert/modeling_tf_bert.py",
"src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py",
"src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py",
"src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py",
"src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py",
}
VIT_MODEL_FILES = {
......@@ -58,6 +60,7 @@ VIT_MODEL_FILES = {
"src/transformers/models/vit/convert_dino_to_pytorch.py",
"src/transformers/models/vit/convert_vit_timm_to_pytorch.py",
"src/transformers/models/vit/feature_extraction_vit.py",
"src/transformers/models/vit/image_processing_vit.py",
"src/transformers/models/vit/modeling_vit.py",
"src/transformers/models/vit/modeling_tf_vit.py",
"src/transformers/models/vit/modeling_flax_vit.py",
......@@ -89,7 +92,8 @@ class TestAddNewModelLike(unittest.TestCase):
def check_result(self, file_name, expected_result):
with open(file_name, "r", encoding="utf-8") as f:
self.assertEqual(f.read(), expected_result)
result = f.read()
self.assertEqual(result, expected_result)
def test_re_class_func(self):
self.assertEqual(_re_class_func.search("def my_function(x, y):").groups()[0], "my_function")
......@@ -439,7 +443,7 @@ NEW_BERT_CONSTANT = "value"
self.check_result(dest_file_name, bert_expected)
def test_filter_framework_files(self):
files = ["modeling_tf_bert.py", "modeling_bert.py", "modeling_flax_bert.py", "configuration_bert.py"]
files = ["modeling_bert.py", "modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"]
self.assertEqual(filter_framework_files(files), files)
self.assertEqual(set(filter_framework_files(files, ["pt", "tf", "flax"])), set(files))
......@@ -467,7 +471,7 @@ NEW_BERT_CONSTANT = "value"
bert_files = get_model_files("bert")
doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
self.assertEqual(model_files, BERT_MODEL_FILES)
......@@ -476,17 +480,17 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
bert_test_files = {
"tests/test_tokenization_bert.py",
"tests/test_modeling_bert.py",
"tests/test_modeling_tf_bert.py",
"tests/test_modeling_flax_bert.py",
"tests/models/bert/test_tokenization_bert.py",
"tests/models/bert/test_modeling_bert.py",
"tests/models/bert/test_modeling_tf_bert.py",
"tests/models/bert/test_modeling_flax_bert.py",
}
self.assertEqual(test_files, bert_test_files)
# VIT
vit_files = get_model_files("vit")
doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
self.assertEqual(model_files, VIT_MODEL_FILES)
......@@ -495,17 +499,17 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
vit_test_files = {
"tests/test_feature_extraction_vit.py",
"tests/test_modeling_vit.py",
"tests/test_modeling_tf_vit.py",
"tests/test_modeling_flax_vit.py",
"tests/models/vit/test_feature_extraction_vit.py",
"tests/models/vit/test_modeling_vit.py",
"tests/models/vit/test_modeling_tf_vit.py",
"tests/models/vit/test_modeling_flax_vit.py",
}
self.assertEqual(test_files, vit_test_files)
# Wav2Vec2
wav2vec2_files = get_model_files("wav2vec2")
doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
self.assertEqual(model_files, WAV2VEC2_MODEL_FILES)
......@@ -514,12 +518,12 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
wav2vec2_test_files = {
"tests/test_feature_extraction_wav2vec2.py",
"tests/test_modeling_wav2vec2.py",
"tests/test_modeling_tf_wav2vec2.py",
"tests/test_modeling_flax_wav2vec2.py",
"tests/test_processor_wav2vec2.py",
"tests/test_tokenization_wav2vec2.py",
"tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
"tests/models/wav2vec2/test_modeling_wav2vec2.py",
"tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
"tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
"tests/models/wav2vec2/test_processor_wav2vec2.py",
"tests/models/wav2vec2/test_tokenization_wav2vec2.py",
}
self.assertEqual(test_files, wav2vec2_test_files)
......@@ -528,7 +532,7 @@ NEW_BERT_CONSTANT = "value"
bert_files = get_model_files("bert", frameworks=["pt"])
doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
bert_model_files = BERT_MODEL_FILES - {
......@@ -541,15 +545,15 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
bert_test_files = {
"tests/test_tokenization_bert.py",
"tests/test_modeling_bert.py",
"tests/models/bert/test_tokenization_bert.py",
"tests/models/bert/test_modeling_bert.py",
}
self.assertEqual(test_files, bert_test_files)
# VIT
vit_files = get_model_files("vit", frameworks=["pt"])
doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
vit_model_files = VIT_MODEL_FILES - {
......@@ -562,15 +566,15 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
vit_test_files = {
"tests/test_feature_extraction_vit.py",
"tests/test_modeling_vit.py",
"tests/models/vit/test_feature_extraction_vit.py",
"tests/models/vit/test_modeling_vit.py",
}
self.assertEqual(test_files, vit_test_files)
# Wav2Vec2
wav2vec2_files = get_model_files("wav2vec2", frameworks=["pt"])
doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {
......@@ -583,10 +587,10 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
wav2vec2_test_files = {
"tests/test_feature_extraction_wav2vec2.py",
"tests/test_modeling_wav2vec2.py",
"tests/test_processor_wav2vec2.py",
"tests/test_tokenization_wav2vec2.py",
"tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
"tests/models/wav2vec2/test_modeling_wav2vec2.py",
"tests/models/wav2vec2/test_processor_wav2vec2.py",
"tests/models/wav2vec2/test_tokenization_wav2vec2.py",
}
self.assertEqual(test_files, wav2vec2_test_files)
......@@ -595,7 +599,7 @@ NEW_BERT_CONSTANT = "value"
bert_files = get_model_files("bert", frameworks=["tf", "flax"])
doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_bert.py"}
......@@ -605,16 +609,16 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
bert_test_files = {
"tests/test_tokenization_bert.py",
"tests/test_modeling_tf_bert.py",
"tests/test_modeling_flax_bert.py",
"tests/models/bert/test_tokenization_bert.py",
"tests/models/bert/test_modeling_tf_bert.py",
"tests/models/bert/test_modeling_flax_bert.py",
}
self.assertEqual(test_files, bert_test_files)
# VIT
vit_files = get_model_files("vit", frameworks=["tf", "flax"])
doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
vit_model_files = VIT_MODEL_FILES - {"src/transformers/models/vit/modeling_vit.py"}
......@@ -624,16 +628,16 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
vit_test_files = {
"tests/test_feature_extraction_vit.py",
"tests/test_modeling_tf_vit.py",
"tests/test_modeling_flax_vit.py",
"tests/models/vit/test_feature_extraction_vit.py",
"tests/models/vit/test_modeling_tf_vit.py",
"tests/models/vit/test_modeling_flax_vit.py",
}
self.assertEqual(test_files, vit_test_files)
# Wav2Vec2
wav2vec2_files = get_model_files("wav2vec2", frameworks=["tf", "flax"])
doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {"src/transformers/models/wav2vec2/modeling_wav2vec2.py"}
......@@ -643,11 +647,11 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
wav2vec2_test_files = {
"tests/test_feature_extraction_wav2vec2.py",
"tests/test_modeling_tf_wav2vec2.py",
"tests/test_modeling_flax_wav2vec2.py",
"tests/test_processor_wav2vec2.py",
"tests/test_tokenization_wav2vec2.py",
"tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
"tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
"tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
"tests/models/wav2vec2/test_processor_wav2vec2.py",
"tests/models/wav2vec2/test_tokenization_wav2vec2.py",
}
self.assertEqual(test_files, wav2vec2_test_files)
......@@ -688,7 +692,7 @@ NEW_BERT_CONSTANT = "value"
expected_model_classes = {
"pt": set(bert_classes),
"tf": {f"TF{m}" for m in bert_classes},
"flax": {f"Flax{m}" for m in bert_classes[:-1]},
"flax": {f"Flax{m}" for m in bert_classes[:-1] + ["BertForCausalLM"]},
}
self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf", "flax"})
......@@ -701,15 +705,15 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
bert_test_files = {
"tests/test_tokenization_bert.py",
"tests/test_modeling_bert.py",
"tests/test_modeling_tf_bert.py",
"tests/test_modeling_flax_bert.py",
"tests/models/bert/test_tokenization_bert.py",
"tests/models/bert/test_modeling_bert.py",
"tests/models/bert/test_modeling_tf_bert.py",
"tests/models/bert/test_modeling_flax_bert.py",
}
self.assertEqual(test_files, bert_test_files)
doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
self.assertEqual(all_bert_files["module_name"], "bert")
......@@ -751,14 +755,14 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
bert_test_files = {
"tests/test_tokenization_bert.py",
"tests/test_modeling_bert.py",
"tests/test_modeling_tf_bert.py",
"tests/models/bert/test_tokenization_bert.py",
"tests/models/bert/test_modeling_bert.py",
"tests/models/bert/test_modeling_tf_bert.py",
}
self.assertEqual(test_files, bert_test_files)
doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.mdx")
self.assertEqual(all_bert_files["module_name"], "bert")
......@@ -777,8 +781,9 @@ NEW_BERT_CONSTANT = "value"
def test_retrieve_info_for_model_with_vit(self):
vit_info = retrieve_info_for_model("vit")
vit_classes = ["ViTForImageClassification", "ViTModel"]
pt_only_classes = ["ViTForMaskedImageModeling"]
expected_model_classes = {
"pt": set(vit_classes),
"pt": set(vit_classes + pt_only_classes),
"tf": {f"TF{m}" for m in vit_classes},
"flax": {f"Flax{m}" for m in vit_classes},
}
......@@ -793,27 +798,28 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["test_files"]}
vit_test_files = {
"tests/test_feature_extraction_vit.py",
"tests/test_modeling_vit.py",
"tests/test_modeling_tf_vit.py",
"tests/test_modeling_flax_vit.py",
"tests/models/vit/test_feature_extraction_vit.py",
"tests/models/vit/test_modeling_vit.py",
"tests/models/vit/test_modeling_tf_vit.py",
"tests/models/vit/test_modeling_flax_vit.py",
}
self.assertEqual(test_files, vit_test_files)
doc_file = str(Path(all_vit_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/vit.mdx")
self.assertEqual(all_vit_files["module_name"], "vit")
vit_model_patterns = vit_info["model_patterns"]
self.assertEqual(vit_model_patterns.model_name, "ViT")
self.assertEqual(vit_model_patterns.checkpoint, "google/vit-base-patch16-224")
self.assertEqual(vit_model_patterns.checkpoint, "google/vit-base-patch16-224-in21k")
self.assertEqual(vit_model_patterns.model_type, "vit")
self.assertEqual(vit_model_patterns.model_lower_cased, "vit")
self.assertEqual(vit_model_patterns.model_camel_cased, "ViT")
self.assertEqual(vit_model_patterns.model_upper_cased, "VIT")
self.assertEqual(vit_model_patterns.config_class, "ViTConfig")
self.assertEqual(vit_model_patterns.feature_extractor_class, "ViTFeatureExtractor")
self.assertEqual(vit_model_patterns.image_processor_class, "ViTImageProcessor")
self.assertIsNone(vit_model_patterns.tokenizer_class)
self.assertIsNone(vit_model_patterns.processor_class)
......@@ -844,17 +850,17 @@ NEW_BERT_CONSTANT = "value"
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["test_files"]}
wav2vec2_test_files = {
"tests/test_feature_extraction_wav2vec2.py",
"tests/test_modeling_wav2vec2.py",
"tests/test_modeling_tf_wav2vec2.py",
"tests/test_modeling_flax_wav2vec2.py",
"tests/test_processor_wav2vec2.py",
"tests/test_tokenization_wav2vec2.py",
"tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
"tests/models/wav2vec2/test_modeling_wav2vec2.py",
"tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
"tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
"tests/models/wav2vec2/test_processor_wav2vec2.py",
"tests/models/wav2vec2/test_tokenization_wav2vec2.py",
}
self.assertEqual(test_files, wav2vec2_test_files)
doc_file = str(Path(all_wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.mdx")
self.assertEqual(all_wav2vec2_files["module_name"], "wav2vec2")
......@@ -881,32 +887,72 @@ _import_structure = {
"tokenization_gpt2": ["GPT2Tokenizer"],
}
if is_tokenizers_available():
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gpt2"] = ["GPT2Model"]
if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]
if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]
if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
from .tokenization_gpt2 import GPT2Tokenizer
if is_tokenizers_available():
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_gpt2_fast import GPT2TokenizerFast
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gpt2 import GPT2Model
if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_gpt2 import TFGPT2Model
if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_gpt2 import FlaxGPT2Model
else:
......@@ -924,25 +970,55 @@ _import_structure = {
"configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
}
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gpt2"] = ["GPT2Model"]
if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]
if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]
if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gpt2 import GPT2Model
if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_gpt2 import TFGPT2Model
if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_gpt2 import FlaxGPT2Model
else:
......@@ -961,20 +1037,40 @@ _import_structure = {
"tokenization_gpt2": ["GPT2Tokenizer"],
}
if is_tokenizers_available():
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gpt2"] = ["GPT2Model"]
if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
from .tokenization_gpt2 import GPT2Tokenizer
if is_tokenizers_available():
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_gpt2_fast import GPT2TokenizerFast
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gpt2 import GPT2Model
else:
......@@ -992,13 +1088,23 @@ _import_structure = {
"configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
}
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gpt2"] = ["GPT2Model"]
if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gpt2 import GPT2Model
else:
......@@ -1032,32 +1138,72 @@ _import_structure = {
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
}
if is_vision_available():
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_vit"] = ["ViTImageProcessor"]
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vit"] = ["ViTModel"]
if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_vit"] = ["TFViTModel"]
if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_vit"] = ["FlaxViTModel"]
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
if is_vision_available():
from .feature_extraction_vit import ViTFeatureExtractor
if is_torch_available():
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_vit import ViTImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vit import ViTModel
if is_tf_available():
from .modeling_tf_vit import ViTModel
if is_flax_available():
from .modeling_flax_vit import ViTModel
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_vit import TFViTModel
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_vit import FlaxViTModel
else:
import sys
......@@ -1074,26 +1220,56 @@ _import_structure = {
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
}
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vit"] = ["ViTModel"]
if is_tf_available():
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_vit"] = ["TFViTModel"]
if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_vit"] = ["FlaxViTModel"]
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vit import ViTModel
if is_tf_available():
from .modeling_tf_vit import ViTModel
if is_flax_available():
from .modeling_flax_vit import ViTModel
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_vit import TFViTModel
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_vit import FlaxViTModel
else:
import sys
......@@ -1110,19 +1286,39 @@ _import_structure = {
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
}
if is_vision_available():
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_vit"] = ["ViTImageProcessor"]
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vit"] = ["ViTModel"]
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
if is_vision_available():
from .feature_extraction_vit import ViTFeatureExtractor
if is_torch_available():
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_vit import ViTImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vit import ViTModel
else:
......@@ -1140,13 +1336,23 @@ _import_structure = {
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
}
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vit"] = ["ViTModel"]
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vit import ViTModel
else:
......@@ -1218,7 +1424,7 @@ Overview of the model.
## Overview
The GPT-New New model was proposed in [<INSERT PAPER NAME HERE>(<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
The GPT-New New model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>
The abstract from the paper is the following:
......@@ -1229,7 +1435,7 @@ Tips:
<INSERT TIPS ABOUT MODEL HERE>
This model was contributed by [INSERT YOUR HF USERNAME HERE](<https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment