Unverified Commit 1ed2ebf6 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[style] consistent nn. and nn.functional (#12124)

* consistent nn. and nn.functional

* fix glitch

* fix glitch #2
parent ff7c8168
...@@ -20,8 +20,8 @@ from dataclasses import dataclass ...@@ -20,8 +20,8 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, get_activation from ...activations import ACT2FN, get_activation
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import random import random
import torch import torch
from torch.nn import functional as F from torch import nn
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
...@@ -234,7 +234,7 @@ class FlaubertModel(XLMModel): ...@@ -234,7 +234,7 @@ class FlaubertModel(XLMModel):
if token_type_ids is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(token_type_ids)
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = F.dropout(tensor, p=self.dropout, training=self.training) tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training)
tensor *= mask.unsqueeze(-1).to(tensor.dtype) tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers # transformer layers
...@@ -261,7 +261,7 @@ class FlaubertModel(XLMModel): ...@@ -261,7 +261,7 @@ class FlaubertModel(XLMModel):
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = F.dropout(attn, p=self.dropout, training=self.training) attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
else: else:
...@@ -270,13 +270,13 @@ class FlaubertModel(XLMModel): ...@@ -270,13 +270,13 @@ class FlaubertModel(XLMModel):
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = F.dropout(attn, p=self.dropout, training=self.training) attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn tensor = tensor + attn
# encoder attention (for decoder only) # encoder attention (for decoder only)
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
# attn = F.dropout(attn, p=self.dropout, training=self.training) # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
# tensor = tensor + attn # tensor = tensor + attn
# tensor = self.layer_norm15[i](tensor) # tensor = self.layer_norm15[i](tensor)
......
...@@ -675,7 +675,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -675,7 +675,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# encoder attention (for decoder only) # encoder attention (for decoder only)
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
# attn = F.dropout(attn, p=self.dropout, training=self.training) # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
# tensor = tensor + attn # tensor = tensor + attn
# tensor = self.layer_norm15[i](tensor) # tensor = self.layer_norm15[i](tensor)
......
...@@ -32,7 +32,6 @@ import random ...@@ -32,7 +32,6 @@ import random
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import CrossEntropyLoss, LayerNorm
...@@ -430,15 +429,15 @@ class EncoderLayer(nn.Module): ...@@ -430,15 +429,15 @@ class EncoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
residual = x residual = x
x = self.activation_fn(self.fc1(x)) x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training) x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
return x, attn_weights return x, attn_weights
...@@ -504,7 +503,7 @@ class FSMTEncoder(nn.Module): ...@@ -504,7 +503,7 @@ class FSMTEncoder(nn.Module):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_ids) embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos x = inputs_embeds + embed_pos
x = F.dropout(x, p=self.dropout, training=self.training) x = nn.functional.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
...@@ -600,7 +599,7 @@ class DecoderLayer(nn.Module): ...@@ -600,7 +599,7 @@ class DecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
...@@ -615,16 +614,16 @@ class DecoderLayer(nn.Module): ...@@ -615,16 +614,16 @@ class DecoderLayer(nn.Module):
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.encoder_attn_layer_norm(x) x = self.encoder_attn_layer_norm(x)
# Fully Connected # Fully Connected
residual = x residual = x
x = self.activation_fn(self.fc1(x)) x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training) x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
return ( return (
...@@ -641,7 +640,7 @@ class FSMTDecoder(nn.Module): ...@@ -641,7 +640,7 @@ class FSMTDecoder(nn.Module):
Args: Args:
config: FSMTConfig config: FSMTConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding): def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding):
...@@ -726,7 +725,7 @@ class FSMTDecoder(nn.Module): ...@@ -726,7 +725,7 @@ class FSMTDecoder(nn.Module):
x = self.embed_tokens(input_ids) * self.embed_scale x = self.embed_tokens(input_ids) * self.embed_scale
x += positions x += positions
x = F.dropout(x, p=self.dropout, training=self.training) x = nn.functional.dropout(x, p=self.dropout, training=self.training)
# Convert to FSMT output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) # Convert to FSMT output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
x = x.transpose(0, 1) x = x.transpose(0, 1)
...@@ -913,7 +912,7 @@ class Attention(nn.Module): ...@@ -913,7 +912,7 @@ class Attention(nn.Module):
attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
assert layer_head_mask.size() == ( assert layer_head_mask.size() == (
...@@ -929,7 +928,7 @@ class Attention(nn.Module): ...@@ -929,7 +928,7 @@ class Attention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout( attn_probs = nn.functional.dropout(
attn_weights, attn_weights,
p=self.dropout, p=self.dropout,
training=self.training, training=self.training,
......
...@@ -22,7 +22,6 @@ import numpy as np ...@@ -22,7 +22,6 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -196,7 +195,7 @@ class FunnelAttentionStructure(nn.Module): ...@@ -196,7 +195,7 @@ class FunnelAttentionStructure(nn.Module):
position_embeds = self.get_position_embeds(seq_len, inputs_embeds.dtype, inputs_embeds.device) position_embeds = self.get_position_embeds(seq_len, inputs_embeds.dtype, inputs_embeds.device)
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
cls_mask = ( cls_mask = (
F.pad(inputs_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0)) nn.functional.pad(inputs_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0))
if self.config.separate_cls if self.config.separate_cls
else None else None
) )
...@@ -368,11 +367,11 @@ class FunnelAttentionStructure(nn.Module): ...@@ -368,11 +367,11 @@ class FunnelAttentionStructure(nn.Module):
stride = (stride, 1) stride = (stride, 1)
if mode == "mean": if mode == "mean":
tensor = F.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True) tensor = nn.functional.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True)
elif mode == "max": elif mode == "max":
tensor = F.max_pool2d(tensor, stride, stride=stride, ceil_mode=True) tensor = nn.functional.max_pool2d(tensor, stride, stride=stride, ceil_mode=True)
elif mode == "min": elif mode == "min":
tensor = -F.max_pool2d(-tensor, stride, stride=stride, ceil_mode=True) tensor = -nn.functional.max_pool2d(-tensor, stride, stride=stride, ceil_mode=True)
else: else:
raise NotImplementedError("The supported modes are 'mean', 'max' and 'min'.") raise NotImplementedError("The supported modes are 'mean', 'max' and 'min'.")
......
...@@ -20,8 +20,8 @@ from dataclasses import dataclass ...@@ -20,8 +20,8 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
......
...@@ -19,7 +19,6 @@ import os ...@@ -19,7 +19,6 @@ import os
from typing import Tuple from typing import Tuple
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
...@@ -186,7 +185,7 @@ class GPTNeoAttentionMixin: ...@@ -186,7 +185,7 @@ class GPTNeoAttentionMixin:
else: else:
raise ValueError(f"Input tensor rank should be one of [2, 3], but is: {len(tensor.shape)}") raise ValueError(f"Input tensor rank should be one of [2, 3], but is: {len(tensor.shape)}")
padded_tensor = F.pad(tensor, padding_side, value=pad_value) padded_tensor = nn.functional.pad(tensor, padding_side, value=pad_value)
padded_tensor = padded_tensor.unfold(dimension=1, size=window_size + block_length, step=block_length) padded_tensor = padded_tensor.unfold(dimension=1, size=window_size + block_length, step=block_length)
if is_key_value: if is_key_value:
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
import math import math
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import gelu from ...activations import gelu
......
...@@ -19,8 +19,7 @@ import decimal ...@@ -19,8 +19,7 @@ import decimal
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.autograd import Function from torch.autograd import Function
from ...utils import logging from ...utils import logging
...@@ -79,7 +78,7 @@ class QuantEmbedding(nn.Module): ...@@ -79,7 +78,7 @@ class QuantEmbedding(nn.Module):
def forward(self, x, positions=None, incremental_state=None): def forward(self, x, positions=None, incremental_state=None):
if not self.quant_mode: if not self.quant_mode:
return ( return (
F.embedding( nn.functional.embedding(
x, x,
self.weight, self.weight,
self.padding_idx, self.padding_idx,
...@@ -101,7 +100,7 @@ class QuantEmbedding(nn.Module): ...@@ -101,7 +100,7 @@ class QuantEmbedding(nn.Module):
self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor
) )
emb_int = F.embedding( emb_int = nn.functional.embedding(
x, x,
self.weight_integer, self.weight_integer,
self.padding_idx, self.padding_idx,
...@@ -264,7 +263,7 @@ class QuantLinear(nn.Module): ...@@ -264,7 +263,7 @@ class QuantLinear(nn.Module):
def forward(self, x, prev_act_scaling_factor=None): def forward(self, x, prev_act_scaling_factor=None):
if not self.quant_mode: if not self.quant_mode:
return F.linear(x, weight=self.weight, bias=self.bias), None return nn.functional.linear(x, weight=self.weight, bias=self.bias), None
# assert that prev_act_scaling_factor is a scalar tensor # assert that prev_act_scaling_factor is a scalar tensor
assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), ( assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), (
...@@ -295,7 +294,7 @@ class QuantLinear(nn.Module): ...@@ -295,7 +294,7 @@ class QuantLinear(nn.Module):
x_int = x / prev_act_scaling_factor x_int = x / prev_act_scaling_factor
return ( return (
F.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor, nn.functional.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor,
bias_scaling_factor, bias_scaling_factor,
) )
......
...@@ -52,7 +52,7 @@ LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -52,7 +52,7 @@ LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
LayoutLMLayerNorm = torch.nn.LayerNorm LayoutLMLayerNorm = nn.LayerNorm
class LayoutLMEmbeddings(nn.Module): class LayoutLMEmbeddings(nn.Module):
......
...@@ -21,7 +21,6 @@ from dataclasses import dataclass ...@@ -21,7 +21,6 @@ from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -250,7 +249,9 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -250,7 +249,9 @@ class LEDEncoderSelfAttention(nn.Module):
# free memory # free memory
del global_key_attn_scores del global_key_attn_scores
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = nn.functional.softmax(
attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
if layer_head_mask is not None: if layer_head_mask is not None:
assert layer_head_mask.size() == ( assert layer_head_mask.size() == (
...@@ -266,7 +267,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -266,7 +267,7 @@ class LEDEncoderSelfAttention(nn.Module):
del attn_scores del attn_scores
# apply dropout # apply dropout
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
...@@ -326,7 +327,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -326,7 +327,7 @@ class LEDEncoderSelfAttention(nn.Module):
@staticmethod @staticmethod
def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
"""pads rows and then flips rows and columns""" """pads rows and then flips rows and columns"""
hidden_states_padded = F.pad( hidden_states_padded = nn.functional.pad(
hidden_states_padded, padding hidden_states_padded, padding
) # padding value is not important because it will be overwritten ) # padding value is not important because it will be overwritten
hidden_states_padded = hidden_states_padded.view( hidden_states_padded = hidden_states_padded.view(
...@@ -353,7 +354,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -353,7 +354,7 @@ class LEDEncoderSelfAttention(nn.Module):
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
""" """
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
chunked_hidden_states = F.pad( chunked_hidden_states = nn.functional.pad(
chunked_hidden_states, (0, window_overlap + 1) chunked_hidden_states, (0, window_overlap + 1)
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
chunked_hidden_states = chunked_hidden_states.view( chunked_hidden_states = chunked_hidden_states.view(
...@@ -489,7 +490,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -489,7 +490,7 @@ class LEDEncoderSelfAttention(nn.Module):
value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end # pad seq_len with w at the beginning of the sequence and another window overlap at the end
padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1) padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
...@@ -661,7 +662,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -661,7 +662,7 @@ class LEDEncoderSelfAttention(nn.Module):
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
# compute global attn probs # compute global attn probs
global_attn_probs_float = F.softmax( global_attn_probs_float = nn.functional.softmax(
global_attn_scores, dim=-1, dtype=torch.float32 global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability ) # use fp32 for numerical stability
...@@ -677,7 +678,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -677,7 +678,7 @@ class LEDEncoderSelfAttention(nn.Module):
batch_size * self.num_heads, max_num_global_attn_indices, seq_len batch_size * self.num_heads, max_num_global_attn_indices, seq_len
) )
global_attn_probs = F.dropout( global_attn_probs = nn.functional.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
) )
...@@ -833,7 +834,7 @@ class LEDDecoderAttention(nn.Module): ...@@ -833,7 +834,7 @@ class LEDDecoderAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
assert layer_head_mask.size() == ( assert layer_head_mask.size() == (
self.num_heads, self.num_heads,
...@@ -851,7 +852,7 @@ class LEDDecoderAttention(nn.Module): ...@@ -851,7 +852,7 @@ class LEDDecoderAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -914,15 +915,15 @@ class LEDEncoderLayer(nn.Module): ...@@ -914,15 +915,15 @@ class LEDEncoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = attn_outputs[0] hidden_states = attn_outputs[0]
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -1002,7 +1003,7 @@ class LEDDecoderLayer(nn.Module): ...@@ -1002,7 +1003,7 @@ class LEDDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
...@@ -1022,7 +1023,7 @@ class LEDDecoderLayer(nn.Module): ...@@ -1022,7 +1023,7 @@ class LEDDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
...@@ -1032,9 +1033,9 @@ class LEDDecoderLayer(nn.Module): ...@@ -1032,9 +1033,9 @@ class LEDDecoderLayer(nn.Module):
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -1562,7 +1563,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1562,7 +1563,7 @@ class LEDEncoder(LEDPreTrainedModel):
Args: Args:
config: LEDConfig config: LEDConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -1637,7 +1638,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1637,7 +1638,7 @@ class LEDEncoder(LEDPreTrainedModel):
f"`config.attention_window`: {attention_window}" f"`config.attention_window`: {attention_window}"
) )
if input_ids is not None: if input_ids is not None:
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id)
if inputs_embeds is not None: if inputs_embeds is not None:
input_ids_padding = inputs_embeds.new_full( input_ids_padding = inputs_embeds.new_full(
(batch_size, padding_len), (batch_size, padding_len),
...@@ -1647,7 +1648,9 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1647,7 +1648,9 @@ class LEDEncoder(LEDPreTrainedModel):
inputs_embeds_padding = self.embed_tokens(input_ids_padding) inputs_embeds_padding = self.embed_tokens(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens attention_mask = nn.functional.pad(
attention_mask, (0, padding_len), value=False
) # no attention on the padding tokens
return padding_len, input_ids, attention_mask, inputs_embeds return padding_len, input_ids, attention_mask, inputs_embeds
...@@ -1760,7 +1763,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1760,7 +1763,7 @@ class LEDEncoder(LEDPreTrainedModel):
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -1842,7 +1845,7 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1842,7 +1845,7 @@ class LEDDecoder(LEDPreTrainedModel):
Args: Args:
config: LEDConfig config: LEDConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -2008,7 +2011,7 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2008,7 +2011,7 @@ class LEDDecoder(LEDPreTrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -19,6 +19,7 @@ import argparse ...@@ -19,6 +19,7 @@ import argparse
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch import nn
from transformers import LongformerForQuestionAnswering, LongformerModel from transformers import LongformerForQuestionAnswering, LongformerModel
...@@ -28,7 +29,7 @@ class LightningModel(pl.LightningModule): ...@@ -28,7 +29,7 @@ class LightningModel(pl.LightningModule):
super().__init__() super().__init__()
self.model = model self.model = model
self.num_labels = 2 self.num_labels = 2
self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels) self.qa_outputs = nn.Linear(self.model.config.hidden_size, self.num_labels)
# implement only because lightning requires to do so # implement only because lightning requires to do so
def forward(self): def forward(self):
......
...@@ -19,10 +19,9 @@ from dataclasses import dataclass ...@@ -19,10 +19,9 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from ...activations import ACT2FN, gelu from ...activations import ACT2FN, gelu
from ...file_utils import ( from ...file_utils import (
...@@ -640,7 +639,9 @@ class LongformerSelfAttention(nn.Module): ...@@ -640,7 +639,9 @@ class LongformerSelfAttention(nn.Module):
# free memory # free memory
del global_key_attn_scores del global_key_attn_scores
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = nn.functional.softmax(
attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
if layer_head_mask is not None: if layer_head_mask is not None:
assert layer_head_mask.size() == ( assert layer_head_mask.size() == (
...@@ -656,7 +657,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -656,7 +657,7 @@ class LongformerSelfAttention(nn.Module):
del attn_scores del attn_scores
# apply dropout # apply dropout
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
...@@ -716,7 +717,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -716,7 +717,7 @@ class LongformerSelfAttention(nn.Module):
@staticmethod @staticmethod
def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
"""pads rows and then flips rows and columns""" """pads rows and then flips rows and columns"""
hidden_states_padded = F.pad( hidden_states_padded = nn.functional.pad(
hidden_states_padded, padding hidden_states_padded, padding
) # padding value is not important because it will be overwritten ) # padding value is not important because it will be overwritten
hidden_states_padded = hidden_states_padded.view( hidden_states_padded = hidden_states_padded.view(
...@@ -743,7 +744,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -743,7 +744,7 @@ class LongformerSelfAttention(nn.Module):
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
""" """
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
chunked_hidden_states = F.pad( chunked_hidden_states = nn.functional.pad(
chunked_hidden_states, (0, window_overlap + 1) chunked_hidden_states, (0, window_overlap + 1)
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
chunked_hidden_states = chunked_hidden_states.view( chunked_hidden_states = chunked_hidden_states.view(
...@@ -879,7 +880,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -879,7 +880,7 @@ class LongformerSelfAttention(nn.Module):
value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end # pad seq_len with w at the beginning of the sequence and another window overlap at the end
padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1) padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
...@@ -1051,7 +1052,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -1051,7 +1052,7 @@ class LongformerSelfAttention(nn.Module):
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
# compute global attn probs # compute global attn probs
global_attn_probs_float = F.softmax( global_attn_probs_float = nn.functional.softmax(
global_attn_scores, dim=-1, dtype=torch.float32 global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability ) # use fp32 for numerical stability
...@@ -1067,7 +1068,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -1067,7 +1068,7 @@ class LongformerSelfAttention(nn.Module):
batch_size * self.num_heads, max_num_global_attn_indices, seq_len batch_size * self.num_heads, max_num_global_attn_indices, seq_len
) )
global_attn_probs = F.dropout( global_attn_probs = nn.functional.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
) )
...@@ -1546,10 +1547,10 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1546,10 +1547,10 @@ class LongformerModel(LongformerPreTrainedModel):
f"`config.attention_window`: {attention_window}" f"`config.attention_window`: {attention_window}"
) )
if input_ids is not None: if input_ids is not None:
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id)
if position_ids is not None: if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) position_ids = nn.functional.pad(position_ids, (0, padding_len), value=pad_token_id)
if inputs_embeds is not None: if inputs_embeds is not None:
input_ids_padding = inputs_embeds.new_full( input_ids_padding = inputs_embeds.new_full(
(batch_size, padding_len), (batch_size, padding_len),
...@@ -1559,8 +1560,10 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1559,8 +1560,10 @@ class LongformerModel(LongformerPreTrainedModel):
inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens attention_mask = nn.functional.pad(
token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 attention_mask, (0, padding_len), value=False
) # no attention on the padding tokens
token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
......
...@@ -19,9 +19,8 @@ from dataclasses import dataclass ...@@ -19,9 +19,8 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1098,9 +1097,9 @@ class LukeForEntityClassification(LukePreTrainedModel): ...@@ -1098,9 +1097,9 @@ class LukeForEntityClassification(LukePreTrainedModel):
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise. # cross entropy is used otherwise.
if labels.ndim == 1: if labels.ndim == 1:
loss = F.cross_entropy(logits, labels) loss = nn.functional.cross_entropy(logits, labels)
else: else:
loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict: if not return_dict:
output = ( output = (
...@@ -1213,9 +1212,9 @@ class LukeForEntityPairClassification(LukePreTrainedModel): ...@@ -1213,9 +1212,9 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise. # cross entropy is used otherwise.
if labels.ndim == 1: if labels.ndim == 1:
loss = F.cross_entropy(logits, labels) loss = nn.functional.cross_entropy(logits, labels)
else: else:
loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict: if not return_dict:
output = ( output = (
...@@ -1351,9 +1350,9 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): ...@@ -1351,9 +1350,9 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
# When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary # When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise. # cross entropy is used otherwise.
if labels.ndim == 2: if labels.ndim == 2:
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) loss = nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
else: else:
loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict: if not return_dict:
output = ( output = (
......
...@@ -20,7 +20,6 @@ import random ...@@ -20,7 +20,6 @@ import random
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -293,7 +292,7 @@ class M2M100Attention(nn.Module): ...@@ -293,7 +292,7 @@ class M2M100Attention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -313,7 +312,7 @@ class M2M100Attention(nn.Module): ...@@ -313,7 +312,7 @@ class M2M100Attention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -375,15 +374,15 @@ class M2M100EncoderLayer(nn.Module): ...@@ -375,15 +374,15 @@ class M2M100EncoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and ( if hidden_states.dtype == torch.float16 and (
...@@ -471,7 +470,7 @@ class M2M100DecoderLayer(nn.Module): ...@@ -471,7 +470,7 @@ class M2M100DecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Cross-Attention Block # Cross-Attention Block
...@@ -491,7 +490,7 @@ class M2M100DecoderLayer(nn.Module): ...@@ -491,7 +490,7 @@ class M2M100DecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple # add cross-attn to positions 3,4 of present_key_value tuple
...@@ -501,9 +500,9 @@ class M2M100DecoderLayer(nn.Module): ...@@ -501,9 +500,9 @@ class M2M100DecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
...@@ -665,7 +664,7 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -665,7 +664,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
Args: Args:
config: M2M100Config config: M2M100Config
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None):
...@@ -764,7 +763,7 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -764,7 +763,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
embed_pos = self.embed_positions(input_ids, inputs_embeds) embed_pos = self.embed_positions(input_ids, inputs_embeds)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -832,7 +831,7 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -832,7 +831,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
Args: Args:
config: M2M100Config config: M2M100Config
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None):
...@@ -989,7 +988,7 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -989,7 +988,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -24,6 +24,7 @@ from zipfile import ZipFile ...@@ -24,6 +24,7 @@ from zipfile import ZipFile
import numpy as np import numpy as np
import torch import torch
from torch import nn
from tqdm import tqdm from tqdm import tqdm
from transformers import MarianConfig, MarianMTModel, MarianTokenizer from transformers import MarianConfig, MarianMTModel, MarianTokenizer
...@@ -53,7 +54,7 @@ def convert_encoder_layer(opus_dict, layer_prefix: str, converter: dict): ...@@ -53,7 +54,7 @@ def convert_encoder_layer(opus_dict, layer_prefix: str, converter: dict):
return sd return sd
def load_layers_(layer_lst: torch.nn.ModuleList, opus_state: dict, converter, is_decoder=False): def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decoder=False):
for i, layer in enumerate(layer_lst): for i, layer in enumerate(layer_lst):
layer_tag = f"decoder_l{i + 1}_" if is_decoder else f"encoder_l{i + 1}_" layer_tag = f"decoder_l{i + 1}_" if is_decoder else f"encoder_l{i + 1}_"
sd = convert_encoder_layer(opus_state, layer_tag, converter) sd = convert_encoder_layer(opus_state, layer_tag, converter)
...@@ -543,8 +544,8 @@ class OpusState: ...@@ -543,8 +544,8 @@ class OpusState:
load_layers_(model.model.decoder.layers, state_dict, BART_CONVERTER, is_decoder=True) load_layers_(model.model.decoder.layers, state_dict, BART_CONVERTER, is_decoder=True)
# handle tensors not associated with layers # handle tensors not associated with layers
wemb_tensor = torch.nn.Parameter(torch.FloatTensor(self.wemb)) wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb))
bias_tensor = torch.nn.Parameter(torch.FloatTensor(self.final_bias)) bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
model.model.shared.weight = wemb_tensor model.model.shared.weight = wemb_tensor
model.model.encoder.embed_tokens = model.model.decoder.embed_tokens = model.model.shared model.model.encoder.embed_tokens = model.model.decoder.embed_tokens = model.model.shared
......
...@@ -22,7 +22,6 @@ from typing import Optional, Tuple ...@@ -22,7 +22,6 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -239,7 +238,7 @@ class MarianAttention(nn.Module): ...@@ -239,7 +238,7 @@ class MarianAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -259,7 +258,7 @@ class MarianAttention(nn.Module): ...@@ -259,7 +258,7 @@ class MarianAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -320,15 +319,15 @@ class MarianEncoderLayer(nn.Module): ...@@ -320,15 +319,15 @@ class MarianEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -416,7 +415,7 @@ class MarianDecoderLayer(nn.Module): ...@@ -416,7 +415,7 @@ class MarianDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
...@@ -436,7 +435,7 @@ class MarianDecoderLayer(nn.Module): ...@@ -436,7 +435,7 @@ class MarianDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
...@@ -446,9 +445,9 @@ class MarianDecoderLayer(nn.Module): ...@@ -446,9 +445,9 @@ class MarianDecoderLayer(nn.Module):
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -630,7 +629,7 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -630,7 +629,7 @@ class MarianEncoder(MarianPreTrainedModel):
Args: Args:
config: MarianConfig config: MarianConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -727,7 +726,7 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -727,7 +726,7 @@ class MarianEncoder(MarianPreTrainedModel):
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -793,7 +792,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -793,7 +792,7 @@ class MarianDecoder(MarianPreTrainedModel):
Args: Args:
config: MarianConfig config: MarianConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -963,7 +962,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -963,7 +962,7 @@ class MarianDecoder(MarianPreTrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -19,7 +19,6 @@ import random ...@@ -19,7 +19,6 @@ import random
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
...@@ -230,7 +229,7 @@ class MBartAttention(nn.Module): ...@@ -230,7 +229,7 @@ class MBartAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -250,7 +249,7 @@ class MBartAttention(nn.Module): ...@@ -250,7 +249,7 @@ class MBartAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -311,15 +310,15 @@ class MBartEncoderLayer(nn.Module): ...@@ -311,15 +310,15 @@ class MBartEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and ( if hidden_states.dtype == torch.float16 and (
...@@ -406,7 +405,7 @@ class MBartDecoderLayer(nn.Module): ...@@ -406,7 +405,7 @@ class MBartDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Cross-Attention Block # Cross-Attention Block
...@@ -426,7 +425,7 @@ class MBartDecoderLayer(nn.Module): ...@@ -426,7 +425,7 @@ class MBartDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple # add cross-attn to positions 3,4 of present_key_value tuple
...@@ -436,9 +435,9 @@ class MBartDecoderLayer(nn.Module): ...@@ -436,9 +435,9 @@ class MBartDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
...@@ -658,7 +657,7 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -658,7 +657,7 @@ class MBartEncoder(MBartPreTrainedModel):
Args: Args:
config: MBartConfig config: MBartConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -758,7 +757,7 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -758,7 +757,7 @@ class MBartEncoder(MBartPreTrainedModel):
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -826,7 +825,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -826,7 +825,7 @@ class MBartDecoder(MBartPreTrainedModel):
Args: Args:
config: MBartConfig config: MBartConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -999,7 +998,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -999,7 +998,7 @@ class MBartDecoder(MBartPreTrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
import torch.nn as nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
......
...@@ -27,7 +27,6 @@ from dataclasses import dataclass ...@@ -27,7 +27,6 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -155,7 +154,7 @@ class NoNorm(nn.Module): ...@@ -155,7 +154,7 @@ class NoNorm(nn.Module):
return input_tensor * self.weight + self.bias return input_tensor * self.weight + self.bias
NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm} NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}
class MobileBertEmbeddings(nn.Module): class MobileBertEmbeddings(nn.Module):
...@@ -207,9 +206,9 @@ class MobileBertEmbeddings(nn.Module): ...@@ -207,9 +206,9 @@ class MobileBertEmbeddings(nn.Module):
# dimensional output. # dimensional output.
inputs_embeds = torch.cat( inputs_embeds = torch.cat(
[ [
F.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0), nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0),
inputs_embeds, inputs_embeds,
F.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0), nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0),
], ],
dim=2, dim=2,
) )
...@@ -920,7 +919,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -920,7 +919,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
def set_output_embeddings(self, new_embeddigs): def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs self.cls.predictions.decoder = new_embeddigs
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding: def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first # resize dense output embedings at first
self.cls.predictions.dense = self._get_resized_lm_head( self.cls.predictions.dense = self._get_resized_lm_head(
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
...@@ -1028,7 +1027,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): ...@@ -1028,7 +1027,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
def set_output_embeddings(self, new_embeddigs): def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs self.cls.predictions.decoder = new_embeddigs
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding: def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first # resize dense output embedings at first
self.cls.predictions.dense = self._get_resized_lm_head( self.cls.predictions.dense = self._get_resized_lm_head(
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment