Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
from typing import List, Optional from typing import List, Optional
from .tokenization_roberta import RobertaTokenizer from ...tokenization_utils_base import BatchEncoding
from .tokenization_utils_base import BatchEncoding from ...utils import logging
from .utils import logging from ..roberta.tokenization_roberta import RobertaTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
from typing import List, Optional from typing import List, Optional
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
from .tokenization_bart import BartTokenizer from .tokenization_bart import BartTokenizer
from .tokenization_roberta_fast import RobertaTokenizerFast
from .tokenization_utils_base import BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
if is_tokenizers_available():
from .tokenization_bert_fast import BertTokenizerFast
if is_torch_available():
from .modeling_bert import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMHeadModel,
BertModel,
BertPreTrainedModel,
load_tf_weights_in_bert,
)
if is_tf_available():
from .modeling_tf_bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings,
TFBertForMaskedLM,
TFBertForMultipleChoice,
TFBertForNextSentencePrediction,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFBertForTokenClassification,
TFBertLMHeadModel,
TFBertMainLayer,
TFBertModel,
TFBertPreTrainedModel,
)
if is_flax_available():
from .modeling_flax_bert import FlaxBertModel
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" BERT model configuration """ """ BERT model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -27,16 +27,15 @@ import torch.utils.checkpoint ...@@ -27,16 +27,15 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN from ...activations import ACT2FN
from .configuration_bert import BertConfig from ...file_utils import (
from .file_utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions, BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
...@@ -47,13 +46,14 @@ from .modeling_outputs import ( ...@@ -47,13 +46,14 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import ( from ...modeling_utils import (
PreTrainedModel, PreTrainedModel,
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from .utils import logging from ...utils import logging
from .configuration_bert import BertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,10 +21,10 @@ import flax.linen as nn ...@@ -21,10 +21,10 @@ import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
from ...utils import logging
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,9 +21,8 @@ from typing import Optional, Tuple ...@@ -21,9 +21,8 @@ from typing import Optional, Tuple
import tensorflow as tf import tensorflow as tf
from .activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from .configuration_bert import BertConfig from ...file_utils import (
from .file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -31,7 +30,7 @@ from .file_utils import ( ...@@ -31,7 +30,7 @@ from .file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
TFCausalLMOutput, TFCausalLMOutput,
...@@ -42,7 +41,7 @@ from .modeling_tf_outputs import ( ...@@ -42,7 +41,7 @@ from .modeling_tf_outputs import (
TFSequenceClassifierOutput, TFSequenceClassifierOutput,
TFTokenClassifierOutput, TFTokenClassifierOutput,
) )
from .modeling_tf_utils import ( from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
...@@ -55,8 +54,9 @@ from .modeling_tf_utils import ( ...@@ -55,8 +54,9 @@ from .modeling_tf_utils import (
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from ...tokenization_utils import BatchEncoding
from .utils import logging from ...utils import logging
from .configuration_bert import BertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -20,8 +20,8 @@ import os ...@@ -20,8 +20,8 @@ import os
import unicodedata import unicodedata
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -19,9 +19,9 @@ from typing import List, Optional, Tuple ...@@ -19,9 +19,9 @@ from typing import List, Optional, Tuple
from tokenizers import normalizers from tokenizers import normalizers
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .tokenization_bert import BertTokenizer from .tokenization_bert import BertTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_torch_available
from .configuration_bert_generation import BertGenerationConfig
if is_sentencepiece_available():
from .tokenization_bert_generation import BertGenerationTokenizer
if is_torch_available():
from .modeling_bert_generation import (
BertGenerationDecoder,
BertGenerationEncoder,
load_tf_weights_in_bert_generation,
)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" BertGeneration model configuration """ """ BertGeneration model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
class BertGenerationConfig(PretrainedConfig): class BertGenerationConfig(PretrainedConfig):
......
...@@ -20,17 +20,17 @@ import torch.utils.checkpoint ...@@ -20,17 +20,17 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .configuration_bert_generation import BertGenerationConfig from ...file_utils import (
from .file_utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_bert import BertEncoder from ...modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions
from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel
from .modeling_utils import PreTrainedModel from ...utils import logging
from .utils import logging from ..bert.modeling_bert import BertEncoder
from .configuration_bert_generation import BertGenerationConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple ...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
...@@ -21,8 +21,8 @@ import os ...@@ -21,8 +21,8 @@ import os
import unicodedata import unicodedata
from typing import Optional from typing import Optional
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer, load_vocab from ...utils import logging
from .utils import logging from ..bert.tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer, load_vocab
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .tokenization_bertweet import BertweetTokenizer
...@@ -24,8 +24,8 @@ from typing import List, Optional, Tuple ...@@ -24,8 +24,8 @@ from typing import List, Optional, Tuple
import regex import regex
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment