Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
...@@ -29,9 +29,9 @@ import numpy as np ...@@ -29,9 +29,9 @@ import numpy as np
import sacremoses as sm import sacremoses as sm
from .file_utils import cached_path, is_torch_available, torch_only_method from ...file_utils import cached_path, is_torch_available, torch_only_method
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
if is_torch_available(): if is_torch_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .tokenization_xlm import XLMTokenizer
if is_torch_available():
from .modeling_xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMForMultipleChoice,
XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple,
XLMForSequenceClassification,
XLMForTokenClassification,
XLMModel,
XLMPreTrainedModel,
XLMWithLMHeadModel,
)
if is_tf_available():
from .modeling_tf_xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
TFXLMForSequenceClassification,
TFXLMForTokenClassification,
TFXLMMainLayer,
TFXLMModel,
TFXLMPreTrainedModel,
TFXLMWithLMHeadModel,
)
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" XLM configuration """ """ XLM configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -22,7 +22,7 @@ import numpy ...@@ -22,7 +22,7 @@ import numpy
import torch import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers.tokenization_xlm import VOCAB_FILES_NAMES from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES
from transformers.utils import logging from transformers.utils import logging
......
...@@ -25,23 +25,22 @@ from typing import Optional, Tuple ...@@ -25,23 +25,22 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from .configuration_xlm import XLMConfig from ...file_utils import (
from .file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from .modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFMultipleChoiceModelOutput, TFMultipleChoiceModelOutput,
TFQuestionAnsweringModelOutput, TFQuestionAnsweringModelOutput,
TFSequenceClassifierOutput, TFSequenceClassifierOutput,
TFTokenClassifierOutput, TFTokenClassifierOutput,
) )
from .modeling_tf_utils import ( from ...modeling_tf_utils import (
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
...@@ -53,8 +52,9 @@ from .modeling_tf_utils import ( ...@@ -53,8 +52,9 @@ from .modeling_tf_utils import (
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from ...tokenization_utils import BatchEncoding
from .utils import logging from ...utils import logging
from .configuration_xlm import XLMConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -29,16 +29,15 @@ from torch import nn ...@@ -29,16 +29,15 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from .activations import gelu from ...activations import gelu
from .configuration_xlm import XLMConfig from ...file_utils import (
from .file_utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
MaskedLMOutput, MaskedLMOutput,
MultipleChoiceModelOutput, MultipleChoiceModelOutput,
...@@ -46,7 +45,7 @@ from .modeling_outputs import ( ...@@ -46,7 +45,7 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import ( from ...modeling_utils import (
PreTrainedModel, PreTrainedModel,
SequenceSummary, SequenceSummary,
SQuADHead, SQuADHead,
...@@ -54,7 +53,8 @@ from .modeling_utils import ( ...@@ -54,7 +53,8 @@ from .modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from .utils import logging from ...utils import logging
from .configuration_xlm import XLMConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -24,8 +24,8 @@ from typing import List, Optional, Tuple ...@@ -24,8 +24,8 @@ from typing import List, Optional, Tuple
import sacremoses as sm import sacremoses as sm
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_torch_available
from .configuration_xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
if is_sentencepiece_available():
from .tokenization_xlm_prophetnet import XLMProphetNetTokenizer
if is_torch_available():
from .modeling_xlm_prophetnet import (
XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMProphetNetDecoder,
XLMProphetNetEncoder,
XLMProphetNetForCausalLM,
XLMProphetNetForConditionalGeneration,
XLMProphetNetModel,
)
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
""" XLM-ProphetNet model configuration """ """ XLM-ProphetNet model configuration """
from .configuration_prophetnet import ProphetNetConfig from ...utils import logging
from .utils import logging from ..prophetnet.configuration_prophetnet import ProphetNetConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
# limitations under the License. # limitations under the License.
""" PyTorch XLM-ProphetNet model.""" """ PyTorch XLM-ProphetNet model."""
from .configuration_xlm_prophetnet import XLMProphetNetConfig from ...utils import logging
from .modeling_prophetnet import ( from ..prophetnet.modeling_prophetnet import (
ProphetNetDecoder, ProphetNetDecoder,
ProphetNetEncoder, ProphetNetEncoder,
ProphetNetForCausalLM, ProphetNetForCausalLM,
ProphetNetForConditionalGeneration, ProphetNetForConditionalGeneration,
ProphetNetModel, ProphetNetModel,
) )
from .utils import logging from .configuration_xlm_prophetnet import XLMProphetNetConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -18,8 +18,8 @@ import os ...@@ -18,8 +18,8 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
if is_sentencepiece_available():
from .tokenization_xlm_roberta import XLMRobertaTokenizer
if is_tokenizers_available():
from .tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
if is_torch_available():
from .modeling_xlm_roberta import (
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMRobertaForCausalLM,
XLMRobertaForMaskedLM,
XLMRobertaForMultipleChoice,
XLMRobertaForQuestionAnswering,
XLMRobertaForSequenceClassification,
XLMRobertaForTokenClassification,
XLMRobertaModel,
)
if is_tf_available():
from .modeling_tf_xlm_roberta import (
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMRobertaForMaskedLM,
TFXLMRobertaForMultipleChoice,
TFXLMRobertaForQuestionAnswering,
TFXLMRobertaForSequenceClassification,
TFXLMRobertaForTokenClassification,
TFXLMRobertaModel,
)
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" XLM-RoBERTa configuration """ """ XLM-RoBERTa configuration """
from .configuration_roberta import RobertaConfig from ...utils import logging
from .utils import logging from ..roberta.configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 XLM-RoBERTa model. """ """ TF 2.0 XLM-RoBERTa model. """
from .configuration_xlm_roberta import XLMRobertaConfig from ...file_utils import add_start_docstrings
from .file_utils import add_start_docstrings from ...utils import logging
from .modeling_tf_roberta import ( from ..roberta.modeling_tf_roberta import (
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForMultipleChoice, TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering, TFRobertaForQuestionAnswering,
...@@ -25,7 +25,7 @@ from .modeling_tf_roberta import ( ...@@ -25,7 +25,7 @@ from .modeling_tf_roberta import (
TFRobertaForTokenClassification, TFRobertaForTokenClassification,
TFRobertaModel, TFRobertaModel,
) )
from .utils import logging from .configuration_xlm_roberta import XLMRobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
# limitations under the License. # limitations under the License.
"""PyTorch XLM-RoBERTa model. """ """PyTorch XLM-RoBERTa model. """
from .configuration_xlm_roberta import XLMRobertaConfig from ...file_utils import add_start_docstrings
from .file_utils import add_start_docstrings from ...utils import logging
from .modeling_roberta import ( from ..roberta.modeling_roberta import (
RobertaForCausalLM, RobertaForCausalLM,
RobertaForMaskedLM, RobertaForMaskedLM,
RobertaForMultipleChoice, RobertaForMultipleChoice,
...@@ -26,7 +26,7 @@ from .modeling_roberta import ( ...@@ -26,7 +26,7 @@ from .modeling_roberta import (
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaModel, RobertaModel,
) )
from .utils import logging from .configuration_xlm_roberta import XLMRobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple ...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -19,9 +19,9 @@ import os ...@@ -19,9 +19,9 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from .file_utils import is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from .tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging from ...utils import logging
if is_sentencepiece_available(): if is_sentencepiece_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
if is_sentencepiece_available():
from .tokenization_xlnet import XLNetTokenizer
if is_tokenizers_available():
from .tokenization_xlnet_fast import XLNetTokenizerFast
if is_torch_available():
from .modeling_xlnet import (
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLNetForMultipleChoice,
XLNetForQuestionAnswering,
XLNetForQuestionAnsweringSimple,
XLNetForSequenceClassification,
XLNetForTokenClassification,
XLNetLMHeadModel,
XLNetModel,
XLNetPreTrainedModel,
load_tf_weights_in_xlnet,
)
if is_tf_available():
from .modeling_tf_xlnet import (
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLNetForMultipleChoice,
TFXLNetForQuestionAnsweringSimple,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetLMHeadModel,
TFXLNetMainLayer,
TFXLNetModel,
TFXLNetPreTrainedModel,
)
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" XLNet configuration """ """ XLNet configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment