Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
......@@ -21,17 +21,16 @@ from typing import Optional, Tuple
import tensorflow as tf
from .activations_tf import get_tf_activation
from .configuration_openai import OpenAIGPTConfig
from .file_utils import (
from ...activations_tf import get_tf_activation
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from .modeling_tf_utils import (
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFConv1D,
TFPreTrainedModel,
......@@ -41,8 +40,9 @@ from .modeling_tf_utils import (
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
from .utils import logging
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_openai import OpenAIGPTConfig
logger = logging.get_logger(__name__)
......
......@@ -20,9 +20,9 @@ import os
import re
from typing import Optional, Tuple
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
from ..bert.tokenization_bert import BasicTokenizer
logger = logging.get_logger(__name__)
......
......@@ -17,9 +17,9 @@
from typing import Optional, Tuple
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_pegasus import PegasusConfig
if is_sentencepiece_available():
from .tokenization_pegasus import PegasusTokenizer
if is_tokenizers_available():
from .tokenization_pegasus_fast import PegasusTokenizerFast
if is_torch_available():
from .modeling_pegasus import PegasusForConditionalGeneration
if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
......@@ -14,8 +14,8 @@
# limitations under the License.
""" PEGASUS model configuration """
from .configuration_bart import BartConfig
from .utils import logging
from ...utils import logging
from ..bart.configuration_bart import BartConfig
logger = logging.get_logger(__name__)
......
......@@ -23,7 +23,7 @@ import torch
from tqdm import tqdm
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
from transformers.configuration_pegasus import DEFAULTS, task_specific_params
from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params
PATTERNS = [
......
......@@ -15,9 +15,9 @@
"""PyTorch Pegasus model, ported from https://github.com/google-research/pegasus"""
from ...file_utils import add_start_docstrings
from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
from .configuration_pegasus import PegasusConfig
from .file_utils import add_start_docstrings
from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
......@@ -44,7 +44,7 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration):
>>> assert summary == "California's largest electricity provider has turned off power to tens of thousands of customers."
"""
# All the code is in src/transformers/modeling_bart.py
# All the code is in src/transformers/models/bart/modeling_bart.py
config_class = PegasusConfig
authorized_missing_keys = [
r"final_logits_bias",
......
......@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF Pegasus model, ported from the fairseq repo."""
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..bart.modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration
from .configuration_pegasus import PegasusConfig
from .file_utils import add_start_docstrings
from .modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration
from .utils import logging
_CONFIG_FOR_DOC = "PegasusConfig"
......@@ -38,4 +38,4 @@ class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration):
r"model.decoder.embed_positions.weight",
]
config_class = PegasusConfig
# All the code is in src/transformers/modeling_tf_bart.py
# All the code is in src/transformers/models/bart/modeling_tf_bart.py
......@@ -14,9 +14,9 @@
# limitations under the License.
from typing import Dict, List, Optional
from .file_utils import add_start_docstrings
from .tokenization_reformer import ReformerTokenizer
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ..reformer.tokenization_reformer import ReformerTokenizer
SPIECE_UNDERLINE = "▁"
......
......@@ -14,9 +14,9 @@
# limitations under the License.
from typing import List, Optional
from .file_utils import add_start_docstrings, is_sentencepiece_available
from .tokenization_reformer_fast import ReformerTokenizerFast
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...file_utils import add_start_docstrings, is_sentencepiece_available
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
if is_sentencepiece_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .tokenization_phobert import PhobertTokenizer
......@@ -21,8 +21,8 @@ import re
from shutil import copyfile
from typing import List, Optional, Tuple
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_torch_available
from .configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
from .tokenization_prophetnet import ProphetNetTokenizer
if is_torch_available():
from .modeling_prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder,
ProphetNetEncoder,
ProphetNetForCausalLM,
ProphetNetForConditionalGeneration,
ProphetNetModel,
ProphetNetPreTrainedModel,
)
......@@ -15,8 +15,8 @@
""" ProphetNet model configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -19,9 +19,7 @@ import argparse
import torch
from transformers import logging
from transformers.modeling_prophetnet import ProphetNetForConditionalGeneration
from transformers.modeling_xlm_prophetnet import XLMProphetNetForConditionalGeneration
from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
......
......@@ -24,17 +24,17 @@ import torch
import torch.nn.functional as F
from torch import Tensor, nn
from .activations import ACT2FN
from .configuration_prophetnet import ProphetNetConfig
from .file_utils import (
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutput
from .modeling_utils import PreTrainedModel
from .utils import logging
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_prophetnet import ProphetNetConfig
logger = logging.get_logger(__name__)
......
......@@ -17,11 +17,11 @@ import collections
import os
from typing import List, Optional, Tuple
from .file_utils import add_start_docstrings
from .tokenization_bert import BasicTokenizer, WordpieceTokenizer
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .utils import logging
from ...file_utils import add_start_docstrings
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from ...utils import logging
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_torch_available
from .configuration_rag import RagConfig
from .retrieval_rag import RagRetriever
from .tokenization_rag import RagTokenizer
if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
......@@ -16,8 +16,8 @@
import copy
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings
RAG_CONFIG_DOC = r"""
......@@ -53,7 +53,7 @@ RAG_CONFIG_DOC = r"""
The path to the serialized faiss index on disk.
passages_path: (:obj:`str`, `optional`):
A path to text passages compatible with the faiss index. Required if using
:class:`~transformers.retrieval_rag.LegacyIndex`
:class:`~transformers.models.rag.retrieval_rag.LegacyIndex`
use_dummy_dataset (:obj:`bool`, `optional`, defaults to ``False``)
Whether to load a "dummy" variant of the dataset specified by :obj:`dataset`.
label_smoothing (:obj:`float`, `optional`, defaults to 0.0):
......@@ -127,7 +127,7 @@ class RagConfig(PretrainedConfig):
decoder_config = kwargs.pop("generator")
decoder_model_type = decoder_config.pop("model_type")
from .configuration_auto import AutoConfig
from ..auto.configuration_auto import AutoConfig
self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
......
......@@ -19,14 +19,14 @@ from typing import List, Optional, Tuple
import torch
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from ...generation_beam_search import BeamSearchScorer
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_rag import RagConfig
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from .generation_beam_search import BeamSearchScorer
from .modeling_outputs import ModelOutput
from .modeling_utils import PreTrainedModel
from .retrieval_rag import RagRetriever
from .utils import logging
logger = logging.get_logger(__name__)
......@@ -316,10 +316,10 @@ class RagPreTrainedModel(PreTrainedModel):
assert (
question_encoder_pretrained_model_name_or_path is not None
), "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined"
from .modeling_auto import AutoModel
from ..auto.modeling_auto import AutoModel
if "config" not in kwargs_question_encoder:
from .configuration_auto import AutoConfig
from ..auto.configuration_auto import AutoConfig
question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path)
kwargs_question_encoder["config"] = question_encoder_config
......@@ -333,10 +333,10 @@ class RagPreTrainedModel(PreTrainedModel):
assert (
generator_pretrained_model_name_or_path is not None
), "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be defined"
from .modeling_auto import AutoModelForSeq2SeqLM
from ..auto.modeling_auto import AutoModelForSeq2SeqLM
if "config" not in kwargs_generator:
from .configuration_auto import AutoConfig
from ..auto.configuration_auto import AutoConfig
generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path)
kwargs_generator["config"] = generator_config
......@@ -484,12 +484,12 @@ class RagModel(RagPreTrainedModel):
)
super().__init__(config)
if question_encoder is None:
from .modeling_auto import AutoModel
from ..auto.modeling_auto import AutoModel
question_encoder = AutoModel.from_config(config.question_encoder)
if generator is None:
from .modeling_auto import AutoModelForSeq2SeqLM
from ..auto.modeling_auto import AutoModelForSeq2SeqLM
generator = AutoModelForSeq2SeqLM.from_config(config.generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment