Unverified Commit c0df963e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make the big table creation/check platform independent (#8856)

parent d366228d
...@@ -214,6 +214,7 @@ if is_sentencepiece_available(): ...@@ -214,6 +214,7 @@ if is_sentencepiece_available():
from .models.camembert import CamembertTokenizer from .models.camembert import CamembertTokenizer
from .models.marian import MarianTokenizer from .models.marian import MarianTokenizer
from .models.mbart import MBartTokenizer from .models.mbart import MBartTokenizer
from .models.mt5 import MT5Tokenizer
from .models.pegasus import PegasusTokenizer from .models.pegasus import PegasusTokenizer
from .models.reformer import ReformerTokenizer from .models.reformer import ReformerTokenizer
from .models.t5 import T5Tokenizer from .models.t5 import T5Tokenizer
...@@ -240,6 +241,7 @@ if is_tokenizers_available(): ...@@ -240,6 +241,7 @@ if is_tokenizers_available():
from .models.lxmert import LxmertTokenizerFast from .models.lxmert import LxmertTokenizerFast
from .models.mbart import MBartTokenizerFast from .models.mbart import MBartTokenizerFast
from .models.mobilebert import MobileBertTokenizerFast from .models.mobilebert import MobileBertTokenizerFast
from .models.mt5 import MT5TokenizerFast
from .models.openai import OpenAIGPTTokenizerFast from .models.openai import OpenAIGPTTokenizerFast
from .models.pegasus import PegasusTokenizerFast from .models.pegasus import PegasusTokenizerFast
from .models.reformer import ReformerTokenizerFast from .models.reformer import ReformerTokenizerFast
......
...@@ -98,6 +98,7 @@ if is_sentencepiece_available(): ...@@ -98,6 +98,7 @@ if is_sentencepiece_available():
from ..camembert.tokenization_camembert import CamembertTokenizer from ..camembert.tokenization_camembert import CamembertTokenizer
from ..marian.tokenization_marian import MarianTokenizer from ..marian.tokenization_marian import MarianTokenizer
from ..mbart.tokenization_mbart import MBartTokenizer from ..mbart.tokenization_mbart import MBartTokenizer
from ..mt5 import MT5Tokenizer
from ..pegasus.tokenization_pegasus import PegasusTokenizer from ..pegasus.tokenization_pegasus import PegasusTokenizer
from ..reformer.tokenization_reformer import ReformerTokenizer from ..reformer.tokenization_reformer import ReformerTokenizer
from ..t5.tokenization_t5 import T5Tokenizer from ..t5.tokenization_t5 import T5Tokenizer
...@@ -111,6 +112,7 @@ else: ...@@ -111,6 +112,7 @@ else:
CamembertTokenizer = None CamembertTokenizer = None
MarianTokenizer = None MarianTokenizer = None
MBartTokenizer = None MBartTokenizer = None
MT5Tokenizer = None
PegasusTokenizer = None PegasusTokenizer = None
ReformerTokenizer = None ReformerTokenizer = None
T5Tokenizer = None T5Tokenizer = None
...@@ -135,6 +137,7 @@ if is_tokenizers_available(): ...@@ -135,6 +137,7 @@ if is_tokenizers_available():
from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast
from ..mbart.tokenization_mbart_fast import MBartTokenizerFast from ..mbart.tokenization_mbart_fast import MBartTokenizerFast
from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast
from ..mt5 import MT5TokenizerFast
from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast
from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
...@@ -161,6 +164,7 @@ else: ...@@ -161,6 +164,7 @@ else:
LxmertTokenizerFast = None LxmertTokenizerFast = None
MBartTokenizerFast = None MBartTokenizerFast = None
MobileBertTokenizerFast = None MobileBertTokenizerFast = None
MT5TokenizerFast = None
OpenAIGPTTokenizerFast = None OpenAIGPTTokenizerFast = None
PegasusTokenizerFast = None PegasusTokenizerFast = None
ReformerTokenizerFast = None ReformerTokenizerFast = None
...@@ -178,7 +182,7 @@ TOKENIZER_MAPPING = OrderedDict( ...@@ -178,7 +182,7 @@ TOKENIZER_MAPPING = OrderedDict(
[ [
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)), (RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
(T5Config, (T5Tokenizer, T5TokenizerFast)), (T5Config, (T5Tokenizer, T5TokenizerFast)),
(MT5Config, (T5Tokenizer, T5TokenizerFast)), (MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)), (MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)), (DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)), (AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)),
......
...@@ -2,10 +2,20 @@ ...@@ -2,10 +2,20 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this # There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_mt5 import MT5Config from .configuration_mt5 import MT5Config
if is_sentencepiece_available():
from ..t5.tokenization_t5 import T5Tokenizer
MT5Tokenizer = T5Tokenizer
if is_tokenizers_available():
from ..t5.tokenization_t5_fast import T5TokenizerFast
MT5TokenizerFast = T5TokenizerFast
if is_torch_available(): if is_torch_available():
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
......
...@@ -56,6 +56,15 @@ class MBartTokenizer: ...@@ -56,6 +56,15 @@ class MBartTokenizer:
requires_sentencepiece(self) requires_sentencepiece(self)
class MT5Tokenizer:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_sentencepiece(self)
class PegasusTokenizer: class PegasusTokenizer:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_sentencepiece(self) requires_sentencepiece(self)
......
...@@ -164,6 +164,15 @@ class MobileBertTokenizerFast: ...@@ -164,6 +164,15 @@ class MobileBertTokenizerFast:
requires_tokenizers(self) requires_tokenizers(self)
class MT5TokenizerFast:
def __init__(self, *args, **kwargs):
requires_tokenizers(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tokenizers(self)
class OpenAIGPTTokenizerFast: class OpenAIGPTTokenizerFast:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tokenizers(self) requires_tokenizers(self)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import collections
import glob import glob
import importlib import importlib
import os import os
...@@ -298,6 +299,22 @@ def check_model_list_copy(overwrite=False, max_per_line=119): ...@@ -298,6 +299,22 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
) )
# Add here suffixes that are used to identify models, seperated by |
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
# Regexes that match TF/Flax/PT model names.
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
def camel_case_split(identifier):
"Split a camelcased `identifier` into words."
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
return [m.group(0) for m in matches]
def _center_text(text, width): def _center_text(text, width):
text_length = 2 if text == "✅" or text == "❌" else len(text) text_length = 2 if text == "✅" or text == "❌" else len(text)
left_indent = (width - text_length) // 2 left_indent = (width - text_length) // 2
...@@ -319,44 +336,43 @@ def get_model_table_from_auto_modules(): ...@@ -319,44 +336,43 @@ def get_model_table_from_auto_modules():
model_name_to_config = { model_name_to_config = {
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items() name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
} }
# All tokenizer tuples. model_name_to_prefix = {
tokenizers = { name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items()
name: transformers.TOKENIZER_MAPPING[config]
for name, config in model_name_to_config.items()
if config in transformers.TOKENIZER_MAPPING
} }
# Model names that a slow/fast tokenizer.
has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None]
has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None]
# Model names that have a PyTorch implementation.
has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING]
# Some of the GenerationModel don't have a base model.
has_pt_model.extend(
[
name
for name, config in model_name_to_config.items()
if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
]
)
# Special exception for RAG
has_pt_model.append("RAG")
# Model names that have a TensorFlow implementation.
has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING]
# Some of the GenerationModel don't have a base model.
has_tf_model.extend(
[
name
for name, config in model_name_to_config.items()
if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
]
)
# Model names that have a Flax implementation. # Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
has_flax_model = [ slow_tokenizers = collections.defaultdict(bool)
name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING fast_tokenizers = collections.defaultdict(bool)
] pt_models = collections.defaultdict(bool)
tf_models = collections.defaultdict(bool)
flax_models = collections.defaultdict(bool)
# Let's lookup through all transformers object (once).
for attr_name in dir(transformers):
lookup_dict = None
if attr_name.endswith("Tokenizer"):
lookup_dict = slow_tokenizers
attr_name = attr_name[:-9]
elif attr_name.endswith("TokenizerFast"):
lookup_dict = fast_tokenizers
attr_name = attr_name[:-13]
elif _re_tf_models.match(attr_name) is not None:
lookup_dict = tf_models
attr_name = _re_tf_models.match(attr_name).groups()[0]
elif _re_flax_models.match(attr_name) is not None:
lookup_dict = flax_models
attr_name = _re_flax_models.match(attr_name).groups()[0]
elif _re_pt_models.match(attr_name) is not None:
lookup_dict = pt_models
attr_name = _re_pt_models.match(attr_name).groups()[0]
if lookup_dict is not None:
while len(attr_name) > 0:
if attr_name in model_name_to_prefix.values():
lookup_dict[attr_name] = True
break
# Try again after removing the last word in the name
attr_name = "".join(camel_case_split(attr_name)[:-1])
# Let's build that table! # Let's build that table!
model_names = list(model_name_to_config.keys()) model_names = list(model_name_to_config.keys())
...@@ -374,13 +390,14 @@ def get_model_table_from_auto_modules(): ...@@ -374,13 +390,14 @@ def get_model_table_from_auto_modules():
check = {True: "✅", False: "❌"} check = {True: "✅", False: "❌"}
for name in model_names: for name in model_names:
prefix = model_name_to_prefix[name]
line = [ line = [
name, name,
check[name in has_slow_tokenizers], check[slow_tokenizers[prefix]],
check[name in has_fast_tokenizers], check[fast_tokenizers[prefix]],
check[name in has_pt_model], check[pt_models[prefix]],
check[name in has_tf_model], check[tf_models[prefix]],
check[name in has_flax_model], check[flax_models[prefix]],
] ]
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n" table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
table += "+" + "+".join(["-" * w for w in widths]) + "+\n" table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment