Unverified Commit 758ed333 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Transformers fast import part 2 (#9446)



* Main init work

* Add version

* Change from absolute to relative imports

* Fix imports

* One more typo

* More typos

* Styling

* Make quality script pass

* Add necessary replace in template

* Fix typos

* Spaces are ignored in replace for some reason

* Forgot one models.

* Fixes for import
Co-authored-by: default avatarLysandreJik <lysandre.debut@reseau.eseo.fr>

* Add documentation

* Styling
Co-authored-by: default avatarLysandreJik <lysandre.debut@reseau.eseo.fr>
parent a400fe89
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
albert,
auto,
bart,
barthez,
bert,
bert_generation,
bert_japanese,
bertweet,
blenderbot,
blenderbot_small,
camembert,
ctrl,
deberta,
dialogpt,
distilbert,
dpr,
electra,
encoder_decoder,
flaubert,
fsmt,
funnel,
gpt2,
herbert,
layoutlm,
led,
longformer,
lxmert,
marian,
mbart,
mmbt,
mobilebert,
mpnet,
mt5,
openai,
pegasus,
phobert,
prophetnet,
rag,
reformer,
retribert,
roberta,
squeezebert,
t5,
tapas,
transfo_xl,
xlm,
xlm_roberta,
xlnet,
)
......@@ -19,8 +19,8 @@ import argparse
import torch
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
from transformers.utils import logging
from ...utils import logging
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
logging.set_verbosity_info()
......
......@@ -23,15 +23,9 @@ import fairseq
import torch
from packaging import version
from transformers import (
BartConfig,
BartForConditionalGeneration,
BartForSequenceClassification,
BartModel,
BartTokenizer,
)
from transformers.models.bart.modeling_bart import _make_linear_from_emb
from transformers.utils import logging
from ...utils import logging
from . import BartConfig, BartForConditionalGeneration, BartForSequenceClassification, BartModel, BartTokenizer
from .modeling_bart import _make_linear_from_emb
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
......
......@@ -15,8 +15,7 @@
from typing import List, Optional
from transformers import add_start_docstrings
from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging
from ..roberta.tokenization_roberta import RobertaTokenizer
......
......@@ -15,8 +15,7 @@
from typing import List, Optional
from transformers import add_start_docstrings
from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
......
......@@ -28,8 +28,8 @@ import re
import tensorflow as tf
import torch
from transformers import BertConfig, BertModel
from transformers.utils import logging
from ...utils import logging
from . import BertConfig, BertModel
logging.set_verbosity_info()
......
......@@ -19,8 +19,8 @@ import argparse
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging
from ...utils import logging
from . import BertConfig, BertForPreTraining, load_tf_weights_in_bert
logging.set_verbosity_info()
......
......@@ -22,7 +22,7 @@ import numpy as np
import tensorflow as tf
import torch
from transformers import BertModel
from . import BertModel
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
......
......@@ -18,8 +18,8 @@ import argparse
import torch
from transformers import BartConfig, BartForConditionalGeneration
from transformers.utils import logging
from ...models.bart import BartConfig, BartForConditionalGeneration
from ...utils import logging
logging.set_verbosity_info()
......
......@@ -17,7 +17,7 @@ import os
import torch
from transformers.file_utils import WEIGHTS_NAME
from ...file_utils import WEIGHTS_NAME
DIALOGPT_MODELS = ["small", "medium", "large"]
......
......@@ -19,7 +19,8 @@ from pathlib import Path
import torch
from torch.serialization import default_restore_location
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
from ...models.bert import BertConfig
from . import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
CheckpointState = collections.namedtuple(
......
......@@ -19,8 +19,8 @@ import argparse
import torch
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
from transformers.utils import logging
from ...utils import logging
from . import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
logging.set_verbosity_info()
......
......@@ -31,9 +31,10 @@ import torch
from fairseq import hub_utils
from fairseq.data.dictionary import Dictionary
from transformers import WEIGHTS_NAME, logging
from transformers.models.fsmt import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...file_utils import WEIGHTS_NAME
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...utils import logging
from . import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration
logging.set_verbosity_warning()
......
......@@ -20,7 +20,7 @@ import logging
import torch
from transformers import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
from . import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
logging.basicConfig(level=logging.INFO)
......
......@@ -19,8 +19,9 @@ import argparse
import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
from transformers.utils import logging
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
from ...utils import logging
from . import GPT2Config, GPT2Model, load_tf_weights_in_gpt2
logging.set_verbosity_info()
......
......@@ -20,7 +20,7 @@ import argparse
import pytorch_lightning as pl
import torch
from transformers import LongformerForQuestionAnswering, LongformerModel
from . import LongformerForQuestionAnswering, LongformerModel
class LightningModel(pl.LightningModule):
......
......@@ -20,7 +20,7 @@ import logging
import torch
from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
from . import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
logging.basicConfig(level=logging.INFO)
......
......@@ -17,7 +17,7 @@ import os
from pathlib import Path
from typing import List, Tuple
from transformers.models.marian.convert_marian_to_pytorch import (
from .convert_marian_to_pytorch import (
FRONT_MATTER_TEMPLATE,
_parse_readme,
convert_all_sentencepiece_models,
......
......@@ -26,8 +26,8 @@ import numpy as np
import torch
from tqdm import tqdm
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
from transformers.hf_api import HfApi
from ...hf_api import HfApi
from . import MarianConfig, MarianMTModel, MarianTokenizer
def remove_suffix(text: str, suffix: str):
......
......@@ -16,9 +16,9 @@ import argparse
import torch
from transformers import BartForConditionalGeneration, MBartConfig
from ..bart import BartForConditionalGeneration
from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
from . import MBartConfig
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment