Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokenizer
if is_torch_available():
from .modeling_blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
......@@ -18,7 +18,7 @@
BlenderbotConfig has the same signature as BartConfig. We only rewrite the signature in order to document
blenderbot-90M defaults.
"""
from .configuration_bart import BartConfig
from ..bart.configuration_bart import BartConfig
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
......
......@@ -18,9 +18,9 @@
import torch
from ...file_utils import add_start_docstrings
from ..bart.modeling_bart import BartForConditionalGeneration
from .configuration_blenderbot import BlenderbotConfig
from .file_utils import add_start_docstrings
from .modeling_bart import BartForConditionalGeneration
BLENDER_START_DOCSTRING = r"""
......
......@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF BlenderBot model, ported from the fairseq repo."""
from ...file_utils import add_start_docstrings, is_tf_available
from ...utils import logging
from ..bart.modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .configuration_blenderbot import BlenderbotConfig
from .file_utils import add_start_docstrings, is_tf_available
from .modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .utils import logging
if is_tf_available():
......
......@@ -21,9 +21,9 @@ from typing import Dict, List, Optional, Tuple
import regex as re
from .tokenization_roberta import RobertaTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
from ..roberta.tokenization_roberta import RobertaTokenizer
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
if is_sentencepiece_available():
from .tokenization_camembert import CamembertTokenizer
if is_tokenizers_available():
from .tokenization_camembert_fast import CamembertTokenizerFast
if is_torch_available():
from .modeling_camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM,
CamembertForMaskedLM,
CamembertForMultipleChoice,
CamembertForQuestionAnswering,
CamembertForSequenceClassification,
CamembertForTokenClassification,
CamembertModel,
)
if is_tf_available():
from .modeling_tf_camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM,
TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering,
TFCamembertForSequenceClassification,
TFCamembertForTokenClassification,
TFCamembertModel,
)
......@@ -15,8 +15,8 @@
# limitations under the License.
""" CamemBERT configuration """
from .configuration_roberta import RobertaConfig
from .utils import logging
from ...utils import logging
from ..roberta.configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__)
......
......@@ -15,9 +15,9 @@
# limitations under the License.
"""PyTorch CamemBERT model. """
from .configuration_camembert import CamembertConfig
from .file_utils import add_start_docstrings
from .modeling_roberta import (
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..roberta.modeling_roberta import (
RobertaForCausalLM,
RobertaForMaskedLM,
RobertaForMultipleChoice,
......@@ -26,7 +26,7 @@ from .modeling_roberta import (
RobertaForTokenClassification,
RobertaModel,
)
from .utils import logging
from .configuration_camembert import CamembertConfig
logger = logging.get_logger(__name__)
......
......@@ -15,9 +15,9 @@
# limitations under the License.
""" TF 2.0 CamemBERT model. """
from .configuration_camembert import CamembertConfig
from .file_utils import add_start_docstrings
from .modeling_tf_roberta import (
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..roberta.modeling_tf_roberta import (
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
......@@ -25,7 +25,7 @@ from .modeling_tf_roberta import (
TFRobertaForTokenClassification,
TFRobertaModel,
)
from .utils import logging
from .configuration_camembert import CamembertConfig
logger = logging.get_logger(__name__)
......
......@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
import sentencepiece as spm
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -19,9 +19,9 @@ import os
from shutil import copyfile
from typing import List, Optional, Tuple
from .file_utils import is_sentencepiece_available
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
from ...file_utils import is_sentencepiece_available
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
if is_sentencepiece_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .tokenization_ctrl import CTRLTokenizer
if is_torch_available():
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel
if is_tf_available():
from .modeling_tf_ctrl import (
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCTRLLMHeadModel,
TFCTRLModel,
TFCTRLPreTrainedModel,
)
......@@ -14,8 +14,8 @@
# limitations under the License.
""" Salesforce CTRL configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -23,11 +23,11 @@ import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_ctrl import CTRLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .utils import logging
logger = logging.get_logger(__name__)
......
......@@ -19,18 +19,18 @@
import numpy as np
import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
from .modeling_tf_utils import (
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
TFSharedEmbeddings,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
from .utils import logging
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_ctrl import CTRLConfig
logger = logging.get_logger(__name__)
......
......@@ -21,8 +21,8 @@ from typing import Optional, Tuple
import regex as re
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_torch_available
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
from .tokenization_deberta import DebertaTokenizer
if is_torch_available():
from .modeling_deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForSequenceClassification,
DebertaModel,
DebertaPreTrainedModel,
)
......@@ -14,8 +14,8 @@
# limitations under the License.
""" DeBERTa model configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -22,12 +22,12 @@ from packaging import version
from torch import _softmax_backward_data, nn
from torch.nn import CrossEntropyLoss
from .activations import ACT2FN
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_deberta import DebertaConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.get_logger(__name__)
......@@ -74,7 +74,7 @@ class XSoftmax(torch.autograd.Function):
Example::
import torch
from transformers.modeling_deroberta import XSoftmax
from transformers.models.deberta import XSoftmax
# Make a tensor
x = torch.randn([4,20,100])
# Create a mask
......@@ -278,7 +278,7 @@ class DebertaAttention(nn.Module):
return attention_output
# Copied from transformers.modeling_bert.BertIntermediate with Bert->Deberta
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta
class DebertaIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment