Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokenizer
if is_torch_available():
from .modeling_blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
BlenderbotConfig has the same signature as BartConfig. We only rewrite the signature in order to document BlenderbotConfig has the same signature as BartConfig. We only rewrite the signature in order to document
blenderbot-90M defaults. blenderbot-90M defaults.
""" """
from .configuration_bart import BartConfig from ..bart.configuration_bart import BartConfig
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP = { BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
import torch import torch
from ...file_utils import add_start_docstrings
from ..bart.modeling_bart import BartForConditionalGeneration
from .configuration_blenderbot import BlenderbotConfig from .configuration_blenderbot import BlenderbotConfig
from .file_utils import add_start_docstrings
from .modeling_bart import BartForConditionalGeneration
BLENDER_START_DOCSTRING = r""" BLENDER_START_DOCSTRING = r"""
......
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""TF BlenderBot model, ported from the fairseq repo.""" """TF BlenderBot model, ported from the fairseq repo."""
from ...file_utils import add_start_docstrings, is_tf_available
from ...utils import logging
from ..bart.modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .configuration_blenderbot import BlenderbotConfig from .configuration_blenderbot import BlenderbotConfig
from .file_utils import add_start_docstrings, is_tf_available
from .modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .utils import logging
if is_tf_available(): if is_tf_available():
......
...@@ -21,9 +21,9 @@ from typing import Dict, List, Optional, Tuple ...@@ -21,9 +21,9 @@ from typing import Dict, List, Optional, Tuple
import regex as re import regex as re
from .tokenization_roberta import RobertaTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import PreTrainedTokenizer from ...utils import logging
from .utils import logging from ..roberta.tokenization_roberta import RobertaTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
if is_sentencepiece_available():
from .tokenization_camembert import CamembertTokenizer
if is_tokenizers_available():
from .tokenization_camembert_fast import CamembertTokenizerFast
if is_torch_available():
from .modeling_camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM,
CamembertForMaskedLM,
CamembertForMultipleChoice,
CamembertForQuestionAnswering,
CamembertForSequenceClassification,
CamembertForTokenClassification,
CamembertModel,
)
if is_tf_available():
from .modeling_tf_camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM,
TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering,
TFCamembertForSequenceClassification,
TFCamembertForTokenClassification,
TFCamembertModel,
)
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" CamemBERT configuration """ """ CamemBERT configuration """
from .configuration_roberta import RobertaConfig from ...utils import logging
from .utils import logging from ..roberta.configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
# limitations under the License. # limitations under the License.
"""PyTorch CamemBERT model. """ """PyTorch CamemBERT model. """
from .configuration_camembert import CamembertConfig from ...file_utils import add_start_docstrings
from .file_utils import add_start_docstrings from ...utils import logging
from .modeling_roberta import ( from ..roberta.modeling_roberta import (
RobertaForCausalLM, RobertaForCausalLM,
RobertaForMaskedLM, RobertaForMaskedLM,
RobertaForMultipleChoice, RobertaForMultipleChoice,
...@@ -26,7 +26,7 @@ from .modeling_roberta import ( ...@@ -26,7 +26,7 @@ from .modeling_roberta import (
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaModel, RobertaModel,
) )
from .utils import logging from .configuration_camembert import CamembertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 CamemBERT model. """ """ TF 2.0 CamemBERT model. """
from .configuration_camembert import CamembertConfig from ...file_utils import add_start_docstrings
from .file_utils import add_start_docstrings from ...utils import logging
from .modeling_tf_roberta import ( from ..roberta.modeling_tf_roberta import (
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForMultipleChoice, TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering, TFRobertaForQuestionAnswering,
...@@ -25,7 +25,7 @@ from .modeling_tf_roberta import ( ...@@ -25,7 +25,7 @@ from .modeling_tf_roberta import (
TFRobertaForTokenClassification, TFRobertaForTokenClassification,
TFRobertaModel, TFRobertaModel,
) )
from .utils import logging from .configuration_camembert import CamembertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple ...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -19,9 +19,9 @@ import os ...@@ -19,9 +19,9 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from .file_utils import is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from .tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging from ...utils import logging
if is_sentencepiece_available(): if is_sentencepiece_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .tokenization_ctrl import CTRLTokenizer
if is_torch_available():
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel
if is_tf_available():
from .modeling_tf_ctrl import (
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCTRLLMHeadModel,
TFCTRLModel,
TFCTRLPreTrainedModel,
)
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" Salesforce CTRL configuration """ """ Salesforce CTRL configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -23,11 +23,11 @@ import torch ...@@ -23,11 +23,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -19,18 +19,18 @@ ...@@ -19,18 +19,18 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_ctrl import CTRLConfig from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
from .modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast from ...modeling_tf_utils import (
from .modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from ...tokenization_utils import BatchEncoding
from .utils import logging from ...utils import logging
from .configuration_ctrl import CTRLConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,8 +21,8 @@ from typing import Optional, Tuple ...@@ -21,8 +21,8 @@ from typing import Optional, Tuple
import regex as re import regex as re
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_torch_available
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
from .tokenization_deberta import DebertaTokenizer
if is_torch_available():
from .modeling_deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForSequenceClassification,
DebertaModel,
DebertaPreTrainedModel,
)
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" DeBERTa model configuration """ """ DeBERTa model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -22,12 +22,12 @@ from packaging import version ...@@ -22,12 +22,12 @@ from packaging import version
from torch import _softmax_backward_data, nn from torch import _softmax_backward_data, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_deberta import DebertaConfig from .configuration_deberta import DebertaConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -74,7 +74,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -74,7 +74,7 @@ class XSoftmax(torch.autograd.Function):
Example:: Example::
import torch import torch
from transformers.modeling_deroberta import XSoftmax from transformers.models.deberta import XSoftmax
# Make a tensor # Make a tensor
x = torch.randn([4,20,100]) x = torch.randn([4,20,100])
# Create a mask # Create a mask
...@@ -278,7 +278,7 @@ class DebertaAttention(nn.Module): ...@@ -278,7 +278,7 @@ class DebertaAttention(nn.Module):
return attention_output return attention_output
# Copied from transformers.modeling_bert.BertIntermediate with Bert->Deberta # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta
class DebertaIntermediate(nn.Module): class DebertaIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment