Commit 4497e78d authored by Nathan Lambert's avatar Nathan Lambert
Browse files

merge unet-rl formatting

parents 49718b47 77aadfee
...@@ -22,7 +22,7 @@ import re ...@@ -22,7 +22,7 @@ import re
# All paths are set with the intent you should run this script from the root of the repo with the command # All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_table.py # python utils/check_table.py
TRANSFORMERS_PATH = "src/transformers" TRANSFORMERS_PATH = "src/diffusers"
PATH_TO_DOCS = "docs/source/en" PATH_TO_DOCS = "docs/source/en"
REPO_PATH = "." REPO_PATH = "."
...@@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe ...@@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") _re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# This is to make sure the transformers module imported is the one in the repo. # This is to make sure the diffusers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
"transformers", "diffusers",
os.path.join(TRANSFORMERS_PATH, "__init__.py"), os.path.join(TRANSFORMERS_PATH, "__init__.py"),
submodule_search_locations=[TRANSFORMERS_PATH], submodule_search_locations=[TRANSFORMERS_PATH],
) )
transformers_module = spec.loader.load_module() diffusers_module = spec.loader.load_module()
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python # Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
...@@ -88,10 +88,10 @@ def _center_text(text, width): ...@@ -88,10 +88,10 @@ def _center_text(text, width):
def get_model_table_from_auto_modules(): def get_model_table_from_auto_modules():
"""Generates an up-to-date model table from the content of the auto modules.""" """Generates an up-to-date model table from the content of the auto modules."""
# Dictionary model names to config. # Dictionary model names to config.
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES config_maping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
model_name_to_config = { model_name_to_config = {
name: config_maping_names[code] name: config_maping_names[code]
for code, name in transformers_module.MODEL_NAMES_MAPPING.items() for code, name in diffusers_module.MODEL_NAMES_MAPPING.items()
if code in config_maping_names if code in config_maping_names
} }
model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()} model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
...@@ -103,8 +103,8 @@ def get_model_table_from_auto_modules(): ...@@ -103,8 +103,8 @@ def get_model_table_from_auto_modules():
tf_models = collections.defaultdict(bool) tf_models = collections.defaultdict(bool)
flax_models = collections.defaultdict(bool) flax_models = collections.defaultdict(bool)
# Let's lookup through all transformers object (once). # Let's lookup through all diffusers object (once).
for attr_name in dir(transformers_module): for attr_name in dir(diffusers_module):
lookup_dict = None lookup_dict = None
if attr_name.endswith("Tokenizer"): if attr_name.endswith("Tokenizer"):
lookup_dict = slow_tokenizers lookup_dict = slow_tokenizers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment