Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
from typing import List, Optional from typing import List, Optional
from .tokenization_gpt2_fast import GPT2TokenizerFast from ...tokenization_utils_base import AddedToken
from ...utils import logging
from ..gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from .tokenization_roberta import RobertaTokenizer from .tokenization_roberta import RobertaTokenizer
from .tokenization_utils_base import AddedToken
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tokenizers_available, is_torch_available
from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from .tokenization_squeezebert import SqueezeBertTokenizer
if is_tokenizers_available():
from .tokenization_squeezebert_fast import SqueezeBertTokenizerFast
if is_torch_available():
from .modeling_squeezebert import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
SqueezeBertForMaskedLM,
SqueezeBertForMultipleChoice,
SqueezeBertForQuestionAnswering,
SqueezeBertForSequenceClassification,
SqueezeBertForTokenClassification,
SqueezeBertModel,
SqueezeBertModule,
SqueezeBertPreTrainedModel,
)
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" SqueezeBERT model configuration """ """ SqueezeBERT model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,10 +21,9 @@ import torch ...@@ -21,10 +21,9 @@ import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN from ...activations import ACT2FN
from .configuration_squeezebert import SqueezeBertConfig from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_outputs import (
from .modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
MaskedLMOutput, MaskedLMOutput,
...@@ -33,8 +32,9 @@ from .modeling_outputs import ( ...@@ -33,8 +32,9 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from .utils import logging from ...utils import logging
from .configuration_squeezebert import SqueezeBertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for SqueezeBERT.""" """Tokenization classes for SqueezeBERT."""
from .tokenization_bert import BertTokenizer from ...utils import logging
from .utils import logging from ..bert.tokenization_bert import BertTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for SqueezeBERT.""" """Tokenization classes for SqueezeBERT."""
from .tokenization_bert_fast import BertTokenizerFast from ...utils import logging
from ..bert.tokenization_bert_fast import BertTokenizerFast
from .tokenization_squeezebert import SqueezeBertTokenizer from .tokenization_squeezebert import SqueezeBertTokenizer
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
if is_sentencepiece_available():
from .tokenization_t5 import T5Tokenizer
if is_tokenizers_available():
from .tokenization_t5_fast import T5TokenizerFast
if is_torch_available():
from .modeling_t5 import (
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
T5ForConditionalGeneration,
T5Model,
T5PreTrainedModel,
load_tf_weights_in_t5,
)
if is_tf_available():
from .modeling_tf_t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5ForConditionalGeneration,
TFT5Model,
TFT5PreTrainedModel,
)
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" T5 model configuration """ """ T5 model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -25,22 +25,22 @@ import torch.nn.functional as F ...@@ -25,22 +25,22 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .configuration_t5 import T5Config from ...file_utils import (
from .file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .utils import logging from ...utils import logging
from .configuration_t5 import T5Config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -26,16 +26,15 @@ import tensorflow as tf ...@@ -26,16 +26,15 @@ import tensorflow as tf
from transformers.modeling_tf_utils import TFWrappedEmbeddings from transformers.modeling_tf_utils import TFWrappedEmbeddings
from .configuration_t5 import T5Config from ...file_utils import (
from .file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_tf_outputs import TFSeq2SeqLMOutput, TFSeq2SeqModelOutput from ...modeling_tf_outputs import TFSeq2SeqLMOutput, TFSeq2SeqModelOutput
from .modeling_tf_utils import ( from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
...@@ -43,8 +42,9 @@ from .modeling_tf_utils import ( ...@@ -43,8 +42,9 @@ from .modeling_tf_utils import (
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from ...tokenization_utils import BatchEncoding
from .utils import logging from ...utils import logging
from .configuration_t5 import T5Config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -23,10 +23,10 @@ from typing import List, Optional, Tuple ...@@ -23,10 +23,10 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from .file_utils import add_start_docstrings from ...file_utils import add_start_docstrings
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -19,11 +19,11 @@ import os ...@@ -19,11 +19,11 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from .file_utils import add_start_docstrings, is_sentencepiece_available from ...file_utils import add_start_docstrings, is_sentencepiece_available
from .tokenization_utils import BatchEncoding from ...tokenization_utils import BatchEncoding
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging from ...utils import logging
if is_sentencepiece_available(): if is_sentencepiece_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer
if is_torch_available():
from .modeling_transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
AdaptiveEmbedding,
TransfoXLLMHeadModel,
TransfoXLModel,
TransfoXLPreTrainedModel,
load_tf_weights_in_transfo_xl,
)
if is_tf_available():
from .modeling_tf_transfo_xl import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding,
TFTransfoXLLMHeadModel,
TFTransfoXLMainLayer,
TFTransfoXLModel,
TFTransfoXLPreTrainedModel,
)
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
import warnings import warnings
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -22,7 +22,7 @@ import sys ...@@ -22,7 +22,7 @@ import sys
import torch import torch
import transformers.tokenization_transfo_xl as data_utils import transformers.models.transfo_xl.tokenization_transfo_xl as data_utils
from transformers import ( from transformers import (
CONFIG_NAME, CONFIG_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
...@@ -30,7 +30,7 @@ from transformers import ( ...@@ -30,7 +30,7 @@ from transformers import (
TransfoXLLMHeadModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl, load_tf_weights_in_transfo_xl,
) )
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES from transformers.models.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
from transformers.utils import logging from transformers.utils import logging
......
...@@ -22,17 +22,17 @@ from typing import List, Optional, Tuple ...@@ -22,17 +22,17 @@ from typing import List, Optional, Tuple
import tensorflow as tf import tensorflow as tf
from .configuration_transfo_xl import TransfoXLConfig from ...file_utils import (
from .file_utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_transfo_xl import TransfoXLConfig
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
import tensorflow as tf import tensorflow as tf
from .modeling_tf_utils import shape_list from ...modeling_tf_utils import shape_list
class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
......
...@@ -25,16 +25,16 @@ import torch ...@@ -25,16 +25,16 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .configuration_transfo_xl import TransfoXLConfig from ...file_utils import (
from .file_utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_transfo_xl import TransfoXLConfig
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax
from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment