Unverified Commit 175da8d1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix custom init sorting script (#16864)

parent 67ed0e43
......@@ -446,20 +446,16 @@ else:
# tokenizers-backed objects
if is_tokenizers_available():
# Fast tokenizers
_import_structure["models.realm"].append("RealmTokenizerFast")
_import_structure["models.xglm"].append("XGLMTokenizerFast")
_import_structure["models.fnet"].append("FNetTokenizerFast")
_import_structure["models.roformer"].append("RoFormerTokenizerFast")
_import_structure["models.clip"].append("CLIPTokenizerFast")
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
_import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast")
_import_structure["models.albert"].append("AlbertTokenizerFast")
_import_structure["models.bart"].append("BartTokenizerFast")
_import_structure["models.barthez"].append("BarthezTokenizerFast")
_import_structure["models.bert"].append("BertTokenizerFast")
_import_structure["models.big_bird"].append("BigBirdTokenizerFast")
_import_structure["models.blenderbot"].append("BlenderbotTokenizerFast")
_import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast")
_import_structure["models.camembert"].append("CamembertTokenizerFast")
_import_structure["models.clip"].append("CLIPTokenizerFast")
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
_import_structure["models.deberta"].append("DebertaTokenizerFast")
_import_structure["models.deberta_v2"].append("DebertaV2TokenizerFast")
_import_structure["models.distilbert"].append("DistilBertTokenizerFast")
......@@ -467,6 +463,7 @@ if is_tokenizers_available():
["DPRContextEncoderTokenizerFast", "DPRQuestionEncoderTokenizerFast", "DPRReaderTokenizerFast"]
)
_import_structure["models.electra"].append("ElectraTokenizerFast")
_import_structure["models.fnet"].append("FNetTokenizerFast")
_import_structure["models.funnel"].append("FunnelTokenizerFast")
_import_structure["models.gpt2"].append("GPT2TokenizerFast")
_import_structure["models.herbert"].append("HerbertTokenizerFast")
......@@ -483,13 +480,16 @@ if is_tokenizers_available():
_import_structure["models.mt5"].append("MT5TokenizerFast")
_import_structure["models.openai"].append("OpenAIGPTTokenizerFast")
_import_structure["models.pegasus"].append("PegasusTokenizerFast")
_import_structure["models.realm"].append("RealmTokenizerFast")
_import_structure["models.reformer"].append("ReformerTokenizerFast")
_import_structure["models.rembert"].append("RemBertTokenizerFast")
_import_structure["models.retribert"].append("RetriBertTokenizerFast")
_import_structure["models.roberta"].append("RobertaTokenizerFast")
_import_structure["models.roformer"].append("RoFormerTokenizerFast")
_import_structure["models.splinter"].append("SplinterTokenizerFast")
_import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast")
_import_structure["models.t5"].append("T5TokenizerFast")
_import_structure["models.xglm"].append("XGLMTokenizerFast")
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast")
_import_structure["models.xlnet"].append("XLNetTokenizerFast")
_import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"]
......
......@@ -48,6 +48,7 @@ if is_tf_available():
if is_flax_available():
_import_structure["modeling_flax_marian"] = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"]
if TYPE_CHECKING:
from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig, MarianOnnxConfig
......
......@@ -183,11 +183,20 @@ def sort_imports(file, check_only=True):
# Check if the block contains some `_import_structure`s thingy to sort.
block = main_blocks[block_idx]
block_lines = block.split("\n")
if len(block_lines) < 3 or "_import_structure" not in "".join(block_lines[:2]):
# Get to the start of the imports.
line_idx = 0
while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
# Skip dummy import blocks
if "import dummy" in block_lines[line_idx]:
line_idx = len(block_lines)
else:
line_idx += 1
if line_idx >= len(block_lines):
continue
# Ignore first and last line: they don't contain anything.
internal_block_code = "\n".join(block_lines[1:-1])
# Ignore beginning and last line: they don't contain anything.
internal_block_code = "\n".join(block_lines[line_idx:-1])
indent = get_indent(block_lines[1])
# Slit the internal block into blocks of indent level 1.
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
......@@ -211,7 +220,7 @@ def sort_imports(file, check_only=True):
count += 1
# And we put our main block back together with its first and last line.
main_blocks[block_idx] = "\n".join([block_lines[0]] + reorderded_blocks + [block_lines[-1]])
main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
if code != "\n".join(main_blocks):
if check_only:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment