"megatron/git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "6abf39be51d3ffed5ff78983aa57d272a6e6d820"
Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
......@@ -29,9 +29,9 @@ import numpy as np
import sacremoses as sm
from .file_utils import cached_path, is_torch_available, torch_only_method
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...file_utils import cached_path, is_torch_available, torch_only_method
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
if is_torch_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .tokenization_xlm import XLMTokenizer
if is_torch_available():
from .modeling_xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMForMultipleChoice,
XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple,
XLMForSequenceClassification,
XLMForTokenClassification,
XLMModel,
XLMPreTrainedModel,
XLMWithLMHeadModel,
)
if is_tf_available():
from .modeling_tf_xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
TFXLMForSequenceClassification,
TFXLMForTokenClassification,
TFXLMMainLayer,
TFXLMModel,
TFXLMPreTrainedModel,
TFXLMWithLMHeadModel,
)
......@@ -14,8 +14,8 @@
# limitations under the License.
""" XLM configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -22,7 +22,7 @@ import numpy
import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers.tokenization_xlm import VOCAB_FILES_NAMES
from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES
from transformers.utils import logging
......
......@@ -25,23 +25,22 @@ from typing import Optional, Tuple
import numpy as np
import tensorflow as tf
from .activations_tf import get_tf_activation
from .configuration_xlm import XLMConfig
from .file_utils import (
from ...activations_tf import get_tf_activation
from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_tf_outputs import (
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFMultipleChoiceModelOutput,
TFQuestionAnsweringModelOutput,
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from .modeling_tf_utils import (
from ...modeling_tf_utils import (
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
......@@ -53,8 +52,9 @@ from .modeling_tf_utils import (
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
from .utils import logging
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_xlm import XLMConfig
logger = logging.get_logger(__name__)
......
......@@ -29,16 +29,15 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from .activations import gelu
from .configuration_xlm import XLMConfig
from .file_utils import (
from ...activations import gelu
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_outputs import (
from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
MultipleChoiceModelOutput,
......@@ -46,7 +45,7 @@ from .modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import (
from ...modeling_utils import (
PreTrainedModel,
SequenceSummary,
SQuADHead,
......@@ -54,7 +53,8 @@ from .modeling_utils import (
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from .utils import logging
from ...utils import logging
from .configuration_xlm import XLMConfig
logger = logging.get_logger(__name__)
......
......@@ -24,8 +24,8 @@ from typing import List, Optional, Tuple
import sacremoses as sm
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_torch_available
from .configuration_xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
if is_sentencepiece_available():
from .tokenization_xlm_prophetnet import XLMProphetNetTokenizer
if is_torch_available():
from .modeling_xlm_prophetnet import (
XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMProphetNetDecoder,
XLMProphetNetEncoder,
XLMProphetNetForCausalLM,
XLMProphetNetForConditionalGeneration,
XLMProphetNetModel,
)
......@@ -15,8 +15,8 @@
""" XLM-ProphetNet model configuration """
from .configuration_prophetnet import ProphetNetConfig
from .utils import logging
from ...utils import logging
from ..prophetnet.configuration_prophetnet import ProphetNetConfig
logger = logging.get_logger(__name__)
......
......@@ -14,15 +14,15 @@
# limitations under the License.
""" PyTorch XLM-ProphetNet model."""
from .configuration_xlm_prophetnet import XLMProphetNetConfig
from .modeling_prophetnet import (
from ...utils import logging
from ..prophetnet.modeling_prophetnet import (
ProphetNetDecoder,
ProphetNetEncoder,
ProphetNetForCausalLM,
ProphetNetForConditionalGeneration,
ProphetNetModel,
)
from .utils import logging
from .configuration_xlm_prophetnet import XLMProphetNetConfig
logger = logging.get_logger(__name__)
......
......@@ -18,8 +18,8 @@ import os
from shutil import copyfile
from typing import List, Optional, Tuple
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
if is_sentencepiece_available():
from .tokenization_xlm_roberta import XLMRobertaTokenizer
if is_tokenizers_available():
from .tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
if is_torch_available():
from .modeling_xlm_roberta import (
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMRobertaForCausalLM,
XLMRobertaForMaskedLM,
XLMRobertaForMultipleChoice,
XLMRobertaForQuestionAnswering,
XLMRobertaForSequenceClassification,
XLMRobertaForTokenClassification,
XLMRobertaModel,
)
if is_tf_available():
from .modeling_tf_xlm_roberta import (
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMRobertaForMaskedLM,
TFXLMRobertaForMultipleChoice,
TFXLMRobertaForQuestionAnswering,
TFXLMRobertaForSequenceClassification,
TFXLMRobertaForTokenClassification,
TFXLMRobertaModel,
)
......@@ -15,8 +15,8 @@
# limitations under the License.
""" XLM-RoBERTa configuration """
from .configuration_roberta import RobertaConfig
from .utils import logging
from ...utils import logging
from ..roberta.configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__)
......
......@@ -15,9 +15,9 @@
# limitations under the License.
""" TF 2.0 XLM-RoBERTa model. """
from .configuration_xlm_roberta import XLMRobertaConfig
from .file_utils import add_start_docstrings
from .modeling_tf_roberta import (
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..roberta.modeling_tf_roberta import (
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
......@@ -25,7 +25,7 @@ from .modeling_tf_roberta import (
TFRobertaForTokenClassification,
TFRobertaModel,
)
from .utils import logging
from .configuration_xlm_roberta import XLMRobertaConfig
logger = logging.get_logger(__name__)
......
......@@ -15,9 +15,9 @@
# limitations under the License.
"""PyTorch XLM-RoBERTa model. """
from .configuration_xlm_roberta import XLMRobertaConfig
from .file_utils import add_start_docstrings
from .modeling_roberta import (
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..roberta.modeling_roberta import (
RobertaForCausalLM,
RobertaForMaskedLM,
RobertaForMultipleChoice,
......@@ -26,7 +26,7 @@ from .modeling_roberta import (
RobertaForTokenClassification,
RobertaModel,
)
from .utils import logging
from .configuration_xlm_roberta import XLMRobertaConfig
logger = logging.get_logger(__name__)
......
......@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
import sentencepiece as spm
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -19,9 +19,9 @@ import os
from shutil import copyfile
from typing import List, Optional, Tuple
from .file_utils import is_sentencepiece_available
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
from ...file_utils import is_sentencepiece_available
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
if is_sentencepiece_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
if is_sentencepiece_available():
from .tokenization_xlnet import XLNetTokenizer
if is_tokenizers_available():
from .tokenization_xlnet_fast import XLNetTokenizerFast
if is_torch_available():
from .modeling_xlnet import (
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLNetForMultipleChoice,
XLNetForQuestionAnswering,
XLNetForQuestionAnsweringSimple,
XLNetForSequenceClassification,
XLNetForTokenClassification,
XLNetLMHeadModel,
XLNetModel,
XLNetPreTrainedModel,
load_tf_weights_in_xlnet,
)
if is_tf_available():
from .modeling_tf_xlnet import (
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLNetForMultipleChoice,
TFXLNetForQuestionAnsweringSimple,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetLMHeadModel,
TFXLNetMainLayer,
TFXLNetModel,
TFXLNetPreTrainedModel,
)
......@@ -15,8 +15,8 @@
# limitations under the License.
""" XLNet configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment