Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
...@@ -26,8 +26,8 @@ import tqdm ...@@ -26,8 +26,8 @@ import tqdm
import requests import requests
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
try: try:
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .tokenization_distilbert import DistilBertTokenizer
if is_tokenizers_available():
from .tokenization_distilbert_fast import DistilBertTokenizerFast
if is_torch_available():
from .modeling_distilbert import (
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
DistilBertModel,
DistilBertPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_distilbert import (
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDistilBertForMaskedLM,
TFDistilBertForMultipleChoice,
TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertMainLayer,
TFDistilBertModel,
TFDistilBertPreTrainedModel,
)
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" DistilBERT model configuration """ """ DistilBERT model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -27,15 +27,14 @@ import torch ...@@ -27,15 +27,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .activations import gelu from ...activations import gelu
from .configuration_distilbert import DistilBertConfig from ...file_utils import (
from .file_utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
MaskedLMOutput, MaskedLMOutput,
MultipleChoiceModelOutput, MultipleChoiceModelOutput,
...@@ -43,13 +42,14 @@ from .modeling_outputs import ( ...@@ -43,13 +42,14 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import ( from ...modeling_utils import (
PreTrainedModel, PreTrainedModel,
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from .utils import logging from ...utils import logging
from .configuration_distilbert import DistilBertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -19,15 +19,14 @@ ...@@ -19,15 +19,14 @@
import tensorflow as tf import tensorflow as tf
from .activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from .configuration_distilbert import DistilBertConfig from ...file_utils import (
from .file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from .modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFMaskedLMOutput, TFMaskedLMOutput,
TFMultipleChoiceModelOutput, TFMultipleChoiceModelOutput,
...@@ -35,7 +34,7 @@ from .modeling_tf_outputs import ( ...@@ -35,7 +34,7 @@ from .modeling_tf_outputs import (
TFSequenceClassifierOutput, TFSequenceClassifierOutput,
TFTokenClassifierOutput, TFTokenClassifierOutput,
) )
from .modeling_tf_utils import ( from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
TFPreTrainedModel, TFPreTrainedModel,
...@@ -47,8 +46,9 @@ from .modeling_tf_utils import ( ...@@ -47,8 +46,9 @@ from .modeling_tf_utils import (
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from ...tokenization_utils import BatchEncoding
from .utils import logging from ...utils import logging
from .configuration_distilbert import DistilBertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for DistilBERT.""" """Tokenization classes for DistilBERT."""
from .tokenization_bert import BertTokenizer from ...utils import logging
from .utils import logging from ..bert.tokenization_bert import BertTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for DistilBERT.""" """Tokenization classes for DistilBERT."""
from .tokenization_bert_fast import BertTokenizerFast from ...utils import logging
from ..bert.tokenization_bert_fast import BertTokenizerFast
from .tokenization_distilbert import DistilBertTokenizer from .tokenization_distilbert import DistilBertTokenizer
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
from .tokenization_dpr import (
DPRContextEncoderTokenizer,
DPRQuestionEncoderTokenizer,
DPRReaderOutput,
DPRReaderTokenizer,
)
if is_tokenizers_available():
from .tokenization_dpr_fast import (
DPRContextEncoderTokenizerFast,
DPRQuestionEncoderTokenizerFast,
DPRReaderTokenizerFast,
)
if is_torch_available():
from .modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder,
DPRPretrainedContextEncoder,
DPRPretrainedQuestionEncoder,
DPRPretrainedReader,
DPRQuestionEncoder,
DPRReader,
)
if is_tf_available():
from .modeling_tf_dpr import (
TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDPRContextEncoder,
TFDPRPretrainedContextEncoder,
TFDPRPretrainedQuestionEncoder,
TFDPRPretrainedReader,
TFDPRQuestionEncoder,
TFDPRReader,
)
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" DPR model configuration """ """ DPR model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,17 +21,17 @@ from typing import Optional, Tuple, Union ...@@ -21,17 +21,17 @@ from typing import Optional, Tuple, Union
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from .configuration_dpr import DPRConfig from ...file_utils import (
from .file_utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_bert import BertModel from ...modeling_outputs import BaseModelOutputWithPooling
from .modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel
from .modeling_utils import PreTrainedModel from ...utils import logging
from .utils import logging from ..bert.modeling_bert import BertModel
from .configuration_dpr import DPRConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -22,18 +22,18 @@ import tensorflow as tf ...@@ -22,18 +22,18 @@ import tensorflow as tf
from tensorflow import Tensor from tensorflow import Tensor
from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Dense
from .configuration_dpr import DPRConfig from ...file_utils import (
from .file_utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_tf_bert import TFBertMainLayer from ...modeling_tf_outputs import TFBaseModelOutputWithPooling
from .modeling_tf_outputs import TFBaseModelOutputWithPooling from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list from ...tokenization_utils import BatchEncoding
from .tokenization_utils import BatchEncoding from ...utils import logging
from .utils import logging from ..bert.modeling_tf_bert import TFBertMainLayer
from .configuration_dpr import DPRConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
import collections import collections
from typing import List, Optional, Union from typing import List, Optional, Union
from .file_utils import add_end_docstrings, add_start_docstrings from ...file_utils import add_end_docstrings, add_start_docstrings
from .tokenization_bert import BertTokenizer from ...tokenization_utils_base import BatchEncoding, TensorType
from .tokenization_utils_base import BatchEncoding, TensorType from ...utils import logging
from .utils import logging from ..bert.tokenization_bert import BertTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
import collections import collections
from typing import List, Optional, Union from typing import List, Optional, Union
from .file_utils import add_end_docstrings, add_start_docstrings from ...file_utils import add_end_docstrings, add_start_docstrings
from .tokenization_bert_fast import BertTokenizerFast from ...tokenization_utils_base import BatchEncoding, TensorType
from ...utils import logging
from ..bert.tokenization_bert_fast import BertTokenizerFast
from .tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer, DPRReaderTokenizer from .tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer, DPRReaderTokenizer
from .tokenization_utils_base import BatchEncoding, TensorType
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .tokenization_electra import ElectraTokenizer
if is_tokenizers_available():
from .tokenization_electra_fast import ElectraTokenizerFast
if is_torch_available():
from .modeling_electra import (
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,
ElectraForQuestionAnswering,
ElectraForSequenceClassification,
ElectraForTokenClassification,
ElectraModel,
ElectraPreTrainedModel,
load_tf_weights_in_electra,
)
if is_tf_available():
from .modeling_tf_electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
TFElectraForPreTraining,
TFElectraForQuestionAnswering,
TFElectraForSequenceClassification,
TFElectraForTokenClassification,
TFElectraModel,
TFElectraPreTrainedModel,
)
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" ELECTRA model configuration """ """ ELECTRA model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -24,16 +24,15 @@ import torch ...@@ -24,16 +24,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN, get_activation from ...activations import ACT2FN, get_activation
from .configuration_electra import ElectraConfig from ...file_utils import (
from .file_utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions, BaseModelOutputWithCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
MultipleChoiceModelOutput, MultipleChoiceModelOutput,
...@@ -41,14 +40,15 @@ from .modeling_outputs import ( ...@@ -41,14 +40,15 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import ( from ...modeling_utils import (
PreTrainedModel, PreTrainedModel,
SequenceSummary, SequenceSummary,
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from .utils import logging from ...utils import logging
from .configuration_electra import ElectraConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -167,7 +167,7 @@ class ElectraEmbeddings(nn.Module): ...@@ -167,7 +167,7 @@ class ElectraEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
# Copied from transformers.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
...@@ -193,7 +193,7 @@ class ElectraEmbeddings(nn.Module): ...@@ -193,7 +193,7 @@ class ElectraEmbeddings(nn.Module):
return embeddings return embeddings
# Copied from transformers.modeling_bert.BertSelfAttention with Bert->Electra # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra
class ElectraSelfAttention(nn.Module): class ElectraSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -272,7 +272,7 @@ class ElectraSelfAttention(nn.Module): ...@@ -272,7 +272,7 @@ class ElectraSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.modeling_bert.BertSelfOutput # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
class ElectraSelfOutput(nn.Module): class ElectraSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -287,7 +287,7 @@ class ElectraSelfOutput(nn.Module): ...@@ -287,7 +287,7 @@ class ElectraSelfOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.modeling_bert.BertAttention with Bert->Electra # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra
class ElectraAttention(nn.Module): class ElectraAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -335,7 +335,7 @@ class ElectraAttention(nn.Module): ...@@ -335,7 +335,7 @@ class ElectraAttention(nn.Module):
return outputs return outputs
# Copied from transformers.modeling_bert.BertIntermediate # Copied from transformers.models.bert.modeling_bert.BertIntermediate
class ElectraIntermediate(nn.Module): class ElectraIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -351,7 +351,7 @@ class ElectraIntermediate(nn.Module): ...@@ -351,7 +351,7 @@ class ElectraIntermediate(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.modeling_bert.BertOutput # Copied from transformers.models.bert.modeling_bert.BertOutput
class ElectraOutput(nn.Module): class ElectraOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -366,7 +366,7 @@ class ElectraOutput(nn.Module): ...@@ -366,7 +366,7 @@ class ElectraOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.modeling_bert.BertLayer with Bert->Electra # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra
class ElectraLayer(nn.Module): class ElectraLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -426,7 +426,7 @@ class ElectraLayer(nn.Module): ...@@ -426,7 +426,7 @@ class ElectraLayer(nn.Module):
return layer_output return layer_output
# Copied from transformers.modeling_bert.BertEncoder with Bert->Electra # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Electra
class ElectraEncoder(nn.Module): class ElectraEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -548,7 +548,7 @@ class ElectraPreTrainedModel(PreTrainedModel): ...@@ -548,7 +548,7 @@ class ElectraPreTrainedModel(PreTrainedModel):
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"] authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
# Copied from transformers.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment