Unverified Commit 36a19915 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix model templates (#8595)

* First fixes

* Fix imports and add init

* Fix typo

* Move init to final dest

* Fix tokenization import

* More fixes

* Styling
parent 042a6aa7
...@@ -47,7 +47,7 @@ class AddNewModelCommand(BaseTransformersCLICommand): ...@@ -47,7 +47,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
path_to_transformer_root = ( path_to_transformer_root = (
Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent
) )
path_to_cookiecutter = path_to_transformer_root / "templates" / "cookiecutter" path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model"
# Execute cookiecutter # Execute cookiecutter
if not self._testing: if not self._testing:
...@@ -75,9 +75,16 @@ class AddNewModelCommand(BaseTransformersCLICommand): ...@@ -75,9 +75,16 @@ class AddNewModelCommand(BaseTransformersCLICommand):
output_pytorch = "PyTorch" in pytorch_or_tensorflow output_pytorch = "PyTorch" in pytorch_or_tensorflow
output_tensorflow = "TensorFlow" in pytorch_or_tensorflow output_tensorflow = "TensorFlow" in pytorch_or_tensorflow
model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}"
os.makedirs(model_dir, exist_ok=True)
shutil.move(
f"{directory}/__init__.py",
f"{model_dir}/__init__.py",
)
shutil.move( shutil.move(
f"{directory}/configuration_{lowercase_model_name}.py", f"{directory}/configuration_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/configuration_{lowercase_model_name}.py", f"{model_dir}/configuration_{lowercase_model_name}.py",
) )
def remove_copy_lines(path): def remove_copy_lines(path):
...@@ -94,7 +101,7 @@ class AddNewModelCommand(BaseTransformersCLICommand): ...@@ -94,7 +101,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
shutil.move( shutil.move(
f"{directory}/modeling_{lowercase_model_name}.py", f"{directory}/modeling_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/modeling_{lowercase_model_name}.py", f"{model_dir}/modeling_{lowercase_model_name}.py",
) )
shutil.move( shutil.move(
...@@ -111,7 +118,7 @@ class AddNewModelCommand(BaseTransformersCLICommand): ...@@ -111,7 +118,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
shutil.move( shutil.move(
f"{directory}/modeling_tf_{lowercase_model_name}.py", f"{directory}/modeling_tf_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/modeling_tf_{lowercase_model_name}.py", f"{model_dir}/modeling_tf_{lowercase_model_name}.py",
) )
shutil.move( shutil.move(
...@@ -129,7 +136,7 @@ class AddNewModelCommand(BaseTransformersCLICommand): ...@@ -129,7 +136,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
shutil.move( shutil.move(
f"{directory}/tokenization_{lowercase_model_name}.py", f"{directory}/tokenization_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/tokenization_{lowercase_model_name}.py", f"{model_dir}/tokenization_{lowercase_model_name}.py",
) )
from os import fdopen, remove from os import fdopen, remove
......
...@@ -21,6 +21,8 @@ from collections import OrderedDict ...@@ -21,6 +21,8 @@ from collections import OrderedDict
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings from ...file_utils import add_start_docstrings
from ...utils import logging from ...utils import logging
# Add modeling imports here
from ..albert.modeling_albert import ( from ..albert.modeling_albert import (
AlbertForMaskedLM, AlbertForMaskedLM,
AlbertForMultipleChoice, AlbertForMultipleChoice,
...@@ -228,8 +230,6 @@ from .configuration_auto import ( ...@@ -228,8 +230,6 @@ from .configuration_auto import (
) )
# Add modeling imports here
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,6 +21,8 @@ from collections import OrderedDict ...@@ -21,6 +21,8 @@ from collections import OrderedDict
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings from ...file_utils import add_start_docstrings
from ...utils import logging from ...utils import logging
# Add modeling imports here
from ..albert.modeling_tf_albert import ( from ..albert.modeling_tf_albert import (
TFAlbertForMaskedLM, TFAlbertForMaskedLM,
TFAlbertForMultipleChoice, TFAlbertForMultipleChoice,
...@@ -175,8 +177,6 @@ from .configuration_auto import ( ...@@ -175,8 +177,6 @@ from .configuration_auto import (
) )
# Add modeling imports here
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
{%- if cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" %}
from ...file_utils import is_tf_available, is_torch_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "PyTorch" %}
from ...file_utils import is_torch_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow" %}
from ...file_utils import is_tf_available
{% endif %}
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %}
if is_torch_available():
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Layer,
{{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
)
{% endif %}
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow") %}
if is_tf_available():
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
TF{{cookiecutter.camelcase_modelname}}Layer,
TF{{cookiecutter.camelcase_modelname}}Model,
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% endif %}
\ No newline at end of file
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" {{cookiecutter.modelname}} model configuration """ """ {{cookiecutter.modelname}} model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -17,15 +17,14 @@ ...@@ -17,15 +17,14 @@
import tensorflow as tf import tensorflow as tf
from .activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config from ...file_utils import (
from .file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from .modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
TFMaskedLMOutput, TFMaskedLMOutput,
...@@ -34,7 +33,7 @@ from .modeling_tf_outputs import ( ...@@ -34,7 +33,7 @@ from .modeling_tf_outputs import (
TFSequenceClassifierOutput, TFSequenceClassifierOutput,
TFTokenClassifierOutput, TFTokenClassifierOutput,
) )
from .modeling_tf_utils import ( from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
TFPreTrainedModel, TFPreTrainedModel,
...@@ -46,8 +45,9 @@ from .modeling_tf_utils import ( ...@@ -46,8 +45,9 @@ from .modeling_tf_utils import (
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from ...tokenization_utils import BatchEncoding
from .utils import logging from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -25,13 +25,12 @@ import torch.utils.checkpoint ...@@ -25,13 +25,12 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config from ...file_utils import (
from .file_utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from .modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
MaskedLMOutput, MaskedLMOutput,
...@@ -40,15 +39,16 @@ from .modeling_outputs import ( ...@@ -40,15 +39,16 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import ( from ...modeling_utils import (
PreTrainedModel, PreTrainedModel,
SequenceSummary, SequenceSummary,
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from .utils import logging from ...utils import logging
from .activations import ACT2FN from ...activations import ACT2FN
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# To replace in: "src/transformers/__init__.py" # To replace in: "src/transformers/__init__.py"
# Below: "if is_torch_available():" if generating PyTorch # Below: "if is_torch_available():" if generating PyTorch
# Replace with: # Replace with:
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM, {{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice, {{cookiecutter.camelcase_modelname}}ForMultipleChoice,
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
# Below: "if is_tf_available():" if generating TensorFlow # Below: "if is_tf_available():" if generating TensorFlow
# Replace with: # Replace with:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
...@@ -44,14 +44,14 @@ ...@@ -44,14 +44,14 @@
# End. # End.
# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig" # Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
# Replace with: # Replace with:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
# End. # End.
# To replace in: "src/transformers/configuration_auto.py" # To replace in: "src/transformers/models/auto/configuration_auto.py"
# Below: "# Add configs here" # Below: "# Add configs here"
# Replace with: # Replace with:
("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config), ("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config),
...@@ -62,9 +62,9 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u ...@@ -62,9 +62,9 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u
{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP,
# End. # End.
# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig", # Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Replace with: # Replace with:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
# End. # End.
# Below: "# Add full (and cased) model names here" # Below: "# Add full (and cased) model names here"
...@@ -83,7 +83,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u ...@@ -83,7 +83,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u
# Below: "# Add modeling imports here" # Below: "# Add modeling imports here"
# Replace with: # Replace with:
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForMaskedLM, {{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice, {{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
...@@ -138,7 +138,7 @@ from .modeling_{{cookiecutter.lowercase_modelname}} import ( ...@@ -138,7 +138,7 @@ from .modeling_{{cookiecutter.lowercase_modelname}} import (
# Below: "# Add modeling imports here" # Below: "# Add modeling imports here"
# Replace with: # Replace with:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import ( from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
"""Tokenization classes for {{cookiecutter.modelname}}.""" """Tokenization classes for {{cookiecutter.modelname}}."""
{%- if cookiecutter.tokenizer_type == "Based on BERT" %} {%- if cookiecutter.tokenizer_type == "Based on BERT" %}
from .tokenization_bert import BertTokenizer, BertTokenizerFast from ...utils import logging
from .utils import logging from ..bert.tokenization_bert import BertTokenizer
from ..bert.tokenization_bert_fast import BertTokenizerFast
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -73,14 +74,14 @@ class {{cookiecutter.camelcase_modelname}}TokenizerFast(BertTokenizerFast): ...@@ -73,14 +74,14 @@ class {{cookiecutter.camelcase_modelname}}TokenizerFast(BertTokenizerFast):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
{%- elif cookiecutter.tokenizer_type == "Standalone" %} {%- elif cookiecutter.tokenizer_type == "Standalone" %}
import warnings import warnings
from typing import List, Optional
from tokenizers import ByteLevelBPETokenizer from tokenizers import ByteLevelBPETokenizer
from .tokenization_utils import AddedToken, PreTrainedTokenizer from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from .tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from typing import List, Optional from ...utils import logging
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment