Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
......@@ -14,8 +14,8 @@
# limitations under the License.
""" MBART configuration """
from .configuration_bart import BartConfig
from .utils import logging
from ...utils import logging
from ..bart.configuration_bart import BartConfig
logger = logging.get_logger(__name__)
......
......@@ -4,7 +4,7 @@ import torch
from transformers import BartForConditionalGeneration, MBartConfig
from .convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
......
from ..bart.modeling_bart import BartForConditionalGeneration
from .configuration_mbart import MBartConfig
from .modeling_bart import BartForConditionalGeneration
_CONFIG_FOR_DOC = "MBartConfig"
......
......@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF mBART model, originally from fairseq."""
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..bart.modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration
from .configuration_mbart import MBartConfig
from .file_utils import add_start_docstrings
from .modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration
from .utils import logging
_CONFIG_FOR_DOC = "MBartConfig"
......@@ -33,4 +33,4 @@ logger = logging.get_logger(__name__)
@add_start_docstrings("mBART (multilingual BART) model for machine translation", START_DOCSTRING)
class TFMBartForConditionalGeneration(TFBartForConditionalGeneration):
config_class = MBartConfig
# All the code is in src/transformers/modeling_tf_bart.py
# All the code is in src/transformers/models/bart/modeling_tf_bart.py
......@@ -15,11 +15,11 @@
from typing import List, Optional
from .file_utils import add_start_docstrings
from .tokenization_utils import BatchEncoding
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .utils import logging
from ...file_utils import add_start_docstrings
from ...tokenization_utils import BatchEncoding
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from ...utils import logging
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
logger = logging.get_logger(__name__)
......
......@@ -17,11 +17,11 @@ from typing import List, Optional
from tokenizers import processors
from .file_utils import add_start_docstrings, is_sentencepiece_available
from .tokenization_utils import BatchEncoding
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
from .utils import logging
from ...file_utils import add_start_docstrings, is_sentencepiece_available
from ...tokenization_utils import BatchEncoding
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from ...utils import logging
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
if is_sentencepiece_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_torch_available
from .configuration_mmbt import MMBTConfig
if is_torch_available():
from .modeling_mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
......@@ -15,7 +15,7 @@
# limitations under the License.
""" MMBT configuration """
from .utils import logging
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -20,10 +20,10 @@ import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from .modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
from .modeling_utils import ModuleUtilsMixin
from .utils import logging
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import ModuleUtilsMixin
from ...utils import logging
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
from .tokenization_mobilebert import MobileBertTokenizer
if is_tokenizers_available():
from .tokenization_mobilebert_fast import MobileBertTokenizerFast
if is_torch_available():
from .modeling_mobilebert import (
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
MobileBertForNextSentencePrediction,
MobileBertForPreTraining,
MobileBertForQuestionAnswering,
MobileBertForSequenceClassification,
MobileBertForTokenClassification,
MobileBertLayer,
MobileBertModel,
MobileBertPreTrainedModel,
load_tf_weights_in_mobilebert,
)
if is_tf_available():
from .modeling_tf_mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
TFMobileBertForNextSentencePrediction,
TFMobileBertForPreTraining,
TFMobileBertForQuestionAnswering,
TFMobileBertForSequenceClassification,
TFMobileBertForTokenClassification,
TFMobileBertMainLayer,
TFMobileBertModel,
TFMobileBertPreTrainedModel,
)
......@@ -12,8 +12,8 @@
# limitations under the License.
""" MobileBERT model configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -31,16 +31,15 @@ import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN
from .configuration_mobilebert import MobileBertConfig
from .file_utils import (
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_outputs import (
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
......@@ -50,8 +49,9 @@ from .modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .utils import logging
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_mobilebert import MobileBertConfig
logger = logging.get_logger(__name__)
......
......@@ -21,9 +21,8 @@ from typing import Optional, Tuple
import tensorflow as tf
from . import MobileBertConfig
from .activations_tf import get_tf_activation
from .file_utils import (
from ...activations_tf import get_tf_activation
from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
add_code_sample_docstrings,
......@@ -31,7 +30,7 @@ from .file_utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_tf_outputs import (
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
......@@ -41,7 +40,7 @@ from .modeling_tf_outputs import (
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from .modeling_tf_utils import (
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFNextSentencePredictionLoss,
......@@ -53,8 +52,9 @@ from .modeling_tf_utils import (
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
from .utils import logging
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_mobilebert import MobileBertConfig
logger = logging.get_logger(__name__)
......
......@@ -13,8 +13,8 @@
# limitations under the License.
"""Tokenization classes for MobileBERT."""
from .tokenization_bert import BertTokenizer
from .utils import logging
from ...utils import logging
from ..bert.tokenization_bert import BertTokenizer
logger = logging.get_logger(__name__)
......
......@@ -13,9 +13,9 @@
# limitations under the License.
"""Tokenization classes for MobileBERT."""
from .tokenization_bert_fast import BertTokenizerFast
from ...utils import logging
from ..bert.tokenization_bert_fast import BertTokenizerFast
from .tokenization_mobilebert import MobileBertTokenizer
from .utils import logging
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .tokenization_openai import OpenAIGPTTokenizer
if is_tokenizers_available():
from .tokenization_openai_fast import OpenAIGPTTokenizerFast
if is_torch_available():
from .modeling_openai import (
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OpenAIGPTDoubleHeadsModel,
OpenAIGPTForSequenceClassification,
OpenAIGPTLMHeadModel,
OpenAIGPTModel,
OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt,
)
if is_tf_available():
from .modeling_tf_openai import (
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFOpenAIGPTDoubleHeadsModel,
TFOpenAIGPTLMHeadModel,
TFOpenAIGPTMainLayer,
TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel,
)
......@@ -15,8 +15,8 @@
# limitations under the License.
""" OpenAI GPT configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -27,24 +27,24 @@ import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu_new, silu
from .configuration_openai import OpenAIGPTConfig
from .file_utils import (
from ...activations import gelu_new, silu
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from .modeling_utils import (
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import (
Conv1D,
PreTrainedModel,
SequenceSummary,
find_pruneable_heads_and_indices,
prune_conv1d_layer,
)
from .utils import logging
from ...utils import logging
from .configuration_openai import OpenAIGPTConfig
logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment