"docs/primer.md" did not exist on "309e8a271e9aca1ef4aab899ce5d2d07c42123bb"
Unverified Commit 1ed2ebf6 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[style] consistent nn. and nn.functional (#12124)

* consistent nn. and nn.functional

* fix glitch

* fix glitch #2
parent ff7c8168
...@@ -23,7 +23,7 @@ from dataclasses import dataclass ...@@ -23,7 +23,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import gelu_new, silu from ...activations import gelu_new, silu
......
...@@ -21,7 +21,6 @@ from typing import Optional, Tuple ...@@ -21,7 +21,6 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -239,7 +238,7 @@ class PegasusAttention(nn.Module): ...@@ -239,7 +238,7 @@ class PegasusAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -259,7 +258,7 @@ class PegasusAttention(nn.Module): ...@@ -259,7 +258,7 @@ class PegasusAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -321,15 +320,15 @@ class PegasusEncoderLayer(nn.Module): ...@@ -321,15 +320,15 @@ class PegasusEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and ( if hidden_states.dtype == torch.float16 and (
...@@ -417,7 +416,7 @@ class PegasusDecoderLayer(nn.Module): ...@@ -417,7 +416,7 @@ class PegasusDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Cross-Attention Block # Cross-Attention Block
...@@ -437,7 +436,7 @@ class PegasusDecoderLayer(nn.Module): ...@@ -437,7 +436,7 @@ class PegasusDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple # add cross-attn to positions 3,4 of present_key_value tuple
...@@ -447,9 +446,9 @@ class PegasusDecoderLayer(nn.Module): ...@@ -447,9 +446,9 @@ class PegasusDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
...@@ -629,7 +628,7 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -629,7 +628,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
Args: Args:
config: PegasusConfig config: PegasusConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -729,7 +728,7 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -729,7 +728,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -797,7 +796,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -797,7 +796,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
Args: Args:
config: PegasusConfig config: PegasusConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -969,7 +968,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -969,7 +968,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import argparse import argparse
import torch from torch import nn
from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
...@@ -107,15 +107,15 @@ def convert_prophetnet_checkpoint_to_pytorch(prophetnet_checkpoint_path: str, py ...@@ -107,15 +107,15 @@ def convert_prophetnet_checkpoint_to_pytorch(prophetnet_checkpoint_path: str, py
param.weight.shape == old_model.in_proj_weight[:embed_dim, :].shape, "Shapes have to match" param.weight.shape == old_model.in_proj_weight[:embed_dim, :].shape, "Shapes have to match"
param.bias.shape == old_model.in_proj_bias[:embed_dim].shape, "Shapes have to match" param.bias.shape == old_model.in_proj_bias[:embed_dim].shape, "Shapes have to match"
if attribute == "query_proj": if attribute == "query_proj":
model.query_proj.weight = torch.nn.Parameter(old_model.in_proj_weight[:embed_dim, :]) model.query_proj.weight = nn.Parameter(old_model.in_proj_weight[:embed_dim, :])
model.query_proj.bias = torch.nn.Parameter(old_model.in_proj_bias[:embed_dim]) model.query_proj.bias = nn.Parameter(old_model.in_proj_bias[:embed_dim])
elif attribute == "key_proj": elif attribute == "key_proj":
model.key_proj.weight = torch.nn.Parameter(old_model.in_proj_weight[embed_dim : 2 * embed_dim, :]) model.key_proj.weight = nn.Parameter(old_model.in_proj_weight[embed_dim : 2 * embed_dim, :])
model.key_proj.bias = torch.nn.Parameter(old_model.in_proj_bias[embed_dim : 2 * embed_dim]) model.key_proj.bias = nn.Parameter(old_model.in_proj_bias[embed_dim : 2 * embed_dim])
elif attribute == "value_proj": elif attribute == "value_proj":
model.value_proj.weight = torch.nn.Parameter(old_model.in_proj_weight[2 * embed_dim :, :]) model.value_proj.weight = nn.Parameter(old_model.in_proj_weight[2 * embed_dim :, :])
model.value_proj.bias = torch.nn.Parameter(old_model.in_proj_bias[2 * embed_dim :]) model.value_proj.bias = nn.Parameter(old_model.in_proj_bias[2 * embed_dim :])
is_key_init = True is_key_init = True
break break
elif attribute == "position_embeddings": elif attribute == "position_embeddings":
...@@ -123,7 +123,7 @@ def convert_prophetnet_checkpoint_to_pytorch(prophetnet_checkpoint_path: str, py ...@@ -123,7 +123,7 @@ def convert_prophetnet_checkpoint_to_pytorch(prophetnet_checkpoint_path: str, py
model.position_embeddings.weight.shape[-1] == old_model.embed_positions.weight.shape[-1] model.position_embeddings.weight.shape[-1] == old_model.embed_positions.weight.shape[-1]
), "Hidden size has to match" ), "Hidden size has to match"
assert model.position_embeddings.weight.shape[0] == 512, "We want 512 position_embeddings." assert model.position_embeddings.weight.shape[0] == 512, "We want 512 position_embeddings."
model.position_embeddings.weight = torch.nn.Parameter(old_model.embed_positions.weight[:512, :]) model.position_embeddings.weight = nn.Parameter(old_model.embed_positions.weight[:512, :])
is_key_init = True is_key_init = True
break break
......
...@@ -21,7 +21,6 @@ from dataclasses import dataclass ...@@ -21,7 +21,6 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
...@@ -183,9 +182,9 @@ PROPHETNET_STANDALONE_INPUTS_DOCSTRING = r""" ...@@ -183,9 +182,9 @@ PROPHETNET_STANDALONE_INPUTS_DOCSTRING = r"""
def softmax(hidden_state, dim, onnx_trace=False): def softmax(hidden_state, dim, onnx_trace=False):
if onnx_trace: if onnx_trace:
return F.softmax(hidden_state.float(), dim=dim) return nn.functional.softmax(hidden_state.float(), dim=dim)
else: else:
return F.softmax(hidden_state, dim=dim, dtype=torch.float32) return nn.functional.softmax(hidden_state, dim=dim, dtype=torch.float32)
def ngram_attention_bias(sequence_length, ngram, device, dtype): def ngram_attention_bias(sequence_length, ngram, device, dtype):
...@@ -732,7 +731,7 @@ class ProphetNetAttention(nn.Module): ...@@ -732,7 +731,7 @@ class ProphetNetAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
assert layer_head_mask.size() == ( assert layer_head_mask.size() == (
...@@ -746,7 +745,7 @@ class ProphetNetAttention(nn.Module): ...@@ -746,7 +745,7 @@ class ProphetNetAttention(nn.Module):
# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
attn_probs = F.dropout( attn_probs = nn.functional.dropout(
attn_weights, attn_weights,
p=self.attention_dropout, p=self.attention_dropout,
training=self.training, training=self.training,
...@@ -767,7 +766,7 @@ class ProphetNetAttention(nn.Module): ...@@ -767,7 +766,7 @@ class ProphetNetAttention(nn.Module):
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
...@@ -788,9 +787,9 @@ class ProphetNetFeedForward(nn.Module): ...@@ -788,9 +787,9 @@ class ProphetNetFeedForward(nn.Module):
hidden_states = self.intermediate(hidden_states) hidden_states = self.intermediate(hidden_states)
hidden_states = self.activation_fn(hidden_states) hidden_states = self.activation_fn(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.output(hidden_states) hidden_states = self.output(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return hidden_states return hidden_states
...@@ -924,7 +923,7 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -924,7 +923,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
) )
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length) main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)
main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output # project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states) main_attn_output = torch.bmm(main_attn_probs, main_value_states)
...@@ -989,7 +988,9 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -989,7 +988,9 @@ class ProphetNetNgramSelfAttention(nn.Module):
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
) )
predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training) predict_attn_probs = nn.functional.dropout(
predict_attn_probs, p=self.attention_dropout, training=self.training
)
# project to attention output # project to attention output
# [ngram, B*head, T, c] # [ngram, B*head, T, c]
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states))
...@@ -1012,7 +1013,7 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -1012,7 +1013,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
self.ngram, batch_size, self.num_attn_heads, sequence_length, -1 self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
).transpose(0, 1) ).transpose(0, 1)
attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
return attn_output, main_attn_probs, predict_attn_probs, past_key_value return attn_output, main_attn_probs, predict_attn_probs, past_key_value
...@@ -1321,7 +1322,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1321,7 +1322,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
hidden_states = inputs_embeds + position_embeddings hidden_states = inputs_embeds + position_embeddings
hidden_states = self.embeddings_layer_norm(hidden_states) hidden_states = self.embeddings_layer_norm(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.config.dropout, training=self.training)
encoder_hidden_states = () if output_hidden_states else None encoder_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -1538,7 +1539,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1538,7 +1539,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
if self.embeddings_layer_norm: if self.embeddings_layer_norm:
hidden_states = self.embeddings_layer_norm(hidden_states) hidden_states = self.embeddings_layer_norm(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# init attentions, hidden_states and cache with empty tuples # init attentions, hidden_states and cache with empty tuples
all_main_stream_hidden_states = () if output_hidden_states else None all_main_stream_hidden_states = () if output_hidden_states else None
...@@ -1995,13 +1996,13 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1995,13 +1996,13 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
break break
expend_targets[i, :, :] = labels expend_targets[i, :, :] = labels
lprobs = F.log_softmax( lprobs = nn.functional.log_softmax(
logits.view(-1, logits.size(-1)), logits.view(-1, logits.size(-1)),
dim=-1, dim=-1,
dtype=torch.float32, dtype=torch.float32,
) )
loss = F.nll_loss(lprobs, expend_targets.view(-1), reduction="mean") loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
if self.config.eps > 0.0: if self.config.eps > 0.0:
smooth_loss = -lprobs.sum(dim=-1, keepdim=True) smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
...@@ -2239,13 +2240,13 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -2239,13 +2240,13 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
break break
expend_targets[i, :, :] = labels expend_targets[i, :, :] = labels
lprobs = F.log_softmax( lprobs = nn.functional.log_softmax(
logits.view(-1, logits.size(-1)), logits.view(-1, logits.size(-1)),
dim=-1, dim=-1,
dtype=torch.float32, dtype=torch.float32,
) )
loss = F.nll_loss(lprobs, expend_targets.view(-1), reduction="mean") loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
if self.config.eps > 0.0: if self.config.eps > 0.0:
smooth_loss = -lprobs.sum(dim=-1, keepdim=True) smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
......
...@@ -18,6 +18,7 @@ from dataclasses import dataclass ...@@ -18,6 +18,7 @@ from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
from torch import nn
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings
...@@ -1065,10 +1066,10 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -1065,10 +1066,10 @@ class RagSequenceForGeneration(RagPreTrainedModel):
return ll.squeeze(-1), smooth_obj.squeeze(-1) return ll.squeeze(-1), smooth_obj.squeeze(-1)
# seq_logits dim = (batch*n_docs, tgt_len , #vocabs) # seq_logits dim = (batch*n_docs, tgt_len , #vocabs)
seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view( seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1) seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
) # batch_size x n_docs x tgt_len x #vocab_size ) # batch_size x n_docs x tgt_len x #vocab_size
doc_logprobs = torch.nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1) doc_logprobs = nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1)
# RAG-sequence marginalization # RAG-sequence marginalization
first_token_scores = seq_logprobs[:, :, :1, :] first_token_scores = seq_logprobs[:, :, :1, :]
...@@ -1212,7 +1213,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1212,7 +1213,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
n_docs = n_docs if n_docs is not None else self.config.n_docs n_docs = n_docs if n_docs is not None else self.config.n_docs
# RAG-token marginalization # RAG-token marginalization
seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view( seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1) seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
) )
doc_logprobs = torch.log_softmax(doc_scores, dim=1) doc_logprobs = torch.log_softmax(doc_scores, dim=1)
......
...@@ -20,6 +20,7 @@ import pickle ...@@ -20,6 +20,7 @@ import pickle
import numpy as np import numpy as np
import torch import torch
from torch import nn
from transformers import ReformerConfig, ReformerModelWithLMHead from transformers import ReformerConfig, ReformerModelWithLMHead
from transformers.utils import logging from transformers.utils import logging
...@@ -31,10 +32,10 @@ logging.set_verbosity_info() ...@@ -31,10 +32,10 @@ logging.set_verbosity_info()
def set_param(torch_layer, weight, bias=None): def set_param(torch_layer, weight, bias=None):
# set parameter of one layer # set parameter of one layer
assert torch_layer.weight.shape == weight.shape, f"{torch_layer} layer.weight does not match" assert torch_layer.weight.shape == weight.shape, f"{torch_layer} layer.weight does not match"
torch_layer.weight = torch.nn.Parameter(weight) torch_layer.weight = nn.Parameter(weight)
if bias is not None: if bias is not None:
assert torch_layer.bias.shape == bias.shape, f"{torch_layer} layer.bias does not match" assert torch_layer.bias.shape == bias.shape, f"{torch_layer} layer.bias does not match"
torch_layer.bias = torch.nn.Parameter(bias) torch_layer.bias = nn.Parameter(bias)
def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size): def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size):
...@@ -153,7 +154,7 @@ def set_model_weights_in_torch(weights, torch_model, hidden_size): ...@@ -153,7 +154,7 @@ def set_model_weights_in_torch(weights, torch_model, hidden_size):
assert ( assert (
position_embeddings.weights[emb_idx].shape == emb_weights.shape position_embeddings.weights[emb_idx].shape == emb_weights.shape
), f"{position_embeddings[emb_idx]} emb does not match" ), f"{position_embeddings[emb_idx]} emb does not match"
position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) position_embeddings.weights[emb_idx] = nn.Parameter(torch.tensor(emb_weights))
trax_layer_weights = weights[5] trax_layer_weights = weights[5]
assert len(torch_model_reformer.encoder.layers) * 4 == len( assert len(torch_model_reformer.encoder.layers) * 4 == len(
......
...@@ -1782,7 +1782,7 @@ class ReformerPreTrainedModel(PreTrainedModel): ...@@ -1782,7 +1782,7 @@ class ReformerPreTrainedModel(PreTrainedModel):
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, AxialPositionEmbeddings): if isinstance(module, AxialPositionEmbeddings):
for weight in module.weights: for weight in module.weights:
torch.nn.init.normal_(weight, std=self.config.axial_norm_std) nn.init.normal_(weight, std=self.config.axial_norm_std)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None: if module.padding_idx is not None:
......
...@@ -20,8 +20,8 @@ RetriBERT model ...@@ -20,8 +20,8 @@ RetriBERT model
import math import math
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from torch import nn
from ...file_utils import add_start_docstrings from ...file_utils import add_start_docstrings
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
import math import math
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu from ...activations import ACT2FN, gelu
......
...@@ -20,7 +20,6 @@ import random ...@@ -20,7 +20,6 @@ import random
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -306,7 +305,7 @@ class Speech2TextAttention(nn.Module): ...@@ -306,7 +305,7 @@ class Speech2TextAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -326,7 +325,7 @@ class Speech2TextAttention(nn.Module): ...@@ -326,7 +325,7 @@ class Speech2TextAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -387,15 +386,15 @@ class Speech2TextEncoderLayer(nn.Module): ...@@ -387,15 +386,15 @@ class Speech2TextEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and ( if hidden_states.dtype == torch.float16 and (
...@@ -482,7 +481,7 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -482,7 +481,7 @@ class Speech2TextDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Cross-Attention Block # Cross-Attention Block
...@@ -502,7 +501,7 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -502,7 +501,7 @@ class Speech2TextDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple # add cross-attn to positions 3,4 of present_key_value tuple
...@@ -512,9 +511,9 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -512,9 +511,9 @@ class Speech2TextDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
...@@ -686,7 +685,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): ...@@ -686,7 +685,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
Args: Args:
config: Speech2TextConfig config: Speech2TextConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: Speech2TextConfig): def __init__(self, config: Speech2TextConfig):
...@@ -772,7 +771,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): ...@@ -772,7 +771,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
embed_pos = self.embed_positions(padding_mask) embed_pos = self.embed_positions(padding_mask)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -840,7 +839,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -840,7 +839,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
Args: Args:
config: Speech2TextConfig config: Speech2TextConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: Speech2TextConfig): def __init__(self, config: Speech2TextConfig):
...@@ -1008,7 +1007,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -1008,7 +1007,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -92,7 +92,7 @@ class SqueezeBertEmbeddings(nn.Module): ...@@ -92,7 +92,7 @@ class SqueezeBertEmbeddings(nn.Module):
return embeddings return embeddings
class MatMulWrapper(torch.nn.Module): class MatMulWrapper(nn.Module):
""" """
Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call
torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul. torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul.
......
...@@ -21,7 +21,6 @@ import os ...@@ -21,7 +21,6 @@ import os
import warnings import warnings
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
...@@ -179,7 +178,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): ...@@ -179,7 +178,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
#################################################### ####################################################
# PyTorch Models are constructed by sub-classing # PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and # - torch.nn.Module for the layers and
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module) # - PreTrainedModel for the models (it-self a sub-class of nn.Module)
#################################################### ####################################################
PARALLELIZE_DOCSTRING = r""" PARALLELIZE_DOCSTRING = r"""
This is an experimental feature and is a subject to change at a moment's notice. This is an experimental feature and is a subject to change at a moment's notice.
...@@ -257,7 +256,7 @@ class T5DenseReluDense(nn.Module): ...@@ -257,7 +256,7 @@ class T5DenseReluDense(nn.Module):
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.wi(hidden_states) hidden_states = self.wi(hidden_states)
hidden_states = F.relu(hidden_states) hidden_states = nn.functional.relu(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states) hidden_states = self.wo(hidden_states)
return hidden_states return hidden_states
...@@ -502,10 +501,10 @@ class T5Attention(nn.Module): ...@@ -502,10 +501,10 @@ class T5Attention(nn.Module):
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
scores += position_bias scores += position_bias
attn_weights = F.softmax(scores.float(), dim=-1).type_as( attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores scores
) # (batch_size, n_heads, seq_length, key_length) ) # (batch_size, n_heads, seq_length, key_length)
attn_weights = F.dropout( attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length) ) # (batch_size, n_heads, seq_length, key_length)
......
...@@ -22,8 +22,8 @@ from dataclasses import dataclass ...@@ -22,8 +22,8 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
...@@ -2096,10 +2096,8 @@ def _calculate_aggregation_loss_known( ...@@ -2096,10 +2096,8 @@ def _calculate_aggregation_loss_known(
# Use aggregation supervision as the target. # Use aggregation supervision as the target.
target_aggregation = aggregation_labels target_aggregation = aggregation_labels
one_hot_labels = torch.nn.functional.one_hot(target_aggregation, num_classes=num_aggregation_labels).type( one_hot_labels = nn.functional.one_hot(target_aggregation, num_classes=num_aggregation_labels).type(torch.float32)
torch.float32 log_probs = nn.functional.log_softmax(logits_aggregation, dim=-1)
)
log_probs = torch.nn.functional.log_softmax(logits_aggregation, dim=-1)
# torch.FloatTensor[batch_size] # torch.FloatTensor[batch_size]
per_example_aggregation_intermediate = -torch.sum(one_hot_labels * log_probs, dim=-1) per_example_aggregation_intermediate = -torch.sum(one_hot_labels * log_probs, dim=-1)
...@@ -2243,7 +2241,7 @@ def _calculate_expected_result( ...@@ -2243,7 +2241,7 @@ def _calculate_expected_result(
aggregation_op_only_probs = gumbel_dist.sample() aggregation_op_only_probs = gumbel_dist.sample()
else: else:
# <float32>[batch_size, num_aggregation_labels - 1] # <float32>[batch_size, num_aggregation_labels - 1]
aggregation_op_only_probs = torch.nn.functional.softmax( aggregation_op_only_probs = nn.functional.softmax(
logits_aggregation[:, 1:] / config.aggregation_temperature, dim=-1 logits_aggregation[:, 1:] / config.aggregation_temperature, dim=-1
) )
......
...@@ -21,8 +21,7 @@ from dataclasses import dataclass ...@@ -21,8 +21,7 @@ from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...file_utils import ( from ...file_utils import (
...@@ -344,7 +343,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module): ...@@ -344,7 +343,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score) attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score)
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1) attn_prob = nn.functional.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob) attn_prob = self.dropatt(attn_prob)
# Mask heads if we want to # Mask heads if we want to
...@@ -434,7 +433,7 @@ class AdaptiveEmbedding(nn.Module): ...@@ -434,7 +433,7 @@ class AdaptiveEmbedding(nn.Module):
if self.div_val == 1: if self.div_val == 1:
embed = self.emb_layers[0](inp) embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed: if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0]) embed = nn.functional.linear(embed, self.emb_projs[0])
else: else:
param = next(self.parameters()) param = next(self.parameters())
inp_flat = inp.view(-1) inp_flat = inp.view(-1)
...@@ -450,7 +449,7 @@ class AdaptiveEmbedding(nn.Module): ...@@ -450,7 +449,7 @@ class AdaptiveEmbedding(nn.Module):
inp_i = inp_flat.index_select(0, indices_i) - l_idx inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i) emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i]) emb_i = nn.functional.linear(emb_i, self.emb_projs[i])
emb_flat.index_copy_(0, indices_i, emb_i) emb_flat.index_copy_(0, indices_i, emb_i)
......
...@@ -19,8 +19,7 @@ ...@@ -19,8 +19,7 @@
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
# CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) # CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
...@@ -71,11 +70,11 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -71,11 +70,11 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
def _compute_logit(self, hidden, weight, bias, proj): def _compute_logit(self, hidden, weight, bias, proj):
if proj is None: if proj is None:
logit = F.linear(hidden, weight, bias=bias) logit = nn.functional.linear(hidden, weight, bias=bias)
else: else:
# if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
proj_hid = F.linear(hidden, proj.t().contiguous()) proj_hid = nn.functional.linear(hidden, proj.t().contiguous())
logit = F.linear(proj_hid, weight, bias=bias) logit = nn.functional.linear(proj_hid, weight, bias=bias)
# else: # else:
# logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
# if bias is not None: # if bias is not None:
...@@ -110,9 +109,9 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -110,9 +109,9 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if self.n_clusters == 0: if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0]) logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
if labels is not None: if labels is not None:
out = -F.log_softmax(logit, dim=-1).gather(1, labels.unsqueeze(1)).squeeze(1) out = -nn.functional.log_softmax(logit, dim=-1).gather(1, labels.unsqueeze(1)).squeeze(1)
else: else:
out = F.log_softmax(logit, dim=-1) out = nn.functional.log_softmax(logit, dim=-1)
else: else:
# construct weights and biases # construct weights and biases
weights, biases = [], [] weights, biases = [], []
...@@ -135,7 +134,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -135,7 +134,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1) head_logprob = nn.functional.log_softmax(head_logit, dim=1)
if labels is None: if labels is None:
out = hidden.new_empty((head_logit.size(0), self.n_token)) out = hidden.new_empty((head_logit.size(0), self.n_token))
...@@ -169,7 +168,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -169,7 +168,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) tail_logprob_i = nn.functional.log_softmax(tail_logit_i, dim=1)
cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
if labels is not None: if labels is not None:
logprob_i = head_logprob_i[:, cluster_prob_idx] + tail_logprob_i.gather( logprob_i = head_logprob_i[:, cluster_prob_idx] + tail_logprob_i.gather(
...@@ -205,7 +204,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -205,7 +204,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
""" """
if self.n_clusters == 0: if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0]) logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
return F.log_softmax(logit, dim=-1) return nn.functional.log_softmax(logit, dim=-1)
else: else:
# construct weights and biases # construct weights and biases
weights, biases = [], [] weights, biases = [], []
...@@ -229,7 +228,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -229,7 +228,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
out = hidden.new_empty((head_logit.size(0), self.n_token)) out = hidden.new_empty((head_logit.size(0), self.n_token))
head_logprob = F.log_softmax(head_logit, dim=1) head_logprob = nn.functional.log_softmax(head_logit, dim=1)
cutoff_values = [0] + self.cutoffs cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1): for i in range(len(cutoff_values) - 1):
...@@ -241,7 +240,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -241,7 +240,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i) tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) tail_logprob_i = nn.functional.log_softmax(tail_logit_i, dim=1)
logprob_i = head_logprob[:, -i] + tail_logprob_i logprob_i = head_logprob[:, -i] + tail_logprob_i
out[:, start_idx, stop_idx] = logprob_i out[:, start_idx, stop_idx] = logprob_i
......
...@@ -89,10 +89,10 @@ class VisualBertEmbeddings(nn.Module): ...@@ -89,10 +89,10 @@ class VisualBertEmbeddings(nn.Module):
self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
if config.special_visual_initialize: if config.special_visual_initialize:
self.visual_token_type_embeddings.weight.data = torch.nn.Parameter( self.visual_token_type_embeddings.weight.data = nn.Parameter(
self.token_type_embeddings.weight.data.clone(), requires_grad=True self.token_type_embeddings.weight.data.clone(), requires_grad=True
) )
self.visual_position_embeddings.weight.data = torch.nn.Parameter( self.visual_position_embeddings.weight.data = nn.Parameter(
self.position_embeddings.weight.data.clone(), requires_grad=True self.position_embeddings.weight.data.clone(), requires_grad=True
) )
...@@ -1253,8 +1253,8 @@ class VisualBertForQuestionAnswering(VisualBertPreTrainedModel): ...@@ -1253,8 +1253,8 @@ class VisualBertForQuestionAnswering(VisualBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = torch.nn.KLDivLoss(reduction="batchmean") loss_fct = nn.KLDivLoss(reduction="batchmean")
log_softmax = torch.nn.LogSoftmax(dim=-1) log_softmax = nn.LogSoftmax(dim=-1)
reshaped_logits = log_softmax(reshaped_logits) reshaped_logits = log_softmax(reshaped_logits)
loss = loss_fct(reshaped_logits, labels.contiguous()) loss = loss_fct(reshaped_logits, labels.contiguous())
if not return_dict: if not return_dict:
......
...@@ -20,7 +20,6 @@ from typing import Optional, Tuple, Union ...@@ -20,7 +20,6 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
...@@ -449,7 +448,7 @@ class Wav2Vec2Attention(nn.Module): ...@@ -449,7 +448,7 @@ class Wav2Vec2Attention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -469,7 +468,7 @@ class Wav2Vec2Attention(nn.Module): ...@@ -469,7 +468,7 @@ class Wav2Vec2Attention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -805,9 +804,9 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): ...@@ -805,9 +804,9 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
if self.training: if self.training:
# sample code vector probs via gumbel in differentiateable way # sample code vector probs via gumbel in differentiateable way
codevector_probs = F.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True).type_as( codevector_probs = nn.functional.gumbel_softmax(
hidden_states hidden_states.float(), tau=self.temperature, hard=True
) ).type_as(hidden_states)
# compute perplexity # compute perplexity
codevector_soft_dist = torch.softmax( codevector_soft_dist = torch.softmax(
...@@ -867,12 +866,12 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -867,12 +866,12 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
if hasattr(module, "weight_v") and hasattr(module, "weight_g"): if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
torch.nn.init.kaiming_normal_(module.weight.data) nn.init.kaiming_normal_(module.weight.data)
else: else:
with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
torch.nn.init.kaiming_normal_(module.weight.data) nn.init.kaiming_normal_(module.weight.data)
else: else:
torch.nn.init.kaiming_normal_(module.weight.data) nn.init.kaiming_normal_(module.weight.data)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
...@@ -1296,7 +1295,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): ...@@ -1296,7 +1295,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
preds = logits.transpose(0, 2).reshape(-1, logits.size(0)) preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten() target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
contrastive_loss = F.cross_entropy(preds.float(), target, reduction="sum") contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="sum")
# 7. compute diversity loss: \mathbf{L}_d # 7. compute diversity loss: \mathbf{L}_d
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
...@@ -1502,10 +1501,10 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ...@@ -1502,10 +1501,10 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
flattened_targets = labels.masked_select(labels_mask) flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16 # ctc_loss doesn't support fp16
log_probs = F.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False): with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss( loss = nn.functional.ctc_loss(
log_probs, log_probs,
flattened_targets, flattened_targets,
input_lengths, input_lengths,
......
...@@ -503,7 +503,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -503,7 +503,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# encoder attention (for decoder only) # encoder attention (for decoder only)
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
# attn = F.dropout(attn, p=self.dropout, training=self.training) # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
# tensor = tensor + attn # tensor = tensor + attn
# tensor = self.layer_norm15[i](tensor) # tensor = self.layer_norm15[i](tensor)
......
...@@ -25,7 +25,6 @@ import numpy as np ...@@ -25,7 +25,6 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from ...activations import gelu from ...activations import gelu
from ...file_utils import ( from ...file_utils import (
...@@ -190,8 +189,8 @@ class MultiHeadAttention(nn.Module): ...@@ -190,8 +189,8 @@ class MultiHeadAttention(nn.Module):
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, qlen, klen)
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if head_mask is not None:
...@@ -212,7 +211,7 @@ class TransformerFFN(nn.Module): ...@@ -212,7 +211,7 @@ class TransformerFFN(nn.Module):
self.dropout = config.dropout self.dropout = config.dropout
self.lin1 = nn.Linear(in_dim, dim_hidden) self.lin1 = nn.Linear(in_dim, dim_hidden)
self.lin2 = nn.Linear(dim_hidden, out_dim) self.lin2 = nn.Linear(dim_hidden, out_dim)
self.act = gelu if config.gelu_activation else F.relu self.act = gelu if config.gelu_activation else nn.functional.relu
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
...@@ -223,7 +222,7 @@ class TransformerFFN(nn.Module): ...@@ -223,7 +222,7 @@ class TransformerFFN(nn.Module):
x = self.lin1(input) x = self.lin1(input)
x = self.act(x) x = self.act(x)
x = self.lin2(x) x = self.lin2(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = nn.functional.dropout(x, p=self.dropout, training=self.training)
return x return x
...@@ -578,7 +577,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -578,7 +577,7 @@ class XLMModel(XLMPreTrainedModel):
if token_type_ids is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(token_type_ids)
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = F.dropout(tensor, p=self.dropout, training=self.training) tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training)
tensor *= mask.unsqueeze(-1).to(tensor.dtype) tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers # transformer layers
...@@ -599,14 +598,14 @@ class XLMModel(XLMPreTrainedModel): ...@@ -599,14 +598,14 @@ class XLMModel(XLMPreTrainedModel):
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = F.dropout(attn, p=self.dropout, training=self.training) attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
# encoder attention (for decoder only) # encoder attention (for decoder only)
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
# attn = F.dropout(attn, p=self.dropout, training=self.training) # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
# tensor = tensor + attn # tensor = tensor + attn
# tensor = self.layer_norm15[i](tensor) # tensor = self.layer_norm15[i](tensor)
...@@ -661,7 +660,9 @@ class XLMPredLayer(nn.Module): ...@@ -661,7 +660,9 @@ class XLMPredLayer(nn.Module):
scores = self.proj(x) scores = self.proj(x)
outputs = (scores,) + outputs outputs = (scores,) + outputs
if y is not None: if y is not None:
loss = F.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="elementwise_mean") loss = nn.functional.cross_entropy(
scores.view(-1, self.n_words), y.view(-1), reduction="elementwise_mean"
)
outputs = (loss,) + outputs outputs = (loss,) + outputs
else: else:
scores = self.proj.log_prob(x) scores = self.proj.log_prob(x)
......
...@@ -23,7 +23,6 @@ from typing import List, Optional, Tuple ...@@ -23,7 +23,6 @@ from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -305,7 +304,7 @@ class XLNetRelativeAttention(nn.Module): ...@@ -305,7 +304,7 @@ class XLNetRelativeAttention(nn.Module):
attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask) attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask)
# attention probability # attention probability
attn_prob = F.softmax(attn_score, dim=3) attn_prob = nn.functional.softmax(attn_score, dim=3)
attn_prob = self.dropout(attn_prob) attn_prob = self.dropout(attn_prob)
# Mask heads if we want to # Mask heads if we want to
...@@ -1208,7 +1207,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1208,7 +1207,7 @@ class XLNetModel(XLNetPreTrainedModel):
# `1` indicates not in the same segment [qlen x klen x bsz] # `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long() seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float) seg_mat = nn.functional.one_hot(seg_mat, num_classes=2).to(dtype_float)
else: else:
seg_mat = None seg_mat = None
...@@ -2034,7 +2033,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -2034,7 +2033,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
else: else:
# during inference, compute the end logits based on beam search # during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size() bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk( start_top_log_probs, start_top_index = torch.topk(
start_log_probs, self.start_n_top, dim=-1 start_log_probs, self.start_n_top, dim=-1
...@@ -2048,7 +2047,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -2048,7 +2047,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
) # shape (bsz, slen, start_n_top, hsz) ) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk( end_top_log_probs, end_top_index = torch.topk(
end_log_probs, self.end_n_top, dim=1 end_log_probs, self.end_n_top, dim=1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment