"tools/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "85a98fe23a973617ad451db0dee098eb936451f1"
Unverified Commit 27174bd4 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make PyTorch model files independent from each other (#7352)

parent d161ed16
...@@ -40,6 +40,10 @@ def gelu_fast(x): ...@@ -40,6 +40,10 @@ def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def mish(x):
return x * torch.tanh(torch.nn.functional.softplus(x))
ACT2FN = { ACT2FN = {
"relu": F.relu, "relu": F.relu,
"swish": swish, "swish": swish,
...@@ -47,6 +51,7 @@ ACT2FN = { ...@@ -47,6 +51,7 @@ ACT2FN = {
"tanh": torch.tanh, "tanh": torch.tanh,
"gelu_new": gelu_new, "gelu_new": gelu_new,
"gelu_fast": gelu_fast, "gelu_fast": gelu_fast,
"mish": mish,
} }
......
...@@ -24,6 +24,7 @@ import torch ...@@ -24,6 +24,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
from .file_utils import ( from .file_utils import (
ModelOutput, ModelOutput,
...@@ -32,7 +33,6 @@ from .file_utils import ( ...@@ -32,7 +33,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
...@@ -42,7 +42,12 @@ from .modeling_outputs import ( ...@@ -42,7 +42,12 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from .utils import logging from .utils import logging
...@@ -192,33 +197,81 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -192,33 +197,81 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
return model return model
class AlbertEmbeddings(BertEmbeddings): class AlbertEmbeddings(nn.Module):
""" """
Construct the embeddings from word, position and token_type embeddings. Construct the embeddings from word, position and token_type embeddings.
""" """
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
self.LayerNorm = torch.nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
# Copied from transformers.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
class AlbertAttention(BertSelfAttention): if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class AlbertAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob) self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob) self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pruned_heads = set() self.pruned_heads = set()
# Copied from transformers.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
......
...@@ -27,7 +27,7 @@ import torch.utils.checkpoint ...@@ -27,7 +27,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_new, swish from .activations import ACT2FN
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .file_utils import ( from .file_utils import (
ModelOutput, ModelOutput,
...@@ -162,16 +162,6 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -162,16 +162,6 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
return model return model
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
BertLayerNorm = torch.nn.LayerNorm
class BertEmbeddings(nn.Module): class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
...@@ -183,7 +173,7 @@ class BertEmbeddings(nn.Module): ...@@ -183,7 +173,7 @@ class BertEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
...@@ -296,7 +286,7 @@ class BertSelfOutput(nn.Module): ...@@ -296,7 +286,7 @@ class BertSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -372,7 +362,7 @@ class BertOutput(nn.Module): ...@@ -372,7 +362,7 @@ class BertOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -528,7 +518,7 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -528,7 +518,7 @@ class BertPredictionHeadTransform(nn.Module):
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
else: else:
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
...@@ -605,7 +595,7 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -605,7 +595,7 @@ class BertPreTrainedModel(PreTrainedModel):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
......
# coding=utf-8
# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ELECTRA model. """
import math
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
...@@ -7,7 +24,7 @@ import torch ...@@ -7,7 +24,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import get_activation from .activations import ACT2FN, get_activation
from .configuration_electra import ElectraConfig from .configuration_electra import ElectraConfig
from .file_utils import ( from .file_utils import (
ModelOutput, ModelOutput,
...@@ -16,7 +33,6 @@ from .file_utils import ( ...@@ -16,7 +33,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutput,
MaskedLMOutput, MaskedLMOutput,
...@@ -25,7 +41,13 @@ from .modeling_outputs import ( ...@@ -25,7 +41,13 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import SequenceSummary from .modeling_utils import (
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from .utils import logging from .utils import logging
...@@ -128,18 +150,345 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_ ...@@ -128,18 +150,345 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_
return model return model
class ElectraEmbeddings(BertEmbeddings): class ElectraEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
# Copied from transformers.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
# Copied from transformers.modeling_bert.BertSelfAttention with Bert->Electra
class ElectraSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in ElectraModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.modeling_bert.BertSelfOutput
class ElectraSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.modeling_bert.BertAttention with Bert->Electra
class ElectraAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = ElectraSelfAttention(config)
self.output = ElectraSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.modeling_bert.BertIntermediate
class ElectraIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.modeling_bert.BertOutput
class ElectraOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.modeling_bert.BertLayer with Bert->Electra
class ElectraLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ElectraAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = ElectraAttention(config)
self.intermediate = ElectraIntermediate(config)
self.output = ElectraOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
# Copied from transformers.modeling_bert.BertEncoder with Bert->Electra
class ElectraEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class ElectraDiscriminatorPredictions(nn.Module): class ElectraDiscriminatorPredictions(nn.Module):
...@@ -166,7 +515,7 @@ class ElectraGeneratorPredictions(nn.Module): ...@@ -166,7 +515,7 @@ class ElectraGeneratorPredictions(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.LayerNorm = BertLayerNorm(config.embedding_size) self.LayerNorm = nn.LayerNorm(config.embedding_size)
self.dense = nn.Linear(config.hidden_size, config.embedding_size) self.dense = nn.Linear(config.hidden_size, config.embedding_size)
def forward(self, generator_hidden_states): def forward(self, generator_hidden_states):
...@@ -177,7 +526,7 @@ class ElectraGeneratorPredictions(nn.Module): ...@@ -177,7 +526,7 @@ class ElectraGeneratorPredictions(nn.Module):
return hidden_states return hidden_states
class ElectraPreTrainedModel(BertPreTrainedModel): class ElectraPreTrainedModel(PreTrainedModel):
"""An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -187,6 +536,19 @@ class ElectraPreTrainedModel(BertPreTrainedModel): ...@@ -187,6 +536,19 @@ class ElectraPreTrainedModel(BertPreTrainedModel):
base_model_prefix = "electra" base_model_prefix = "electra"
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
# Copied from transformers.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass @dataclass
class ElectraForPreTrainingOutput(ModelOutput): class ElectraForPreTrainingOutput(ModelOutput):
...@@ -306,9 +668,6 @@ ELECTRA_INPUTS_DOCSTRING = r""" ...@@ -306,9 +668,6 @@ ELECTRA_INPUTS_DOCSTRING = r"""
ELECTRA_START_DOCSTRING, ELECTRA_START_DOCSTRING,
) )
class ElectraModel(ElectraPreTrainedModel): class ElectraModel(ElectraPreTrainedModel):
config_class = ElectraConfig
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.embeddings = ElectraEmbeddings(config) self.embeddings = ElectraEmbeddings(config)
...@@ -316,7 +675,7 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -316,7 +675,7 @@ class ElectraModel(ElectraPreTrainedModel):
if config.embedding_size != config.hidden_size: if config.embedding_size != config.hidden_size:
self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size) self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
self.encoder = BertEncoder(config) self.encoder = ElectraEncoder(config)
self.config = config self.config = config
self.init_weights() self.init_weights()
......
...@@ -63,9 +63,6 @@ FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -63,9 +63,6 @@ FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"funnel-transformer/xlarge", # B10-10-10H1024, no decoder "funnel-transformer/xlarge", # B10-10-10H1024, no decoder
] ]
FunnelLayerNorm = nn.LayerNorm
INF = 1e6 INF = 1e6
...@@ -163,7 +160,7 @@ class FunnelEmbeddings(nn.Module): ...@@ -163,7 +160,7 @@ class FunnelEmbeddings(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.layer_norm = FunnelLayerNorm(config.d_model, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, input_ids=None, inputs_embeds=None): def forward(self, input_ids=None, inputs_embeds=None):
...@@ -457,7 +454,7 @@ class FunnelRelMultiheadAttention(nn.Module): ...@@ -457,7 +454,7 @@ class FunnelRelMultiheadAttention(nn.Module):
self.seg_embed = nn.Parameter(torch.zeros([2, n_head, d_head])) self.seg_embed = nn.Parameter(torch.zeros([2, n_head, d_head]))
self.post_proj = nn.Linear(n_head * d_head, d_model) self.post_proj = nn.Linear(n_head * d_head, d_model)
self.layer_norm = FunnelLayerNorm(d_model, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(d_model, eps=config.layer_norm_eps)
self.scale = 1.0 / (d_head ** 0.5) self.scale = 1.0 / (d_head ** 0.5)
def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None): def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
...@@ -581,7 +578,7 @@ class FunnelPositionwiseFFN(nn.Module): ...@@ -581,7 +578,7 @@ class FunnelPositionwiseFFN(nn.Module):
self.activation_dropout = nn.Dropout(config.activation_dropout) self.activation_dropout = nn.Dropout(config.activation_dropout)
self.linear_2 = nn.Linear(config.d_inner, config.d_model) self.linear_2 = nn.Linear(config.d_inner, config.d_model)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = FunnelLayerNorm(config.d_model, config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps)
def forward(self, hidden): def forward(self, hidden):
h = self.linear_1(hidden) h = self.linear_1(hidden)
......
...@@ -202,7 +202,7 @@ class LayoutLMSelfOutput(nn.Module): ...@@ -202,7 +202,7 @@ class LayoutLMSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -281,7 +281,7 @@ class LayoutLMOutput(nn.Module): ...@@ -281,7 +281,7 @@ class LayoutLMOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -441,7 +441,7 @@ class LayoutLMPredictionHeadTransform(nn.Module): ...@@ -441,7 +441,7 @@ class LayoutLMPredictionHeadTransform(nn.Module):
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
else: else:
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
......
...@@ -22,6 +22,7 @@ import torch.nn as nn ...@@ -22,6 +22,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from .activations import ACT2FN, gelu
from .configuration_longformer import LongformerConfig from .configuration_longformer import LongformerConfig
from .file_utils import ( from .file_utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -29,7 +30,6 @@ from .file_utils import ( ...@@ -29,7 +30,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_bert import BertIntermediate, BertLayerNorm, BertOutput, BertPooler, BertPreTrainedModel, BertSelfOutput
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
...@@ -39,7 +39,6 @@ from .modeling_outputs import ( ...@@ -39,7 +39,6 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_roberta import RobertaEmbeddings, RobertaLMHead
from .modeling_utils import ( from .modeling_utils import (
PreTrainedModel, PreTrainedModel,
apply_chunking_to_forward, apply_chunking_to_forward,
...@@ -100,6 +99,95 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru ...@@ -100,6 +99,95 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru
return attention_mask return attention_mask
# Copied from transformers.modeling_roberta.create_position_ids_from_input_ids
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx
class LongformerEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
# Copied from transformers.modeling_bert.BertEmbeddings.__init__
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
# End copy
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
# Copied from transformers.modeling_bert.BertEmbeddings.forward
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param torch.Tensor inputs_embeds:
:return torch.Tensor:
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
class LongformerSelfAttention(nn.Module): class LongformerSelfAttention(nn.Module):
def __init__(self, config, layer_id): def __init__(self, config, layer_id):
super().__init__() super().__init__()
...@@ -656,11 +744,26 @@ class LongformerSelfAttention(nn.Module): ...@@ -656,11 +744,26 @@ class LongformerSelfAttention(nn.Module):
return global_attn_output return global_attn_output
# Copied from transformers.modeling_bert.BertSelfOutput
class LongformerSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class LongformerAttention(nn.Module): class LongformerAttention(nn.Module):
def __init__(self, config, layer_id=0): def __init__(self, config, layer_id=0):
super().__init__() super().__init__()
self.self = LongformerSelfAttention(config, layer_id) self.self = LongformerSelfAttention(config, layer_id)
self.output = BertSelfOutput(config) self.output = LongformerSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads):
...@@ -697,12 +800,43 @@ class LongformerAttention(nn.Module): ...@@ -697,12 +800,43 @@ class LongformerAttention(nn.Module):
return outputs return outputs
# Copied from transformers.modeling_bert.BertIntermediate
class LongformerIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.modeling_bert.BertOutput
class LongformerOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class LongformerLayer(nn.Module): class LongformerLayer(nn.Module):
def __init__(self, config, layer_id=0): def __init__(self, config, layer_id=0):
super().__init__() super().__init__()
self.attention = LongformerAttention(config, layer_id) self.attention = LongformerAttention(config, layer_id)
self.intermediate = BertIntermediate(config) self.intermediate = LongformerIntermediate(config)
self.output = BertOutput(config) self.output = LongformerOutput(config)
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
...@@ -787,6 +921,48 @@ class LongformerEncoder(nn.Module): ...@@ -787,6 +921,48 @@ class LongformerEncoder(nn.Module):
) )
# Copied from transformers.modeling_bert.BertPooler
class LongformerPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
# Copied from transformers.modeling_roberta.RobertaLMHead with Roberta->Longformer
class LongformerLMHead(nn.Module):
"""Longformer Head for masked language modeling."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x)
return x
class LongformerPreTrainedModel(PreTrainedModel): class LongformerPreTrainedModel(PreTrainedModel):
"""An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained a simple interface for downloading and loading pretrained
...@@ -803,7 +979,7 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -803,7 +979,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
...@@ -922,9 +1098,9 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -922,9 +1098,9 @@ class LongformerModel(LongformerPreTrainedModel):
f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
) )
self.embeddings = RobertaEmbeddings(config) self.embeddings = LongformerEmbeddings(config)
self.encoder = LongformerEncoder(config) self.encoder = LongformerEncoder(config)
self.pooler = BertPooler(config) self.pooler = LongformerPooler(config)
self.init_weights() self.init_weights()
...@@ -1121,7 +1297,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1121,7 +1297,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
super().__init__(config) super().__init__(config)
self.longformer = LongformerModel(config) self.longformer = LongformerModel(config)
self.lm_head = RobertaLMHead(config) self.lm_head = LongformerLMHead(config)
self.init_weights() self.init_weights()
...@@ -1218,10 +1394,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1218,10 +1394,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
on top of the pooled output) e.g. for GLUE tasks. """, on top of the pooled output) e.g. for GLUE tasks. """,
LONGFORMER_START_DOCSTRING, LONGFORMER_START_DOCSTRING,
) )
class LongformerForSequenceClassification(BertPreTrainedModel): class LongformerForSequenceClassification(LongformerPreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1326,10 +1499,7 @@ class LongformerClassificationHead(nn.Module): ...@@ -1326,10 +1499,7 @@ class LongformerClassificationHead(nn.Module):
TriviaQA (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, TriviaQA (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
LONGFORMER_START_DOCSTRING, LONGFORMER_START_DOCSTRING,
) )
class LongformerForQuestionAnswering(BertPreTrainedModel): class LongformerForQuestionAnswering(LongformerPreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1457,10 +1627,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -1457,10 +1627,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
LONGFORMER_START_DOCSTRING, LONGFORMER_START_DOCSTRING,
) )
class LongformerForTokenClassification(BertPreTrainedModel): class LongformerForTokenClassification(LongformerPreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1546,10 +1713,7 @@ class LongformerForTokenClassification(BertPreTrainedModel): ...@@ -1546,10 +1713,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
LONGFORMER_START_DOCSTRING, LONGFORMER_START_DOCSTRING,
) )
class LongformerForMultipleChoice(BertPreTrainedModel): class LongformerForMultipleChoice(LongformerPreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -25,7 +25,7 @@ import torch ...@@ -25,7 +25,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, SmoothL1Loss from torch.nn import CrossEntropyLoss, SmoothL1Loss
from .activations import gelu, swish from .activations import ACT2FN, gelu
from .configuration_lxmert import LxmertConfig from .configuration_lxmert import LxmertConfig
from .file_utils import ( from .file_utils import (
ModelOutput, ModelOutput,
...@@ -275,11 +275,6 @@ def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path): ...@@ -275,11 +275,6 @@ def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path):
return model return model
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
LxmertLayerNorm = torch.nn.LayerNorm
class LxmertEmbeddings(nn.Module): class LxmertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
...@@ -291,7 +286,7 @@ class LxmertEmbeddings(nn.Module): ...@@ -291,7 +286,7 @@ class LxmertEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = LxmertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None, inputs_embeds=None): def forward(self, input_ids, token_type_ids=None, inputs_embeds=None):
...@@ -385,7 +380,7 @@ class LxmertAttentionOutput(nn.Module): ...@@ -385,7 +380,7 @@ class LxmertAttentionOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = LxmertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -447,7 +442,7 @@ class LxmertOutput(nn.Module): ...@@ -447,7 +442,7 @@ class LxmertOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = LxmertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -573,11 +568,11 @@ class LxmertVisualFeatureEncoder(nn.Module): ...@@ -573,11 +568,11 @@ class LxmertVisualFeatureEncoder(nn.Module):
# Object feature encoding # Object feature encoding
self.visn_fc = nn.Linear(feat_dim, config.hidden_size) self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
self.visn_layer_norm = LxmertLayerNorm(config.hidden_size, eps=1e-12) self.visn_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
# Box position encoding # Box position encoding
self.box_fc = nn.Linear(pos_dim, config.hidden_size) self.box_fc = nn.Linear(pos_dim, config.hidden_size)
self.box_layer_norm = LxmertLayerNorm(config.hidden_size, eps=1e-12) self.box_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
...@@ -694,7 +689,7 @@ class LxmertPredictionHeadTransform(nn.Module): ...@@ -694,7 +689,7 @@ class LxmertPredictionHeadTransform(nn.Module):
super(LxmertPredictionHeadTransform, self).__init__() super(LxmertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
self.LayerNorm = LxmertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
...@@ -731,7 +726,7 @@ class LxmertVisualAnswerHead(nn.Module): ...@@ -731,7 +726,7 @@ class LxmertVisualAnswerHead(nn.Module):
self.logit_fc = nn.Sequential( self.logit_fc = nn.Sequential(
nn.Linear(hid_dim, hid_dim * 2), nn.Linear(hid_dim, hid_dim * 2),
GeLU(), GeLU(),
LxmertLayerNorm(hid_dim * 2, eps=1e-12), nn.LayerNorm(hid_dim * 2, eps=1e-12),
nn.Linear(hid_dim * 2, num_labels), nn.Linear(hid_dim * 2, num_labels),
) )
...@@ -797,7 +792,7 @@ class LxmertPreTrainedModel(PreTrainedModel): ...@@ -797,7 +792,7 @@ class LxmertPreTrainedModel(PreTrainedModel):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, LxmertLayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
......
...@@ -31,7 +31,7 @@ import torch.nn.functional as F ...@@ -31,7 +31,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_new, swish from .activations import ACT2FN
from .configuration_mobilebert import MobileBertConfig from .configuration_mobilebert import MobileBertConfig
from .file_utils import ( from .file_utils import (
ModelOutput, ModelOutput,
...@@ -40,7 +40,6 @@ from .file_utils import ( ...@@ -40,7 +40,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_bert import BertIntermediate
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
...@@ -155,7 +154,6 @@ class NoNorm(nn.Module): ...@@ -155,7 +154,6 @@ class NoNorm(nn.Module):
return input_tensor * self.weight + self.bias return input_tensor * self.weight + self.bias
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm} NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm}
...@@ -358,10 +356,19 @@ class MobileBertAttention(nn.Module): ...@@ -358,10 +356,19 @@ class MobileBertAttention(nn.Module):
return outputs return outputs
class MobileBertIntermediate(BertIntermediate): class MobileBertIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__()
self.dense = nn.Linear(config.true_hidden_size, config.intermediate_size) self.dense = nn.Linear(config.true_hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class OutputBottleneck(nn.Module): class OutputBottleneck(nn.Module):
......
...@@ -28,7 +28,7 @@ from torch import nn ...@@ -28,7 +28,7 @@ from torch import nn
from torch.autograd.function import Function from torch.autograd.function import Function
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_fast, gelu_new, swish from .activations import ACT2FN
from .configuration_reformer import ReformerConfig from .configuration_reformer import ReformerConfig
from .file_utils import ( from .file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
...@@ -55,20 +55,6 @@ REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -55,20 +55,6 @@ REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
ACT2FN = {
"gelu": gelu,
"relu": torch.nn.functional.relu,
"swish": swish,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"mish": mish,
}
# Define named tuples for nn.Modules here # Define named tuples for nn.Modules here
LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"]) LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"])
LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"])
......
...@@ -25,7 +25,7 @@ import torch.utils.checkpoint as checkpoint ...@@ -25,7 +25,7 @@ import torch.utils.checkpoint as checkpoint
from .configuration_retribert import RetriBertConfig from .configuration_retribert import RetriBertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_bert import BertLayerNorm, BertModel from .modeling_bert import BertModel
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .utils import logging from .utils import logging
...@@ -52,7 +52,7 @@ class RetriBertPreTrainedModel(PreTrainedModel): ...@@ -52,7 +52,7 @@ class RetriBertPreTrainedModel(PreTrainedModel):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
......
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN, gelu
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .file_utils import ( from .file_utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -29,7 +30,6 @@ from .file_utils import ( ...@@ -29,7 +30,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_bert import ACT2FN, gelu
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
...@@ -65,15 +65,12 @@ ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -65,15 +65,12 @@ ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
RobertaLayerNorm = torch.nn.LayerNorm
class RobertaEmbeddings(nn.Module): class RobertaEmbeddings(nn.Module):
""" """
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
""" """
# Copied from transformers.modeling_bert.BertEmbeddings.__init__ with Bert->Roberta # Copied from transformers.modeling_bert.BertEmbeddings.__init__
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
...@@ -82,7 +79,7 @@ class RobertaEmbeddings(nn.Module): ...@@ -82,7 +79,7 @@ class RobertaEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = RobertaLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
...@@ -221,12 +218,12 @@ class RobertaSelfAttention(nn.Module): ...@@ -221,12 +218,12 @@ class RobertaSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.modeling_bert.BertSelfOutput with Bert->Roberta # Copied from transformers.modeling_bert.BertSelfOutput
class RobertaSelfOutput(nn.Module): class RobertaSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = RobertaLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -300,12 +297,12 @@ class RobertaIntermediate(nn.Module): ...@@ -300,12 +297,12 @@ class RobertaIntermediate(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.modeling_bert.BertOutput with Bert->Roberta # Copied from transformers.modeling_bert.BertOutput
class RobertaOutput(nn.Module): class RobertaOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = RobertaLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -465,14 +462,14 @@ class RobertaPreTrainedModel(PreTrainedModel): ...@@ -465,14 +462,14 @@ class RobertaPreTrainedModel(PreTrainedModel):
base_model_prefix = "roberta" base_model_prefix = "roberta"
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
# Copied from transformers.modeling_bert.BertPreTrainedModel._init_weights with Bert->Roberta # Copied from transformers.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, RobertaLayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
...@@ -916,7 +913,7 @@ class RobertaLMHead(nn.Module): ...@@ -916,7 +913,7 @@ class RobertaLMHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = RobertaLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
......
...@@ -25,7 +25,7 @@ from torch import nn ...@@ -25,7 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from .activations import gelu_new, swish from .activations import ACT2FN
from .configuration_xlnet import XLNetConfig from .configuration_xlnet import XLNetConfig
from .file_utils import ( from .file_utils import (
ModelOutput, ModelOutput,
...@@ -207,12 +207,6 @@ def load_tf_weights_in_xlnet(model, config, tf_path): ...@@ -207,12 +207,6 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
return model return model
ACT2FN = {"gelu": gelu_new, "relu": torch.nn.functional.relu, "swish": swish}
XLNetLayerNorm = nn.LayerNorm
class XLNetRelativeAttention(nn.Module): class XLNetRelativeAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -239,7 +233,7 @@ class XLNetRelativeAttention(nn.Module): ...@@ -239,7 +233,7 @@ class XLNetRelativeAttention(nn.Module):
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.seg_embed = nn.Parameter(torch.FloatTensor(2, self.n_head, self.d_head)) self.seg_embed = nn.Parameter(torch.FloatTensor(2, self.n_head, self.d_head))
self.layer_norm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
def prune_heads(self, heads): def prune_heads(self, heads):
...@@ -476,7 +470,7 @@ class XLNetRelativeAttention(nn.Module): ...@@ -476,7 +470,7 @@ class XLNetRelativeAttention(nn.Module):
class XLNetFeedForward(nn.Module): class XLNetFeedForward(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.layer_norm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.layer_1 = nn.Linear(config.d_model, config.d_inner) self.layer_1 = nn.Linear(config.d_model, config.d_inner)
self.layer_2 = nn.Linear(config.d_inner, config.d_model) self.layer_2 = nn.Linear(config.d_inner, config.d_model)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
...@@ -563,7 +557,7 @@ class XLNetPreTrainedModel(PreTrainedModel): ...@@ -563,7 +557,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, XLNetLayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
elif isinstance(module, XLNetRelativeAttention): elif isinstance(module, XLNetRelativeAttention):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment