Unverified Commit f0fd73a2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Document check copies (#25291)

* Document check copies better and add tests

* Include header in check for copies

* Manual fixes

* Try autofix

* Fixes

* Clean tests

* Finalize doc

* Remove debug print

* More fixes
parent 29f04002
......@@ -242,7 +242,7 @@ def window_reverse(windows, window_size, height, width):
# Copied from transformers.models.swin.modeling_swin.drop_path
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
......@@ -56,7 +56,7 @@ remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
......
......@@ -1603,7 +1603,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
)
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config, target_lang=None):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
......
......@@ -15,7 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
import regex as re
......@@ -25,6 +25,10 @@ from ...utils import logging
from .english_normalizer import EnglishTextNormalizer
if TYPE_CHECKING:
from ...pipelines.conversational import Conversation
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"tokenizer_file": "tokenizer.json",
......@@ -697,7 +701,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
return (text, kwargs)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids with GPT2 -> Whisper
def _build_conversation_input_ids(self, conversation) -> List[int]:
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
input_ids = []
for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
......
......@@ -15,7 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
import numpy as np
from tokenizers import pre_tokenizers, processors
......@@ -27,6 +27,10 @@ from .english_normalizer import EnglishTextNormalizer
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
if TYPE_CHECKING:
from ...pipelines.conversational import Conversation
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
......@@ -468,7 +472,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._build_conversation_input_ids
def _build_conversation_input_ids(self, conversation) -> List[int]:
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
input_ids = []
for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
......
......@@ -360,7 +360,7 @@ class XCLIPEncoderLayer(nn.Module):
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
......@@ -135,7 +135,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
......
......@@ -770,7 +770,7 @@ class YolosImageProcessor(BaseImageProcessor):
return target
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare
def prepare(self, image, target, return_segmentation_masks=False, masks_path=None):
def prepare(self, image, target, return_segmentation_masks=None, masks_path=None):
logger.warning_once(
"The `prepare` method is deprecated and will be removed in a v4.33. "
"Please use `prepare_annotation` instead. Note: the `prepare_annotation` method "
......
......@@ -899,6 +899,13 @@ class BartModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class BartPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BartPretrainedModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -13,19 +13,19 @@
# limitations under the License.
import os
import re
import shutil
import sys
import tempfile
import unittest
import black
from contextlib import contextmanager
from pathlib import Path
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))
import check_copies # noqa: E402
from check_copies import convert_to_localized_md, find_code_in_transformers, is_copy_consistent # noqa: E402
# This is the reference code that will be used in the tests.
......@@ -49,78 +49,137 @@ REFERENCE_CODE = """ def __init__(self, config):
return hidden_states
"""
MOCK_BERT_CODE = """from ...modeling_utils import PreTrainedModel
def bert_function(x):
return x
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
class BertModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__()
self.bert = BertEncoder(config)
@add_docstring(BERT_DOCSTRING)
def forward(self, x):
return self.bert(x)
"""
MOCK_BERT_COPY_CODE = """from ...modeling_utils import PreTrainedModel
# Copied from transformers.models.bert.modeling_bert.bert_function
def bert_copy_function(x):
return x
# Copied from transformers.models.bert.modeling_bert.BertAttention
class BertCopyAttention(nn.Module):
def __init__(self, config):
super().__init__()
class CopyCheckTester(unittest.TestCase):
def setUp(self):
self.transformer_dir = tempfile.mkdtemp()
os.makedirs(os.path.join(self.transformer_dir, "models/bert/"))
check_copies.TRANSFORMER_PATH = self.transformer_dir
shutil.copy(
os.path.join(git_repo_path, "src/transformers/models/bert/modeling_bert.py"),
os.path.join(self.transformer_dir, "models/bert/modeling_bert.py"),
)
def tearDown(self):
check_copies.TRANSFORMER_PATH = "src/transformers"
shutil.rmtree(self.transformer_dir)
def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None):
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
if overwrite_result is not None:
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
code = black.format_str(code, mode=mode)
fname = os.path.join(self.transformer_dir, "new_code.py")
with open(fname, "w", newline="\n") as f:
# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->BertCopy all-casing
class BertCopyModel(BertCopyPreTrainedModel):
def __init__(self, config):
super().__init__()
self.bertcopy = BertCopyEncoder(config)
@add_docstring(BERTCOPY_DOCSTRING)
def forward(self, x):
return self.bertcopy(x)
"""
def replace_in_file(filename, old, new):
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
content = content.replace(old, new)
with open(filename, "w", encoding="utf-8") as f:
f.write(content)
def create_tmp_repo(tmp_dir):
"""
Creates a mock repository in a temporary folder for testing.
"""
tmp_dir = Path(tmp_dir)
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True)
model_dir = tmp_dir / "src" / "transformers" / "models"
model_dir.mkdir(parents=True, exist_ok=True)
models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE}
for model, code in models.items():
model_subdir = model_dir / model
model_subdir.mkdir(exist_ok=True)
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f:
f.write(code)
if overwrite_result is None:
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
else:
check_copies.is_copy_consistent(f.name, overwrite=True)
with open(fname, "r") as f:
self.assertTrue(f.read(), expected)
@contextmanager
def patch_transformer_repo_path(new_folder):
"""
Temporarily patches the variables defines in `check_copies` to use a different location for the repo.
"""
old_repo_path = check_copies.REPO_PATH
old_doc_path = check_copies.PATH_TO_DOCS
old_transformer_path = check_copies.TRANSFORMERS_PATH
repo_path = Path(new_folder).resolve()
check_copies.REPO_PATH = str(repo_path)
check_copies.PATH_TO_DOCS = str(repo_path / "docs" / "source" / "en")
check_copies.TRANSFORMERS_PATH = str(repo_path / "src" / "transformers")
try:
yield
finally:
check_copies.REPO_PATH = old_repo_path
check_copies.PATH_TO_DOCS = old_doc_path
check_copies.TRANSFORMERS_PATH = old_transformer_path
class CopyCheckTester(unittest.TestCase):
def test_find_code_in_transformers(self):
code = check_copies.find_code_in_transformers("models.bert.modeling_bert.BertLMPredictionHead")
self.assertEqual(code, REFERENCE_CODE)
with tempfile.TemporaryDirectory() as tmp_folder:
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
code = find_code_in_transformers("models.bert.modeling_bert.BertAttention")
def test_is_copy_consistent(self):
# Base copy consistency
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
"BertLMPredictionHead",
REFERENCE_CODE + "\n",
reference_code = (
"class BertAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n"
)
self.assertEqual(code, reference_code)
# With no empty line at the end
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
"BertLMPredictionHead",
REFERENCE_CODE,
)
def test_is_copy_consistent(self):
path_to_check = ["src", "transformers", "models", "bertcopy", "modeling_bertcopy.py"]
with tempfile.TemporaryDirectory() as tmp_folder:
# Base check
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [])
# Copy consistency with rename
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel",
"TestModelLMPredictionHead",
re.sub("Bert", "TestModel", REFERENCE_CODE),
)
# Base check with an inconsistency
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
# Copy consistency with a really long name
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
self.check_copy_consistency(
f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
f"{long_class_name}LMPredictionHead",
re.sub("Bert", long_class_name, REFERENCE_CODE),
)
replace_in_file(file_to_check, "self.bertcopy(x)", "self.bert(x)")
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
# Copy consistency with overwrite
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel",
"TestModelLMPredictionHead",
REFERENCE_CODE,
overwrite_result=re.sub("Bert", "TestModel", REFERENCE_CODE),
)
diffs = is_copy_consistent(file_to_check, overwrite=True)
with open(file_to_check, "r", encoding="utf-8") as f:
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
def test_convert_to_localized_md(self):
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
......@@ -168,14 +227,14 @@ class CopyCheckTester(unittest.TestCase):
" Christopher D. Manning 发布。\n"
)
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
num_models_equal, converted_md_list = convert_to_localized_md(
md_list, localized_md_list, localized_readme["format_model_list"]
)
self.assertFalse(num_models_equal)
self.assertEqual(converted_md_list, converted_md_list_sample)
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
num_models_equal, converted_md_list = convert_to_localized_md(
md_list, converted_md_list, localized_readme["format_model_list"]
)
......@@ -201,7 +260,7 @@ class CopyCheckTester(unittest.TestCase):
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
)
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
num_models_equal, converted_md_list = convert_to_localized_md(
link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
)
......
......@@ -12,6 +12,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that checks whether the copies defined in the library match the original or not. This includes:
- All code commented with `# Copied from` comments,
- The list of models in the main README.md matches the ones in the localized READMEs and in the index.md,
- Files that are registered as full copies of one another in the `FULL_COPIES` constant of this script.
This also checks the list of models in the README is complete (has all models) and add a line to complete if there is
a model missing.
Use from the root of the repo with:
```bash
python utils/check_copies.py
```
for a check that will error in case of inconsistencies (used by `make repo-consistency`) or
```bash
python utils/check_copies.py --fix_and_overwrite
```
for a check that will fix all inconsistencies automatically (used by `make fix-copies`).
"""
import argparse
import glob
......@@ -103,7 +126,9 @@ transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
# Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
# function definition
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def find_code_in_transformers(object_name):
......@@ -140,7 +165,7 @@ def find_code_in_transformers(object_name):
raise ValueError(f" {object_name} does not match any function or class in {module}.")
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
start_index = line_index
start_index = line_index - 1
while line_index < len(lines) and _should_continue(lines[line_index], indent):
line_index += 1
# Clean up empty lines at the end (if any).
......@@ -179,6 +204,33 @@ def blackify(code):
return result[len("class Bla:\n") :] if has_indent else result
def check_codes_match(observed_code, theoretical_code):
"""
Checks if the code in `observed_code` and `theoretical_code` match with the exception of the class/function name.
Returns the index of the first line where there is a difference (if any) and `None` if the codes match.
"""
observed_code_header = observed_code.split("\n")[0]
theoretical_code_header = theoretical_code.split("\n")[0]
_re_class_match = re.compile(r"class\s+([^\(:]+)(?:\(|:)")
_re_func_match = re.compile(r"def\s+([^\(]+)\(")
for re_pattern in [_re_class_match, _re_func_match]:
if re_pattern.match(observed_code_header) is not None:
observed_obj_name = re_pattern.search(observed_code_header).groups()[0]
theoretical_name = re_pattern.search(theoretical_code_header).groups()[0]
theoretical_code_header = theoretical_code_header.replace(theoretical_name, observed_obj_name)
diff_index = 0
if theoretical_code_header != observed_code_header:
return 0
diff_index = 1
for observed_line, theoretical_line in zip(observed_code.split("\n")[1:], theoretical_code.split("\n")[1:]):
if observed_line != theoretical_line:
return diff_index
diff_index += 1
def is_copy_consistent(filename, overwrite=False):
"""
Check if the code commented as a copy in `filename` matches the original.
......@@ -201,10 +253,11 @@ def is_copy_consistent(filename, overwrite=False):
theoretical_code = find_code_in_transformers(object_name)
theoretical_indent = get_indent(theoretical_code)
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
indent = theoretical_indent
line_index = start_index
start_index = line_index + 1 if indent == theoretical_indent else line_index
line_index = start_index + 1
subcode = "\n".join(theoretical_code.split("\n")[1:])
indent = get_indent(subcode)
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
should_continue = True
while line_index < len(lines) and should_continue:
......@@ -212,6 +265,8 @@ def is_copy_consistent(filename, overwrite=False):
if line_index >= len(lines):
break
line = lines[line_index]
# There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be
# used.
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
......@@ -233,19 +288,12 @@ def is_copy_consistent(filename, overwrite=False):
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
# Blackify after replacement. To be able to do that, we need the header (class or function definition)
# from the previous line
theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
theoretical_code = blackify(theoretical_code)
# Test for a diff and act accordingly.
if observed_code != theoretical_code:
diff_index = start_index + 1
for observed_line, theoretical_line in zip(observed_code.split("\n"), theoretical_code.split("\n")):
if observed_line != theoretical_line:
break
diff_index += 1
diffs.append([object_name, diff_index])
diff_index = check_codes_match(observed_code, theoretical_code)
if diff_index is not None:
diffs.append([object_name, diff_index + start_index + 1])
if overwrite:
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
line_index = start_index + 1
......@@ -259,6 +307,10 @@ def is_copy_consistent(filename, overwrite=False):
def check_copies(overwrite: bool = False):
"""
Check every file is copy-consistent with the original and maybe `overwrite` content when it is not. Also check the
model list in the main README and other READMEs/index.md are consistent.
"""
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
diffs = []
for filename in all_files:
......@@ -275,6 +327,10 @@ def check_copies(overwrite: bool = False):
def check_full_copies(overwrite: bool = False):
"""
Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent and maybe
`overwrite` to fix issues.
"""
diffs = []
for target, source in FULL_COPIES.items():
with open(source, "r", encoding="utf-8") as f:
......@@ -299,7 +355,7 @@ def check_full_copies(overwrite: bool = False):
def get_model_list(filename, start_prompt, end_prompt):
"""Extracts the model list from the README."""
"""Extracts the model list from a README, between `start_prompt` and `end_prompt`."""
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start of the list.
......@@ -327,7 +383,20 @@ def get_model_list(filename, start_prompt, end_prompt):
def convert_to_localized_md(model_list, localized_model_list, format_str):
"""Convert `model_list` to each localized README."""
"""
Compare the model list from the main README to the one in a localized README.
Args:
model_list (`str`): The model list in the main README.
localized_model_list (`str`): The model list in one of the localized README.
format_str (`str`):
The template for a model entry in the localized README (look at the `format_model_list` in the entries of
`LOCALIZED_READMES` for examples).
Returns:
`Tuple[bool, str]`: A tuple where the first value indicates if the READMEs match or not, and the second value
is the correct localized README.
"""
def _rep(match):
title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups()
......@@ -341,7 +410,8 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
)
# This regex captures metadata from an English model description, including model title, model link,
# affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example).
# affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for
# example).
_re_capture_meta = re.compile(
r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$"
)
......@@ -389,6 +459,10 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
def convert_readme_to_index(model_list):
"""
Converts the model list of the README to the index.md format.
"""
# We need to replce both link to the main doc and stable doc (the order of the next two instructions is important).
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
return model_list.replace("https://huggingface.co/docs/transformers/", "")
......@@ -420,7 +494,9 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
def check_model_list_copy(overwrite=False, max_per_line=119):
"""Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
"""
Check the model lists in the README is consistent with the ones in the other READMES and also with `index.nmd`.
"""
# Fix potential doc links in the README
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
readme = f.read()
......@@ -490,6 +566,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)
# Map a model name with the name it has in the README for the check_readme check
SPECIAL_MODEL_NAMES = {
"Bert Generation": "BERT For Sequence Generation",
"BigBird": "BigBird-RoBERTa",
......@@ -522,7 +599,7 @@ MODELS_NOT_IN_README = [
"VisionTextDualEncoder",
]
# Template for new entries to add in the main README when we have missing models.
README_TEMPLATE = (
"1. **[{model_name}](https://huggingface.co/docs/main/transformers/model_doc/{model_type})** (from "
"<FILL INSTITUTION>) released with the paper [<FILL PAPER TITLE>](<FILL ARKIV LINK>) by <FILL AUTHORS>."
......@@ -530,6 +607,10 @@ README_TEMPLATE = (
def check_readme(overwrite=False):
"""
Check if the main README contains all the models in the library or not. If `overwrite`, will add an entry for the
missing models using `README_TEMPLATE`.
"""
info = LOCALIZED_READMES["README.md"]
models, start_index, end_index, lines = _find_text_in_file(
os.path.join(REPO_PATH, "README.md"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment