Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
...@@ -21,8 +21,7 @@ from typing import Iterable, List, Optional, Tuple ...@@ -21,8 +21,7 @@ from typing import Iterable, List, Optional, Tuple
import numpy as np import numpy as np
from .configuration_rag import RagConfig from ...file_utils import (
from .file_utils import (
cached_path, cached_path,
is_datasets_available, is_datasets_available,
is_faiss_available, is_faiss_available,
...@@ -30,9 +29,10 @@ from .file_utils import ( ...@@ -30,9 +29,10 @@ from .file_utils import (
requires_datasets, requires_datasets,
requires_faiss, requires_faiss,
) )
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer from .tokenization_rag import RagTokenizer
from .tokenization_utils_base import BatchEncoding
from .utils import logging
if is_datasets_available(): if is_datasets_available():
...@@ -105,7 +105,7 @@ class LegacyIndex(Index): ...@@ -105,7 +105,7 @@ class LegacyIndex(Index):
The dimension of indexed vectors. The dimension of indexed vectors.
index_path (:obj:`str`): index_path (:obj:`str`):
A path to a `directory` containing index files compatible with A path to a `directory` containing index files compatible with
:class:`~transformers.retrieval_rag.LegacyIndex` :class:`~transformers.models.rag.retrieval_rag.LegacyIndex`
""" """
INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index" INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
...@@ -344,7 +344,7 @@ class RagRetriever: ...@@ -344,7 +344,7 @@ class RagRetriever:
generator_tokenizer. generator_tokenizer.
generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`): generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer used for the generator part of the RagModel. The tokenizer used for the generator part of the RagModel.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration): index (:class:`~transformers.models.rag.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration If specified, use this index instead of the one built using the configuration
Examples:: Examples::
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
import os import os
from typing import List, Optional from typing import List, Optional
from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging
from .configuration_rag import RagConfig from .configuration_rag import RagConfig
from .file_utils import add_start_docstrings
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -42,7 +42,7 @@ class RagTokenizer: ...@@ -42,7 +42,7 @@ class RagTokenizer:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# dynamically import AutoTokenizer # dynamically import AutoTokenizer
from .tokenization_auto import AutoTokenizer from ..auto.tokenization_auto import AutoTokenizer
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tokenizers_available, is_torch_available
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
if is_sentencepiece_available():
from .tokenization_reformer import ReformerTokenizer
if is_tokenizers_available():
from .tokenization_reformer_fast import ReformerTokenizerFast
if is_torch_available():
from .modeling_reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,
ReformerForMaskedLM,
ReformerForQuestionAnswering,
ReformerForSequenceClassification,
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
)
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" Reformer model configuration """ """ Reformer model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -28,9 +28,8 @@ from torch import nn ...@@ -28,9 +28,8 @@ from torch import nn
from torch.autograd.function import Function from torch.autograd.function import Function
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN from ...activations import ACT2FN
from .configuration_reformer import ReformerConfig from ...file_utils import (
from .file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
ModelOutput, ModelOutput,
...@@ -38,9 +37,10 @@ from .file_utils import ( ...@@ -38,9 +37,10 @@ from .file_utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from .modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from .utils import logging from ...utils import logging
from .configuration_reformer import ReformerConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -21,8 +21,8 @@ from typing import Dict, Optional, Tuple ...@@ -21,8 +21,8 @@ from typing import Dict, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from .tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -19,9 +19,9 @@ import os ...@@ -19,9 +19,9 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
from .file_utils import is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from .tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging from ...utils import logging
if is_sentencepiece_available(): if is_sentencepiece_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tokenizers_available, is_torch_available
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .tokenization_retribert import RetriBertTokenizer
if is_tokenizers_available():
from .tokenization_retribert_fast import RetriBertTokenizerFast
if is_torch_available():
from .modeling_retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" RetriBERT model configuration """ """ RetriBERT model configuration """
from .configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from .utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -23,11 +23,11 @@ import torch ...@@ -23,11 +23,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from ...file_utils import add_start_docstrings
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ..bert.modeling_bert import BertModel
from .configuration_retribert import RetriBertConfig from .configuration_retribert import RetriBertConfig
from .file_utils import add_start_docstrings
from .modeling_bert import BertModel
from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for RetriBERT.""" """Tokenization classes for RetriBERT."""
from .tokenization_bert import BertTokenizer from ...utils import logging
from .utils import logging from ..bert.tokenization_bert import BertTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for RetriBERT.""" """Tokenization classes for RetriBERT."""
from .tokenization_bert_fast import BertTokenizerFast from ...utils import logging
from ..bert.tokenization_bert_fast import BertTokenizerFast
from .tokenization_retribert import RetriBertTokenizer from .tokenization_retribert import RetriBertTokenizer
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .tokenization_roberta import RobertaTokenizer
if is_tokenizers_available():
from .tokenization_roberta_fast import RobertaTokenizerFast
if is_torch_available():
from .modeling_roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
RobertaForCausalLM,
RobertaForMaskedLM,
RobertaForMultipleChoice,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaModel,
)
if is_tf_available():
from .modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaMainLayer,
TFRobertaModel,
TFRobertaPreTrainedModel,
)
if is_flax_available():
from .modeling_flax_roberta import FlaxRobertaModel
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" RoBERTa configuration """ """ RoBERTa configuration """
from .configuration_bert import BertConfig from ...utils import logging
from .utils import logging from ..bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -24,8 +24,18 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel ...@@ -24,8 +24,18 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version from packaging import version
from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput from transformers.models.bertmodeling_bert import (
from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification BertIntermediate,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.models.roberta.modeling_roberta import (
RobertaConfig,
RobertaForMaskedLM,
RobertaForSequenceClassification,
)
from transformers.utils import logging from transformers.utils import logging
......
...@@ -20,10 +20,10 @@ import flax.linen as nn ...@@ -20,10 +20,10 @@ import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
from ...utils import logging
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -89,7 +89,7 @@ ROBERTA_INPUTS_DOCSTRING = r""" ...@@ -89,7 +89,7 @@ ROBERTA_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta
class FlaxRobertaLayerNorm(nn.Module): class FlaxRobertaLayerNorm(nn.Module):
""" """
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
...@@ -130,7 +130,7 @@ class FlaxRobertaLayerNorm(nn.Module): ...@@ -130,7 +130,7 @@ class FlaxRobertaLayerNorm(nn.Module):
return y return y
# Copied from transformers.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta
class FlaxRobertaEmbedding(nn.Module): class FlaxRobertaEmbedding(nn.Module):
""" """
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
...@@ -147,7 +147,7 @@ class FlaxRobertaEmbedding(nn.Module): ...@@ -147,7 +147,7 @@ class FlaxRobertaEmbedding(nn.Module):
return jnp.take(embedding, inputs, axis=0) return jnp.take(embedding, inputs, axis=0)
# Copied from transformers.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
class FlaxRobertaEmbeddings(nn.Module): class FlaxRobertaEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
...@@ -179,7 +179,7 @@ class FlaxRobertaEmbeddings(nn.Module): ...@@ -179,7 +179,7 @@ class FlaxRobertaEmbeddings(nn.Module):
return layer_norm return layer_norm
# Copied from transformers.modeling_flax_bert.FlaxBertAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module): class FlaxRobertaAttention(nn.Module):
num_heads: int num_heads: int
head_size: int head_size: int
...@@ -194,7 +194,7 @@ class FlaxRobertaAttention(nn.Module): ...@@ -194,7 +194,7 @@ class FlaxRobertaAttention(nn.Module):
return layer_norm return layer_norm
# Copied from transformers.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module): class FlaxRobertaIntermediate(nn.Module):
output_size: int output_size: int
...@@ -205,7 +205,7 @@ class FlaxRobertaIntermediate(nn.Module): ...@@ -205,7 +205,7 @@ class FlaxRobertaIntermediate(nn.Module):
return gelu(dense) return gelu(dense)
# Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module): class FlaxRobertaOutput(nn.Module):
@nn.compact @nn.compact
def __call__(self, intermediate_output, attention_output): def __call__(self, intermediate_output, attention_output):
...@@ -230,7 +230,7 @@ class FlaxRobertaLayer(nn.Module): ...@@ -230,7 +230,7 @@ class FlaxRobertaLayer(nn.Module):
return output return output
# Copied from transformers.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
class FlaxRobertaLayerCollection(nn.Module): class FlaxRobertaLayerCollection(nn.Module):
""" """
Stores N RobertaLayer(s) Stores N RobertaLayer(s)
...@@ -255,7 +255,7 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -255,7 +255,7 @@ class FlaxRobertaLayerCollection(nn.Module):
return input_i return input_i
# Copied from transformers.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
class FlaxRobertaEncoder(nn.Module): class FlaxRobertaEncoder(nn.Module):
num_layers: int num_layers: int
num_heads: int num_heads: int
...@@ -270,7 +270,7 @@ class FlaxRobertaEncoder(nn.Module): ...@@ -270,7 +270,7 @@ class FlaxRobertaEncoder(nn.Module):
return layer return layer
# Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module): class FlaxRobertaPooler(nn.Module):
@nn.compact @nn.compact
def __call__(self, hidden_state): def __call__(self, hidden_state):
...@@ -279,7 +279,7 @@ class FlaxRobertaPooler(nn.Module): ...@@ -279,7 +279,7 @@ class FlaxRobertaPooler(nn.Module):
return jax.lax.tanh(out) return jax.lax.tanh(out)
# Copied from transformers.modeling_flax_bert.FlaxBertModule with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module): class FlaxRobertaModule(nn.Module):
vocab_size: int vocab_size: int
hidden_size: int hidden_size: int
......
...@@ -22,15 +22,14 @@ import torch ...@@ -22,15 +22,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN, gelu from ...activations import ACT2FN, gelu
from .configuration_roberta import RobertaConfig from ...file_utils import (
from .file_utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions, BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
...@@ -40,13 +39,14 @@ from .modeling_outputs import ( ...@@ -40,13 +39,14 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import ( from ...modeling_utils import (
PreTrainedModel, PreTrainedModel,
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from .utils import logging from ...utils import logging
from .configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -70,7 +70,7 @@ class RobertaEmbeddings(nn.Module): ...@@ -70,7 +70,7 @@ class RobertaEmbeddings(nn.Module):
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
""" """
# Copied from transformers.modeling_bert.BertEmbeddings.__init__ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
...@@ -99,7 +99,7 @@ class RobertaEmbeddings(nn.Module): ...@@ -99,7 +99,7 @@ class RobertaEmbeddings(nn.Module):
else: else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
# Copied from transformers.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -141,7 +141,7 @@ class RobertaEmbeddings(nn.Module): ...@@ -141,7 +141,7 @@ class RobertaEmbeddings(nn.Module):
return position_ids.unsqueeze(0).expand(input_shape) return position_ids.unsqueeze(0).expand(input_shape)
# Copied from transformers.modeling_bert.BertSelfAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
class RobertaSelfAttention(nn.Module): class RobertaSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -220,7 +220,7 @@ class RobertaSelfAttention(nn.Module): ...@@ -220,7 +220,7 @@ class RobertaSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.modeling_bert.BertSelfOutput # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
class RobertaSelfOutput(nn.Module): class RobertaSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -235,7 +235,7 @@ class RobertaSelfOutput(nn.Module): ...@@ -235,7 +235,7 @@ class RobertaSelfOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.modeling_bert.BertAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
class RobertaAttention(nn.Module): class RobertaAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -283,7 +283,7 @@ class RobertaAttention(nn.Module): ...@@ -283,7 +283,7 @@ class RobertaAttention(nn.Module):
return outputs return outputs
# Copied from transformers.modeling_bert.BertIntermediate # Copied from transformers.models.bert.modeling_bert.BertIntermediate
class RobertaIntermediate(nn.Module): class RobertaIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -299,7 +299,7 @@ class RobertaIntermediate(nn.Module): ...@@ -299,7 +299,7 @@ class RobertaIntermediate(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.modeling_bert.BertOutput # Copied from transformers.models.bert.modeling_bert.BertOutput
class RobertaOutput(nn.Module): class RobertaOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -314,7 +314,7 @@ class RobertaOutput(nn.Module): ...@@ -314,7 +314,7 @@ class RobertaOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.modeling_bert.BertLayer with Bert->Roberta # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
class RobertaLayer(nn.Module): class RobertaLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -374,7 +374,7 @@ class RobertaLayer(nn.Module): ...@@ -374,7 +374,7 @@ class RobertaLayer(nn.Module):
return layer_output return layer_output
# Copied from transformers.modeling_bert.BertEncoder with Bert->Roberta # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
class RobertaEncoder(nn.Module): class RobertaEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -449,7 +449,7 @@ class RobertaEncoder(nn.Module): ...@@ -449,7 +449,7 @@ class RobertaEncoder(nn.Module):
) )
# Copied from transformers.modeling_bert.BertPooler # Copied from transformers.models.bert.modeling_bert.BertPooler
class RobertaPooler(nn.Module): class RobertaPooler(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -474,7 +474,7 @@ class RobertaPreTrainedModel(PreTrainedModel): ...@@ -474,7 +474,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
config_class = RobertaConfig config_class = RobertaConfig
base_model_prefix = "roberta" base_model_prefix = "roberta"
# Copied from transformers.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
...@@ -579,7 +579,7 @@ class RobertaModel(RobertaPreTrainedModel): ...@@ -579,7 +579,7 @@ class RobertaModel(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
# Copied from transformers.modeling_bert.BertModel.__init__ with Bert->Roberta # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -612,7 +612,7 @@ class RobertaModel(RobertaPreTrainedModel): ...@@ -612,7 +612,7 @@ class RobertaModel(RobertaPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions, output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
# Copied from transformers.modeling_bert.BertModel.forward # Copied from transformers.models.bert.modeling_bert.BertModel.forward
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -18,15 +18,14 @@ ...@@ -18,15 +18,14 @@
import tensorflow as tf import tensorflow as tf
from .activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from .configuration_roberta import RobertaConfig from ...file_utils import (
from .file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from .modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
TFMaskedLMOutput, TFMaskedLMOutput,
...@@ -35,7 +34,7 @@ from .modeling_tf_outputs import ( ...@@ -35,7 +34,7 @@ from .modeling_tf_outputs import (
TFSequenceClassifierOutput, TFSequenceClassifierOutput,
TFTokenClassifierOutput, TFTokenClassifierOutput,
) )
from .modeling_tf_utils import ( from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
TFPreTrainedModel, TFPreTrainedModel,
...@@ -46,8 +45,9 @@ from .modeling_tf_utils import ( ...@@ -46,8 +45,9 @@ from .modeling_tf_utils import (
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from .utils import logging from ...utils import logging
from .configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -223,7 +223,7 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -223,7 +223,7 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
return tf.reshape(logits, [batch_size, length, self.vocab_size]) return tf.reshape(logits, [batch_size, length, self.vocab_size])
# Copied from transformers.modeling_tf_bert.TFBertPooler # Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler
class TFRobertaPooler(tf.keras.layers.Layer): class TFRobertaPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -244,7 +244,7 @@ class TFRobertaPooler(tf.keras.layers.Layer): ...@@ -244,7 +244,7 @@ class TFRobertaPooler(tf.keras.layers.Layer):
return pooled_output return pooled_output
# Copied from transformers.modeling_tf_bert.TFBertSelfAttention # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention
class TFRobertaSelfAttention(tf.keras.layers.Layer): class TFRobertaSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -316,7 +316,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ...@@ -316,7 +316,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
return outputs return outputs
# Copied from transformers.modeling_tf_bert.TFBertSelfOutput # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput
class TFRobertaSelfOutput(tf.keras.layers.Layer): class TFRobertaSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -335,7 +335,7 @@ class TFRobertaSelfOutput(tf.keras.layers.Layer): ...@@ -335,7 +335,7 @@ class TFRobertaSelfOutput(tf.keras.layers.Layer):
return hidden_states return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Roberta
class TFRobertaAttention(tf.keras.layers.Layer): class TFRobertaAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -356,7 +356,7 @@ class TFRobertaAttention(tf.keras.layers.Layer): ...@@ -356,7 +356,7 @@ class TFRobertaAttention(tf.keras.layers.Layer):
return outputs return outputs
# Copied from transformers.modeling_tf_bert.TFBertIntermediate # Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate
class TFRobertaIntermediate(tf.keras.layers.Layer): class TFRobertaIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -377,7 +377,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer): ...@@ -377,7 +377,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer):
return hidden_states return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertOutput # Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput
class TFRobertaOutput(tf.keras.layers.Layer): class TFRobertaOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -396,7 +396,7 @@ class TFRobertaOutput(tf.keras.layers.Layer): ...@@ -396,7 +396,7 @@ class TFRobertaOutput(tf.keras.layers.Layer):
return hidden_states return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertLayer with Bert->Roberta # Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Roberta
class TFRobertaLayer(tf.keras.layers.Layer): class TFRobertaLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -417,7 +417,7 @@ class TFRobertaLayer(tf.keras.layers.Layer): ...@@ -417,7 +417,7 @@ class TFRobertaLayer(tf.keras.layers.Layer):
return outputs return outputs
# Copied from transformers.modeling_tf_bert.TFBertEncoder with Bert->Roberta # Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Roberta
class TFRobertaEncoder(tf.keras.layers.Layer): class TFRobertaEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -478,16 +478,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -478,16 +478,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# The embeddings must be the last declaration in order to follow the weights order # The embeddings must be the last declaration in order to follow the weights order
self.embeddings = TFRobertaEmbeddings(config, name="embeddings") self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.get_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.set_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = value.shape[0]
# Copied from transformers.modeling_tf_bert.TFBertMainLayer._prune_heads # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
...@@ -495,7 +495,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -495,7 +495,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.call # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
def call( def call(
self, self,
inputs, inputs,
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
import warnings import warnings
from typing import List, Optional from typing import List, Optional
from .tokenization_gpt2 import GPT2Tokenizer from ...tokenization_utils import AddedToken
from .tokenization_utils import AddedToken from ...utils import logging
from .utils import logging from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment