Unverified Commit c89bdfbe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
parent 90150733
......@@ -21,8 +21,7 @@ from typing import Iterable, List, Optional, Tuple
import numpy as np
from .configuration_rag import RagConfig
from .file_utils import (
from ...file_utils import (
cached_path,
is_datasets_available,
is_faiss_available,
......@@ -30,9 +29,10 @@ from .file_utils import (
requires_datasets,
requires_faiss,
)
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer
from .tokenization_utils_base import BatchEncoding
from .utils import logging
if is_datasets_available():
......@@ -105,7 +105,7 @@ class LegacyIndex(Index):
The dimension of indexed vectors.
index_path (:obj:`str`):
A path to a `directory` containing index files compatible with
:class:`~transformers.retrieval_rag.LegacyIndex`
:class:`~transformers.models.rag.retrieval_rag.LegacyIndex`
"""
INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
......@@ -344,7 +344,7 @@ class RagRetriever:
generator_tokenizer.
generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
index (:class:`~transformers.models.rag.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
Examples::
......
......@@ -16,10 +16,10 @@
import os
from typing import List, Optional
from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging
from .configuration_rag import RagConfig
from .file_utils import add_start_docstrings
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__)
......@@ -42,7 +42,7 @@ class RagTokenizer:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# dynamically import AutoTokenizer
from .tokenization_auto import AutoTokenizer
from ..auto.tokenization_auto import AutoTokenizer
config = kwargs.pop("config", None)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_sentencepiece_available, is_tokenizers_available, is_torch_available
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
if is_sentencepiece_available():
from .tokenization_reformer import ReformerTokenizer
if is_tokenizers_available():
from .tokenization_reformer_fast import ReformerTokenizerFast
if is_torch_available():
from .modeling_reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,
ReformerForMaskedLM,
ReformerForQuestionAnswering,
ReformerForSequenceClassification,
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
)
......@@ -15,8 +15,8 @@
# limitations under the License.
""" Reformer model configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -28,9 +28,8 @@ from torch import nn
from torch.autograd.function import Function
from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN
from .configuration_reformer import ReformerConfig
from .file_utils import (
from ...activations import ACT2FN
from ...file_utils import (
DUMMY_INPUTS,
DUMMY_MASK,
ModelOutput,
......@@ -38,9 +37,10 @@ from .file_utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward
from .utils import logging
from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...utils import logging
from .configuration_reformer import ReformerConfig
logger = logging.get_logger(__name__)
......
......@@ -21,8 +21,8 @@ from typing import Dict, Optional, Tuple
import sentencepiece as spm
from .tokenization_utils import PreTrainedTokenizer
from .utils import logging
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -19,9 +19,9 @@ import os
from shutil import copyfile
from typing import Optional, Tuple
from .file_utils import is_sentencepiece_available
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
from ...file_utils import is_sentencepiece_available
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
if is_sentencepiece_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tokenizers_available, is_torch_available
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .tokenization_retribert import RetriBertTokenizer
if is_tokenizers_available():
from .tokenization_retribert_fast import RetriBertTokenizerFast
if is_torch_available():
from .modeling_retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
......@@ -14,8 +14,8 @@
# limitations under the License.
""" RetriBERT model configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
......
......@@ -23,11 +23,11 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from ...file_utils import add_start_docstrings
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ..bert.modeling_bert import BertModel
from .configuration_retribert import RetriBertConfig
from .file_utils import add_start_docstrings
from .modeling_bert import BertModel
from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.get_logger(__name__)
......
......@@ -14,8 +14,8 @@
# limitations under the License.
"""Tokenization classes for RetriBERT."""
from .tokenization_bert import BertTokenizer
from .utils import logging
from ...utils import logging
from ..bert.tokenization_bert import BertTokenizer
logger = logging.get_logger(__name__)
......
......@@ -14,9 +14,9 @@
# limitations under the License.
"""Tokenization classes for RetriBERT."""
from .tokenization_bert_fast import BertTokenizerFast
from ...utils import logging
from ..bert.tokenization_bert_fast import BertTokenizerFast
from .tokenization_retribert import RetriBertTokenizer
from .utils import logging
logger = logging.get_logger(__name__)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .tokenization_roberta import RobertaTokenizer
if is_tokenizers_available():
from .tokenization_roberta_fast import RobertaTokenizerFast
if is_torch_available():
from .modeling_roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
RobertaForCausalLM,
RobertaForMaskedLM,
RobertaForMultipleChoice,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaModel,
)
if is_tf_available():
from .modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaMainLayer,
TFRobertaModel,
TFRobertaPreTrainedModel,
)
if is_flax_available():
from .modeling_flax_roberta import FlaxRobertaModel
......@@ -15,8 +15,8 @@
# limitations under the License.
""" RoBERTa configuration """
from .configuration_bert import BertConfig
from .utils import logging
from ...utils import logging
from ..bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__)
......
......@@ -24,8 +24,18 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version
from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
from transformers.models.bertmodeling_bert import (
BertIntermediate,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.models.roberta.modeling_roberta import (
RobertaConfig,
RobertaForMaskedLM,
RobertaForSequenceClassification,
)
from transformers.utils import logging
......
......@@ -20,10 +20,10 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
from ...utils import logging
from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
from .utils import logging
logger = logging.get_logger(__name__)
......@@ -89,7 +89,7 @@ ROBERTA_INPUTS_DOCSTRING = r"""
"""
# Copied from transformers.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta
class FlaxRobertaLayerNorm(nn.Module):
"""
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
......@@ -130,7 +130,7 @@ class FlaxRobertaLayerNorm(nn.Module):
return y
# Copied from transformers.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta
class FlaxRobertaEmbedding(nn.Module):
"""
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
......@@ -147,7 +147,7 @@ class FlaxRobertaEmbedding(nn.Module):
return jnp.take(embedding, inputs, axis=0)
# Copied from transformers.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
class FlaxRobertaEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
......@@ -179,7 +179,7 @@ class FlaxRobertaEmbeddings(nn.Module):
return layer_norm
# Copied from transformers.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module):
num_heads: int
head_size: int
......@@ -194,7 +194,7 @@ class FlaxRobertaAttention(nn.Module):
return layer_norm
# Copied from transformers.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module):
output_size: int
......@@ -205,7 +205,7 @@ class FlaxRobertaIntermediate(nn.Module):
return gelu(dense)
# Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module):
@nn.compact
def __call__(self, intermediate_output, attention_output):
......@@ -230,7 +230,7 @@ class FlaxRobertaLayer(nn.Module):
return output
# Copied from transformers.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
class FlaxRobertaLayerCollection(nn.Module):
"""
Stores N RobertaLayer(s)
......@@ -255,7 +255,7 @@ class FlaxRobertaLayerCollection(nn.Module):
return input_i
# Copied from transformers.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
class FlaxRobertaEncoder(nn.Module):
num_layers: int
num_heads: int
......@@ -270,7 +270,7 @@ class FlaxRobertaEncoder(nn.Module):
return layer
# Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module):
@nn.compact
def __call__(self, hidden_state):
......@@ -279,7 +279,7 @@ class FlaxRobertaPooler(nn.Module):
return jax.lax.tanh(out)
# Copied from transformers.modeling_flax_bert.FlaxBertModule with Bert->Roberta
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module):
vocab_size: int
hidden_size: int
......
......@@ -22,15 +22,14 @@ import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN, gelu
from .configuration_roberta import RobertaConfig
from .file_utils import (
from ...activations import ACT2FN, gelu
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_outputs import (
from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
......@@ -40,13 +39,14 @@ from .modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import (
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from .utils import logging
from ...utils import logging
from .configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__)
......@@ -70,7 +70,7 @@ class RobertaEmbeddings(nn.Module):
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
# Copied from transformers.modeling_bert.BertEmbeddings.__init__
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
......@@ -99,7 +99,7 @@ class RobertaEmbeddings(nn.Module):
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
# Copied from transformers.modeling_bert.BertEmbeddings.forward
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
if input_ids is not None:
input_shape = input_ids.size()
else:
......@@ -141,7 +141,7 @@ class RobertaEmbeddings(nn.Module):
return position_ids.unsqueeze(0).expand(input_shape)
# Copied from transformers.modeling_bert.BertSelfAttention with Bert->Roberta
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
class RobertaSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -220,7 +220,7 @@ class RobertaSelfAttention(nn.Module):
return outputs
# Copied from transformers.modeling_bert.BertSelfOutput
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
class RobertaSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -235,7 +235,7 @@ class RobertaSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.modeling_bert.BertAttention with Bert->Roberta
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
class RobertaAttention(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -283,7 +283,7 @@ class RobertaAttention(nn.Module):
return outputs
# Copied from transformers.modeling_bert.BertIntermediate
# Copied from transformers.models.bert.modeling_bert.BertIntermediate
class RobertaIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -299,7 +299,7 @@ class RobertaIntermediate(nn.Module):
return hidden_states
# Copied from transformers.modeling_bert.BertOutput
# Copied from transformers.models.bert.modeling_bert.BertOutput
class RobertaOutput(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -314,7 +314,7 @@ class RobertaOutput(nn.Module):
return hidden_states
# Copied from transformers.modeling_bert.BertLayer with Bert->Roberta
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
class RobertaLayer(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -374,7 +374,7 @@ class RobertaLayer(nn.Module):
return layer_output
# Copied from transformers.modeling_bert.BertEncoder with Bert->Roberta
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
class RobertaEncoder(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -449,7 +449,7 @@ class RobertaEncoder(nn.Module):
)
# Copied from transformers.modeling_bert.BertPooler
# Copied from transformers.models.bert.modeling_bert.BertPooler
class RobertaPooler(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -474,7 +474,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
# Copied from transformers.modeling_bert.BertPreTrainedModel._init_weights
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
......@@ -579,7 +579,7 @@ class RobertaModel(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"]
# Copied from transformers.modeling_bert.BertModel.__init__ with Bert->Roberta
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
......@@ -612,7 +612,7 @@ class RobertaModel(RobertaPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
# Copied from transformers.modeling_bert.BertModel.forward
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
def forward(
self,
input_ids=None,
......
......@@ -18,15 +18,14 @@
import tensorflow as tf
from .activations_tf import get_tf_activation
from .configuration_roberta import RobertaConfig
from .file_utils import (
from ...activations_tf import get_tf_activation
from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_tf_outputs import (
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
......@@ -35,7 +34,7 @@ from .modeling_tf_outputs import (
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from .modeling_tf_utils import (
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFPreTrainedModel,
......@@ -46,8 +45,9 @@ from .modeling_tf_utils import (
keras_serializable,
shape_list,
)
from .tokenization_utils_base import BatchEncoding
from .utils import logging
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__)
......@@ -223,7 +223,7 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
return tf.reshape(logits, [batch_size, length, self.vocab_size])
# Copied from transformers.modeling_tf_bert.TFBertPooler
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler
class TFRobertaPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -244,7 +244,7 @@ class TFRobertaPooler(tf.keras.layers.Layer):
return pooled_output
# Copied from transformers.modeling_tf_bert.TFBertSelfAttention
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention
class TFRobertaSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -316,7 +316,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
return outputs
# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput
class TFRobertaSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -335,7 +335,7 @@ class TFRobertaSelfOutput(tf.keras.layers.Layer):
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertAttention with Bert->Roberta
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Roberta
class TFRobertaAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -356,7 +356,7 @@ class TFRobertaAttention(tf.keras.layers.Layer):
return outputs
# Copied from transformers.modeling_tf_bert.TFBertIntermediate
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate
class TFRobertaIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -377,7 +377,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer):
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertOutput
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput
class TFRobertaOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -396,7 +396,7 @@ class TFRobertaOutput(tf.keras.layers.Layer):
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertLayer with Bert->Roberta
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Roberta
class TFRobertaLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -417,7 +417,7 @@ class TFRobertaLayer(tf.keras.layers.Layer):
return outputs
# Copied from transformers.modeling_tf_bert.TFBertEncoder with Bert->Roberta
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Roberta
class TFRobertaEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -478,16 +478,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# The embeddings must be the last declaration in order to follow the weights order
self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
def get_input_embeddings(self):
return self.embeddings
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
# Copied from transformers.modeling_tf_bert.TFBertMainLayer._prune_heads
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
......@@ -495,7 +495,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.call
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
def call(
self,
inputs,
......
......@@ -17,9 +17,9 @@
import warnings
from typing import List, Optional
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_utils import AddedToken
from .utils import logging
from ...tokenization_utils import AddedToken
from ...utils import logging
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment