Unverified Commit f0fd73a2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Document check copies (#25291)

* Document check copies better and add tests

* Include header in check for copies

* Manual fixes

* Try autofix

* Fixes

* Clean tests

* Finalize doc

* Remove debug print

* More fixes
parent 29f04002
...@@ -242,7 +242,7 @@ def window_reverse(windows, window_size, height, width): ...@@ -242,7 +242,7 @@ def window_reverse(windows, window_size, height, width):
# Copied from transformers.models.swin.modeling_swin.drop_path # Copied from transformers.models.swin.modeling_swin.drop_path
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
...@@ -56,7 +56,7 @@ remat = nn_partitioning.remat ...@@ -56,7 +56,7 @@ remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -1603,7 +1603,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): ...@@ -1603,7 +1603,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
) )
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config, target_lang=None): def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config) super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Tokenization classes for Whisper.""" """Tokenization classes for Whisper."""
import json import json
import os import os
from typing import List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np import numpy as np
import regex as re import regex as re
...@@ -25,6 +25,10 @@ from ...utils import logging ...@@ -25,6 +25,10 @@ from ...utils import logging
from .english_normalizer import EnglishTextNormalizer from .english_normalizer import EnglishTextNormalizer
if TYPE_CHECKING:
from ...pipelines.conversational import Conversation
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json", "vocab_file": "vocab.json",
"tokenizer_file": "tokenizer.json", "tokenizer_file": "tokenizer.json",
...@@ -697,7 +701,7 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -697,7 +701,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
return (text, kwargs) return (text, kwargs)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids with GPT2 -> Whisper # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids with GPT2 -> Whisper
def _build_conversation_input_ids(self, conversation) -> List[int]: def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
input_ids = [] input_ids = []
for is_user, text in conversation.iter_texts(): for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Tokenization classes for Whisper.""" """Tokenization classes for Whisper."""
import json import json
import os import os
from typing import List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import numpy as np import numpy as np
from tokenizers import pre_tokenizers, processors from tokenizers import pre_tokenizers, processors
...@@ -27,6 +27,10 @@ from .english_normalizer import EnglishTextNormalizer ...@@ -27,6 +27,10 @@ from .english_normalizer import EnglishTextNormalizer
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
if TYPE_CHECKING:
from ...pipelines.conversational import Conversation
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
...@@ -468,7 +472,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -468,7 +472,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._build_conversation_input_ids # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._build_conversation_input_ids
def _build_conversation_input_ids(self, conversation) -> List[int]: def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
input_ids = [] input_ids = []
for is_user, text in conversation.iter_texts(): for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
......
...@@ -360,7 +360,7 @@ class XCLIPEncoderLayer(nn.Module): ...@@ -360,7 +360,7 @@ class XCLIPEncoderLayer(nn.Module):
# Copied from transformers.models.beit.modeling_beit.drop_path # Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False): def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
...@@ -135,7 +135,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i ...@@ -135,7 +135,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
......
...@@ -770,7 +770,7 @@ class YolosImageProcessor(BaseImageProcessor): ...@@ -770,7 +770,7 @@ class YolosImageProcessor(BaseImageProcessor):
return target return target
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare
def prepare(self, image, target, return_segmentation_masks=False, masks_path=None): def prepare(self, image, target, return_segmentation_masks=None, masks_path=None):
logger.warning_once( logger.warning_once(
"The `prepare` method is deprecated and will be removed in a v4.33. " "The `prepare` method is deprecated and will be removed in a v4.33. "
"Please use `prepare_annotation` instead. Note: the `prepare_annotation` method " "Please use `prepare_annotation` instead. Note: the `prepare_annotation` method "
......
...@@ -899,6 +899,13 @@ class BartModel(metaclass=DummyObject): ...@@ -899,6 +899,13 @@ class BartModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class BartPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BartPretrainedModel(metaclass=DummyObject): class BartPretrainedModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -13,19 +13,19 @@ ...@@ -13,19 +13,19 @@
# limitations under the License. # limitations under the License.
import os import os
import re
import shutil import shutil
import sys import sys
import tempfile import tempfile
import unittest import unittest
from contextlib import contextmanager
import black from pathlib import Path
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils")) sys.path.append(os.path.join(git_repo_path, "utils"))
import check_copies # noqa: E402 import check_copies # noqa: E402
from check_copies import convert_to_localized_md, find_code_in_transformers, is_copy_consistent # noqa: E402
# This is the reference code that will be used in the tests. # This is the reference code that will be used in the tests.
...@@ -49,78 +49,137 @@ REFERENCE_CODE = """ def __init__(self, config): ...@@ -49,78 +49,137 @@ REFERENCE_CODE = """ def __init__(self, config):
return hidden_states return hidden_states
""" """
MOCK_BERT_CODE = """from ...modeling_utils import PreTrainedModel
def bert_function(x):
return x
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
class BertModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__()
self.bert = BertEncoder(config)
@add_docstring(BERT_DOCSTRING)
def forward(self, x):
return self.bert(x)
"""
MOCK_BERT_COPY_CODE = """from ...modeling_utils import PreTrainedModel
# Copied from transformers.models.bert.modeling_bert.bert_function
def bert_copy_function(x):
return x
# Copied from transformers.models.bert.modeling_bert.BertAttention
class BertCopyAttention(nn.Module):
def __init__(self, config):
super().__init__()
class CopyCheckTester(unittest.TestCase):
def setUp(self):
self.transformer_dir = tempfile.mkdtemp()
os.makedirs(os.path.join(self.transformer_dir, "models/bert/"))
check_copies.TRANSFORMER_PATH = self.transformer_dir
shutil.copy(
os.path.join(git_repo_path, "src/transformers/models/bert/modeling_bert.py"),
os.path.join(self.transformer_dir, "models/bert/modeling_bert.py"),
)
def tearDown(self): # Copied from transformers.models.bert.modeling_bert.BertModel with Bert->BertCopy all-casing
check_copies.TRANSFORMER_PATH = "src/transformers" class BertCopyModel(BertCopyPreTrainedModel):
shutil.rmtree(self.transformer_dir) def __init__(self, config):
super().__init__()
def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None): self.bertcopy = BertCopyEncoder(config)
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
if overwrite_result is not None: @add_docstring(BERTCOPY_DOCSTRING)
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result def forward(self, x):
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119) return self.bertcopy(x)
code = black.format_str(code, mode=mode) """
fname = os.path.join(self.transformer_dir, "new_code.py")
with open(fname, "w", newline="\n") as f:
def replace_in_file(filename, old, new):
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
content = content.replace(old, new)
with open(filename, "w", encoding="utf-8") as f:
f.write(content)
def create_tmp_repo(tmp_dir):
"""
Creates a mock repository in a temporary folder for testing.
"""
tmp_dir = Path(tmp_dir)
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True)
model_dir = tmp_dir / "src" / "transformers" / "models"
model_dir.mkdir(parents=True, exist_ok=True)
models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE}
for model, code in models.items():
model_subdir = model_dir / model
model_subdir.mkdir(exist_ok=True)
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f:
f.write(code) f.write(code)
if overwrite_result is None:
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
else:
check_copies.is_copy_consistent(f.name, overwrite=True)
with open(fname, "r") as f:
self.assertTrue(f.read(), expected)
@contextmanager
def patch_transformer_repo_path(new_folder):
"""
Temporarily patches the variables defines in `check_copies` to use a different location for the repo.
"""
old_repo_path = check_copies.REPO_PATH
old_doc_path = check_copies.PATH_TO_DOCS
old_transformer_path = check_copies.TRANSFORMERS_PATH
repo_path = Path(new_folder).resolve()
check_copies.REPO_PATH = str(repo_path)
check_copies.PATH_TO_DOCS = str(repo_path / "docs" / "source" / "en")
check_copies.TRANSFORMERS_PATH = str(repo_path / "src" / "transformers")
try:
yield
finally:
check_copies.REPO_PATH = old_repo_path
check_copies.PATH_TO_DOCS = old_doc_path
check_copies.TRANSFORMERS_PATH = old_transformer_path
class CopyCheckTester(unittest.TestCase):
def test_find_code_in_transformers(self): def test_find_code_in_transformers(self):
code = check_copies.find_code_in_transformers("models.bert.modeling_bert.BertLMPredictionHead") with tempfile.TemporaryDirectory() as tmp_folder:
self.assertEqual(code, REFERENCE_CODE) create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
code = find_code_in_transformers("models.bert.modeling_bert.BertAttention")
def test_is_copy_consistent(self): reference_code = (
# Base copy consistency "class BertAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n"
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
"BertLMPredictionHead",
REFERENCE_CODE + "\n",
) )
self.assertEqual(code, reference_code)
# With no empty line at the end def test_is_copy_consistent(self):
self.check_copy_consistency( path_to_check = ["src", "transformers", "models", "bertcopy", "modeling_bertcopy.py"]
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead", with tempfile.TemporaryDirectory() as tmp_folder:
"BertLMPredictionHead", # Base check
REFERENCE_CODE, create_tmp_repo(tmp_folder)
) with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [])
# Copy consistency with rename # Base check with an inconsistency
self.check_copy_consistency( create_tmp_repo(tmp_folder)
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel", with patch_transformer_repo_path(tmp_folder):
"TestModelLMPredictionHead", file_to_check = os.path.join(tmp_folder, *path_to_check)
re.sub("Bert", "TestModel", REFERENCE_CODE),
)
# Copy consistency with a really long name replace_in_file(file_to_check, "self.bertcopy(x)", "self.bert(x)")
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason" diffs = is_copy_consistent(file_to_check)
self.check_copy_consistency( self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
f"{long_class_name}LMPredictionHead",
re.sub("Bert", long_class_name, REFERENCE_CODE),
)
# Copy consistency with overwrite diffs = is_copy_consistent(file_to_check, overwrite=True)
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel", with open(file_to_check, "r", encoding="utf-8") as f:
"TestModelLMPredictionHead", self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
REFERENCE_CODE,
overwrite_result=re.sub("Bert", "TestModel", REFERENCE_CODE),
)
def test_convert_to_localized_md(self): def test_convert_to_localized_md(self):
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"] localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
...@@ -168,14 +227,14 @@ class CopyCheckTester(unittest.TestCase): ...@@ -168,14 +227,14 @@ class CopyCheckTester(unittest.TestCase):
" Christopher D. Manning 发布。\n" " Christopher D. Manning 发布。\n"
) )
num_models_equal, converted_md_list = check_copies.convert_to_localized_md( num_models_equal, converted_md_list = convert_to_localized_md(
md_list, localized_md_list, localized_readme["format_model_list"] md_list, localized_md_list, localized_readme["format_model_list"]
) )
self.assertFalse(num_models_equal) self.assertFalse(num_models_equal)
self.assertEqual(converted_md_list, converted_md_list_sample) self.assertEqual(converted_md_list, converted_md_list_sample)
num_models_equal, converted_md_list = check_copies.convert_to_localized_md( num_models_equal, converted_md_list = convert_to_localized_md(
md_list, converted_md_list, localized_readme["format_model_list"] md_list, converted_md_list, localized_readme["format_model_list"]
) )
...@@ -201,7 +260,7 @@ class CopyCheckTester(unittest.TestCase): ...@@ -201,7 +260,7 @@ class CopyCheckTester(unittest.TestCase):
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n" " Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
) )
num_models_equal, converted_md_list = check_copies.convert_to_localized_md( num_models_equal, converted_md_list = convert_to_localized_md(
link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"] link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
) )
......
...@@ -12,6 +12,29 @@ ...@@ -12,6 +12,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Utility that checks whether the copies defined in the library match the original or not. This includes:
- All code commented with `# Copied from` comments,
- The list of models in the main README.md matches the ones in the localized READMEs and in the index.md,
- Files that are registered as full copies of one another in the `FULL_COPIES` constant of this script.
This also checks the list of models in the README is complete (has all models) and add a line to complete if there is
a model missing.
Use from the root of the repo with:
```bash
python utils/check_copies.py
```
for a check that will error in case of inconsistencies (used by `make repo-consistency`) or
```bash
python utils/check_copies.py --fix_and_overwrite
```
for a check that will fix all inconsistencies automatically (used by `make fix-copies`).
"""
import argparse import argparse
import glob import glob
...@@ -103,7 +126,9 @@ transformers_module = direct_transformers_import(TRANSFORMERS_PATH) ...@@ -103,7 +126,9 @@ transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
def _should_continue(line, indent): def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None # Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
# function definition
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def find_code_in_transformers(object_name): def find_code_in_transformers(object_name):
...@@ -140,7 +165,7 @@ def find_code_in_transformers(object_name): ...@@ -140,7 +165,7 @@ def find_code_in_transformers(object_name):
raise ValueError(f" {object_name} does not match any function or class in {module}.") raise ValueError(f" {object_name} does not match any function or class in {module}.")
# We found the beginning of the class / func, now let's find the end (when the indent diminishes). # We found the beginning of the class / func, now let's find the end (when the indent diminishes).
start_index = line_index start_index = line_index - 1
while line_index < len(lines) and _should_continue(lines[line_index], indent): while line_index < len(lines) and _should_continue(lines[line_index], indent):
line_index += 1 line_index += 1
# Clean up empty lines at the end (if any). # Clean up empty lines at the end (if any).
...@@ -179,6 +204,33 @@ def blackify(code): ...@@ -179,6 +204,33 @@ def blackify(code):
return result[len("class Bla:\n") :] if has_indent else result return result[len("class Bla:\n") :] if has_indent else result
def check_codes_match(observed_code, theoretical_code):
"""
Checks if the code in `observed_code` and `theoretical_code` match with the exception of the class/function name.
Returns the index of the first line where there is a difference (if any) and `None` if the codes match.
"""
observed_code_header = observed_code.split("\n")[0]
theoretical_code_header = theoretical_code.split("\n")[0]
_re_class_match = re.compile(r"class\s+([^\(:]+)(?:\(|:)")
_re_func_match = re.compile(r"def\s+([^\(]+)\(")
for re_pattern in [_re_class_match, _re_func_match]:
if re_pattern.match(observed_code_header) is not None:
observed_obj_name = re_pattern.search(observed_code_header).groups()[0]
theoretical_name = re_pattern.search(theoretical_code_header).groups()[0]
theoretical_code_header = theoretical_code_header.replace(theoretical_name, observed_obj_name)
diff_index = 0
if theoretical_code_header != observed_code_header:
return 0
diff_index = 1
for observed_line, theoretical_line in zip(observed_code.split("\n")[1:], theoretical_code.split("\n")[1:]):
if observed_line != theoretical_line:
return diff_index
diff_index += 1
def is_copy_consistent(filename, overwrite=False): def is_copy_consistent(filename, overwrite=False):
""" """
Check if the code commented as a copy in `filename` matches the original. Check if the code commented as a copy in `filename` matches the original.
...@@ -201,10 +253,11 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -201,10 +253,11 @@ def is_copy_consistent(filename, overwrite=False):
theoretical_code = find_code_in_transformers(object_name) theoretical_code = find_code_in_transformers(object_name)
theoretical_indent = get_indent(theoretical_code) theoretical_indent = get_indent(theoretical_code)
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2 start_index = line_index + 1 if indent == theoretical_indent else line_index
indent = theoretical_indent line_index = start_index + 1
line_index = start_index
subcode = "\n".join(theoretical_code.split("\n")[1:])
indent = get_indent(subcode)
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment. # Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
should_continue = True should_continue = True
while line_index < len(lines) and should_continue: while line_index < len(lines) and should_continue:
...@@ -212,6 +265,8 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -212,6 +265,8 @@ def is_copy_consistent(filename, overwrite=False):
if line_index >= len(lines): if line_index >= len(lines):
break break
line = lines[line_index] line = lines[line_index]
# There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be
# used.
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
# Clean up empty lines at the end (if any). # Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1: while len(lines[line_index - 1]) <= 1:
...@@ -233,19 +288,12 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -233,19 +288,12 @@ def is_copy_consistent(filename, overwrite=False):
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code) theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
# Blackify after replacement. To be able to do that, we need the header (class or function definition) theoretical_code = blackify(theoretical_code)
# from the previous line
theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
# Test for a diff and act accordingly. # Test for a diff and act accordingly.
if observed_code != theoretical_code: diff_index = check_codes_match(observed_code, theoretical_code)
diff_index = start_index + 1 if diff_index is not None:
for observed_line, theoretical_line in zip(observed_code.split("\n"), theoretical_code.split("\n")): diffs.append([object_name, diff_index + start_index + 1])
if observed_line != theoretical_line:
break
diff_index += 1
diffs.append([object_name, diff_index])
if overwrite: if overwrite:
lines = lines[:start_index] + [theoretical_code] + lines[line_index:] lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
line_index = start_index + 1 line_index = start_index + 1
...@@ -259,6 +307,10 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -259,6 +307,10 @@ def is_copy_consistent(filename, overwrite=False):
def check_copies(overwrite: bool = False): def check_copies(overwrite: bool = False):
"""
Check every file is copy-consistent with the original and maybe `overwrite` content when it is not. Also check the
model list in the main README and other READMEs/index.md are consistent.
"""
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
diffs = [] diffs = []
for filename in all_files: for filename in all_files:
...@@ -275,6 +327,10 @@ def check_copies(overwrite: bool = False): ...@@ -275,6 +327,10 @@ def check_copies(overwrite: bool = False):
def check_full_copies(overwrite: bool = False): def check_full_copies(overwrite: bool = False):
"""
Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent and maybe
`overwrite` to fix issues.
"""
diffs = [] diffs = []
for target, source in FULL_COPIES.items(): for target, source in FULL_COPIES.items():
with open(source, "r", encoding="utf-8") as f: with open(source, "r", encoding="utf-8") as f:
...@@ -299,7 +355,7 @@ def check_full_copies(overwrite: bool = False): ...@@ -299,7 +355,7 @@ def check_full_copies(overwrite: bool = False):
def get_model_list(filename, start_prompt, end_prompt): def get_model_list(filename, start_prompt, end_prompt):
"""Extracts the model list from the README.""" """Extracts the model list from a README, between `start_prompt` and `end_prompt`."""
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f: with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines() lines = f.readlines()
# Find the start of the list. # Find the start of the list.
...@@ -327,7 +383,20 @@ def get_model_list(filename, start_prompt, end_prompt): ...@@ -327,7 +383,20 @@ def get_model_list(filename, start_prompt, end_prompt):
def convert_to_localized_md(model_list, localized_model_list, format_str): def convert_to_localized_md(model_list, localized_model_list, format_str):
"""Convert `model_list` to each localized README.""" """
Compare the model list from the main README to the one in a localized README.
Args:
model_list (`str`): The model list in the main README.
localized_model_list (`str`): The model list in one of the localized README.
format_str (`str`):
The template for a model entry in the localized README (look at the `format_model_list` in the entries of
`LOCALIZED_READMES` for examples).
Returns:
`Tuple[bool, str]`: A tuple where the first value indicates if the READMEs match or not, and the second value
is the correct localized README.
"""
def _rep(match): def _rep(match):
title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups() title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups()
...@@ -341,7 +410,8 @@ def convert_to_localized_md(model_list, localized_model_list, format_str): ...@@ -341,7 +410,8 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
) )
# This regex captures metadata from an English model description, including model title, model link, # This regex captures metadata from an English model description, including model title, model link,
# affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example). # affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for
# example).
_re_capture_meta = re.compile( _re_capture_meta = re.compile(
r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$" r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$"
) )
...@@ -389,6 +459,10 @@ def convert_to_localized_md(model_list, localized_model_list, format_str): ...@@ -389,6 +459,10 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
def convert_readme_to_index(model_list): def convert_readme_to_index(model_list):
"""
Converts the model list of the README to the index.md format.
"""
# We need to replce both link to the main doc and stable doc (the order of the next two instructions is important).
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "") model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
return model_list.replace("https://huggingface.co/docs/transformers/", "") return model_list.replace("https://huggingface.co/docs/transformers/", "")
...@@ -420,7 +494,9 @@ def _find_text_in_file(filename, start_prompt, end_prompt): ...@@ -420,7 +494,9 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
def check_model_list_copy(overwrite=False, max_per_line=119): def check_model_list_copy(overwrite=False, max_per_line=119):
"""Check the model lists in the README and index.rst are consistent and maybe `overwrite`.""" """
Check the model lists in the README is consistent with the ones in the other READMES and also with `index.nmd`.
"""
# Fix potential doc links in the README # Fix potential doc links in the README
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f: with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
readme = f.read() readme = f.read()
...@@ -490,6 +566,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119): ...@@ -490,6 +566,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
) )
# Map a model name with the name it has in the README for the check_readme check
SPECIAL_MODEL_NAMES = { SPECIAL_MODEL_NAMES = {
"Bert Generation": "BERT For Sequence Generation", "Bert Generation": "BERT For Sequence Generation",
"BigBird": "BigBird-RoBERTa", "BigBird": "BigBird-RoBERTa",
...@@ -522,7 +599,7 @@ MODELS_NOT_IN_README = [ ...@@ -522,7 +599,7 @@ MODELS_NOT_IN_README = [
"VisionTextDualEncoder", "VisionTextDualEncoder",
] ]
# Template for new entries to add in the main README when we have missing models.
README_TEMPLATE = ( README_TEMPLATE = (
"1. **[{model_name}](https://huggingface.co/docs/main/transformers/model_doc/{model_type})** (from " "1. **[{model_name}](https://huggingface.co/docs/main/transformers/model_doc/{model_type})** (from "
"<FILL INSTITUTION>) released with the paper [<FILL PAPER TITLE>](<FILL ARKIV LINK>) by <FILL AUTHORS>." "<FILL INSTITUTION>) released with the paper [<FILL PAPER TITLE>](<FILL ARKIV LINK>) by <FILL AUTHORS>."
...@@ -530,6 +607,10 @@ README_TEMPLATE = ( ...@@ -530,6 +607,10 @@ README_TEMPLATE = (
def check_readme(overwrite=False): def check_readme(overwrite=False):
"""
Check if the main README contains all the models in the library or not. If `overwrite`, will add an entry for the
missing models using `README_TEMPLATE`.
"""
info = LOCALIZED_READMES["README.md"] info = LOCALIZED_READMES["README.md"]
models, start_index, end_index, lines = _find_text_in_file( models, start_index, end_index, lines = _find_text_in_file(
os.path.join(REPO_PATH, "README.md"), os.path.join(REPO_PATH, "README.md"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment