Unverified Commit 1ed2ebf6 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[style] consistent nn. and nn.functional (#12124)

* consistent nn. and nn.functional

* fix glitch

* fix glitch #2
parent ff7c8168
......@@ -15,8 +15,8 @@
import math
import torch
import torch.nn.functional as F
from packaging import version
from torch import nn
from .utils import logging
......@@ -28,8 +28,8 @@ def _gelu_python(x):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in
torch.nn.functional Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
......@@ -45,7 +45,7 @@ def gelu_new(x):
if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python
else:
gelu = F.gelu
gelu = nn.functional.gelu
def gelu_fast(x):
......@@ -70,11 +70,11 @@ def _silu_python(x):
if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python
else:
silu = F.silu
silu = nn.functional.silu
def mish(x):
return x * torch.tanh(torch.nn.functional.softplus(x))
return x * torch.tanh(nn.functional.softplus(x))
def linear_act(x):
......@@ -82,7 +82,7 @@ def linear_act(x):
ACT2FN = {
"relu": F.relu,
"relu": nn.functional.relu,
"silu": silu,
"swish": silu,
"gelu": gelu,
......
......@@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.nn import functional as F
from torch import nn
from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer
......@@ -1564,7 +1564,7 @@ class GenerationMixin:
)
# sample
probs = F.softmax(next_token_scores, dim=-1)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
......@@ -1801,9 +1801,11 @@ class GenerationMixin:
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
......@@ -2098,9 +2100,11 @@ class GenerationMixin:
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
......@@ -2128,7 +2132,7 @@ class GenerationMixin:
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = F.softmax(next_token_scores, dim=-1)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
......@@ -2426,9 +2430,11 @@ class GenerationMixin:
next_token_logits = outputs.logits[batch_group_indices, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1]
next_token_scores = logits_processor(
......
......@@ -4,6 +4,7 @@ import inspect
from typing import Any, Dict, List, Optional, Union
import torch
from torch import nn
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
from torch.fx.node import Argument
......@@ -277,7 +278,7 @@ class HFTracer(Tracer):
return path
def path_of_module(self, mod: torch.nn.Module) -> str:
def path_of_module(self, mod: nn.Module) -> str:
"""
Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if
``root`` has a submodule named ``foo``, which has a submodule named ``bar``, passing ``bar`` into this function
......
......@@ -25,7 +25,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from torch import Tensor, device, dtype, nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from .activations import get_activation
from .configuration_utils import PretrainedConfig
......@@ -355,9 +354,7 @@ class ModuleUtilsMixin:
"""
def parameter_filter(x):
return (x.requires_grad or not only_trainable) and not (
isinstance(x, torch.nn.Embedding) and exclude_embeddings
)
return (x.requires_grad or not only_trainable) and not (isinstance(x, nn.Embedding) and exclude_embeddings)
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
......@@ -549,7 +546,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
):
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight
......@@ -613,7 +610,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
output_embeddings.weight = input_embeddings.weight
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
(
0,
......@@ -625,7 +622,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
"""
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
......@@ -668,8 +665,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return self.get_input_embeddings()
def _get_resized_embeddings(
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
) -> torch.nn.Embedding:
self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None
) -> nn.Embedding:
"""
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
initialized vectors at the end. Reducing the size will remove vectors from the end
......@@ -732,8 +729,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return new_embeddings
def _get_resized_lm_head(
self, old_lm_head: torch.nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
) -> torch.nn.Linear:
self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
) -> nn.Linear:
"""
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end
......@@ -1681,7 +1678,7 @@ class SQuADHead(nn.Module):
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(
start_log_probs, self.start_n_top, dim=-1
......@@ -1695,7 +1692,7 @@ class SQuADHead(nn.Module):
) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(
end_log_probs, self.end_n_top, dim=1
......@@ -1820,7 +1817,7 @@ class SequenceSummary(nn.Module):
return output
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
......@@ -1834,7 +1831,7 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
return model
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
"""
Prune a linear layer to keep only entries in index.
......@@ -1902,8 +1899,8 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) ->
def prune_layer(
layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
) -> Union[torch.nn.Linear, Conv1D]:
layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
) -> Union[nn.Linear, Conv1D]:
"""
Prune a Conv1D or linear layer to keep only entries in index.
......
......@@ -20,7 +20,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
......
......@@ -20,7 +20,6 @@ import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
......@@ -223,7 +222,7 @@ class BartAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
......@@ -243,7 +242,7 @@ class BartAttention(nn.Module):
else:
attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
......@@ -303,15 +302,15 @@ class BartEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
......@@ -398,7 +397,7 @@ class BartDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
......@@ -418,7 +417,7 @@ class BartDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
......@@ -428,9 +427,9 @@ class BartDecoderLayer(nn.Module):
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
......@@ -661,7 +660,7 @@ class BartEncoder(BartPretrainedModel):
Args:
config: BartConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
......@@ -760,7 +759,7 @@ class BartEncoder(BartPretrainedModel):
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
......@@ -826,7 +825,7 @@ class BartDecoder(BartPretrainedModel):
Args:
config: BartConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
......@@ -997,7 +996,7 @@ class BartDecoder(BartPretrainedModel):
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
......
......@@ -139,7 +139,7 @@ class BertGenerationEmbeddings(nn.Module):
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
......
......@@ -22,7 +22,6 @@ from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
......@@ -379,7 +378,7 @@ class BigBirdSelfAttention(nn.Module):
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......@@ -608,7 +607,9 @@ class BigBirdBlockSparseAttention(nn.Module):
first_product = first_product * rsqrt_d
first_product += (1.0 - to_mask) * attn_mask_penalty
first_attn_weights = F.softmax(first_product, dim=-1) # [bsz, n_heads, from_block_size, to_seq_len]
first_attn_weights = nn.functional.softmax(
first_product, dim=-1
) # [bsz, n_heads, from_block_size, to_seq_len]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4)
......@@ -660,7 +661,7 @@ class BigBirdBlockSparseAttention(nn.Module):
)
second_product = second_product * rsqrt_d
second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty
second_attn_weights = F.softmax(
second_attn_weights = nn.functional.softmax(
second_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
......@@ -721,7 +722,7 @@ class BigBirdBlockSparseAttention(nn.Module):
) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size]
# safely doing softmax since attention matrix is completed
attn_weights = F.softmax(
attn_weights = nn.functional.softmax(
band_product, dim=-1
) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size]
......@@ -794,7 +795,7 @@ class BigBirdBlockSparseAttention(nn.Module):
)
second_last_product = second_last_product * rsqrt_d
second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty
second_last_attn_weights = F.softmax(
second_last_attn_weights = nn.functional.softmax(
second_last_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
......@@ -810,7 +811,7 @@ class BigBirdBlockSparseAttention(nn.Module):
last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4)
last_product = last_product * rsqrt_d
last_product += (1.0 - to_mask) * attn_mask_penalty
last_attn_weights = F.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n]
last_attn_weights = nn.functional.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4)
......@@ -2210,10 +2211,10 @@ class BigBirdModel(BigBirdPreTrainedModel):
f"`config.block_size`: {block_size}"
)
if input_ids is not None:
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id)
if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_bigbird.BigBirdEmbeddings
position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id)
position_ids = nn.functional.pad(position_ids, (0, padding_len), value=pad_token_id)
if inputs_embeds is not None:
input_ids_padding = inputs_embeds.new_full(
(batch_size, padding_len),
......@@ -2223,8 +2224,10 @@ class BigBirdModel(BigBirdPreTrainedModel):
inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
attention_mask = nn.functional.pad(
attention_mask, (0, padding_len), value=False
) # no attention on the padding tokens
token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
......
......@@ -22,7 +22,6 @@ from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
......@@ -206,7 +205,7 @@ class BigBirdPegasusSelfAttention(nn.Module):
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......@@ -436,7 +435,9 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
first_product = first_product * rsqrt_d
first_product += (1.0 - to_mask) * attn_mask_penalty
first_attn_weights = F.softmax(first_product, dim=-1) # [bsz, n_heads, from_block_size, to_seq_len]
first_attn_weights = nn.functional.softmax(
first_product, dim=-1
) # [bsz, n_heads, from_block_size, to_seq_len]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4)
......@@ -488,7 +489,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
)
second_product = second_product * rsqrt_d
second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty
second_attn_weights = F.softmax(
second_attn_weights = nn.functional.softmax(
second_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
......@@ -549,7 +550,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size]
# safely doing softmax since attention matrix is completed
attn_weights = F.softmax(
attn_weights = nn.functional.softmax(
band_product, dim=-1
) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size]
......@@ -622,7 +623,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
)
second_last_product = second_last_product * rsqrt_d
second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty
second_last_attn_weights = F.softmax(
second_last_attn_weights = nn.functional.softmax(
second_last_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
......@@ -638,7 +639,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4)
last_product = last_product * rsqrt_d
last_product += (1.0 - to_mask) * attn_mask_penalty
last_attn_weights = F.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n]
last_attn_weights = nn.functional.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4)
......@@ -1295,7 +1296,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
......@@ -1315,7 +1316,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
else:
attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
......@@ -1384,7 +1385,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
)
hidden_states = self_attention_outputs[0]
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
......@@ -1392,7 +1393,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
......@@ -1492,7 +1493,7 @@ class BigBirdPegasusDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Cross-Attention Block
......@@ -1512,7 +1513,7 @@ class BigBirdPegasusDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple
......@@ -1522,9 +1523,9 @@ class BigBirdPegasusDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
......@@ -1733,7 +1734,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
Args:
config: BigBirdPegasusConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
......@@ -1829,7 +1830,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=hidden_states.device)
......@@ -2015,7 +2016,9 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
inputs_embeds_padding = self.embed_tokens(input_ids_padding)
hidden_states = torch.cat([hidden_states, inputs_embeds_padding], dim=-2)
attention_mask = F.pad(attention_mask, (0, padding_len), value=0) # no attention on the padding tokens
attention_mask = nn.functional.pad(
attention_mask, (0, padding_len), value=0
) # no attention on the padding tokens
return padding_len, hidden_states, attention_mask
......@@ -2027,7 +2030,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
Args:
config: BigBirdPegasusConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
......@@ -2198,7 +2201,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
......
......@@ -23,7 +23,6 @@ import warnings
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
......@@ -224,7 +223,7 @@ class BlenderbotAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
......@@ -244,7 +243,7 @@ class BlenderbotAttention(nn.Module):
else:
attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
......@@ -306,15 +305,15 @@ class BlenderbotEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
......@@ -402,7 +401,7 @@ class BlenderbotDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Cross-Attention Block
......@@ -422,7 +421,7 @@ class BlenderbotDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple
......@@ -432,9 +431,9 @@ class BlenderbotDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
......@@ -617,7 +616,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
Args:
config: BlenderbotConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
......@@ -715,7 +714,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
......@@ -784,7 +783,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
Args:
config: BlenderbotConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
......@@ -956,7 +955,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
......
......@@ -21,7 +21,6 @@ import random
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
......@@ -222,7 +221,7 @@ class BlenderbotSmallAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
......@@ -242,7 +241,7 @@ class BlenderbotSmallAttention(nn.Module):
else:
attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
......@@ -303,15 +302,15 @@ class BlenderbotSmallEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
......@@ -399,7 +398,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
......@@ -419,7 +418,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
......@@ -429,9 +428,9 @@ class BlenderbotSmallDecoderLayer(nn.Module):
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
......@@ -618,7 +617,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
Args:
config: BlenderbotSmallConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
......@@ -717,7 +716,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
......@@ -784,7 +783,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
Args:
config: BlenderbotSmallConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
......@@ -957,7 +956,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
inputs_embeds = self.layernorm_embedding(inputs_embeds)
hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
......
......@@ -21,6 +21,7 @@ import os
import numpy as np
import torch
from packaging import version
from torch import nn
import gluonnlp as nlp
import mxnet as mx
......@@ -170,8 +171,8 @@ def convert_bort_checkpoint_to_pytorch(bort_checkpoint_path: str, pytorch_dump_f
# | `encoder.transformer_cells.*.proj.weight` | `bert.encoder.layer.*.output.dense.weight`
# Helper function to convert MXNET Arrays to PyTorch
def to_torch(mx_array) -> torch.nn.Parameter:
return torch.nn.Parameter(torch.FloatTensor(mx_array.data().asnumpy()))
def to_torch(mx_array) -> nn.Parameter:
return nn.Parameter(torch.FloatTensor(mx_array.data().asnumpy()))
# Check param shapes and map new HF param back
def check_and_map_params(hf_param, gluon_param):
......
......@@ -18,7 +18,6 @@
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
......@@ -62,7 +61,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
return -neg_ce.mean()
......@@ -235,7 +234,7 @@ class CLIPAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
......@@ -247,7 +246,7 @@ class CLIPAttention(nn.Module):
else:
attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
......@@ -493,7 +492,7 @@ class CLIPEncoder(nn.Module):
Args:
config: CLIPConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: CLIPConfig):
......
......@@ -383,7 +383,7 @@ class ConvBertSelfAttention(nn.Module):
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -19,7 +19,7 @@ from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
......@@ -87,7 +87,7 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
return output, attention_weights
class MultiHeadAttention(torch.nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, d_model_size, num_heads):
super().__init__()
self.num_heads = num_heads
......@@ -95,11 +95,11 @@ class MultiHeadAttention(torch.nn.Module):
self.depth = int(d_model_size / self.num_heads)
self.Wq = torch.nn.Linear(d_model_size, d_model_size)
self.Wk = torch.nn.Linear(d_model_size, d_model_size)
self.Wv = torch.nn.Linear(d_model_size, d_model_size)
self.Wq = nn.Linear(d_model_size, d_model_size)
self.Wk = nn.Linear(d_model_size, d_model_size)
self.Wv = nn.Linear(d_model_size, d_model_size)
self.dense = torch.nn.Linear(d_model_size, d_model_size)
self.dense = nn.Linear(d_model_size, d_model_size)
self.pruned_heads = set()
def prune_heads(self, heads):
......@@ -167,21 +167,21 @@ class MultiHeadAttention(torch.nn.Module):
def point_wise_feed_forward_network(d_model_size, dff):
return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff), torch.nn.ReLU(), torch.nn.Linear(dff, d_model_size))
return nn.Sequential(nn.Linear(d_model_size, dff), nn.ReLU(), nn.Linear(dff, d_model_size))
class EncoderLayer(torch.nn.Module):
class EncoderLayer(nn.Module):
def __init__(self, d_model_size, num_heads, dff, rate=0.1):
super().__init__()
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads)
self.ffn = point_wise_feed_forward_network(d_model_size, dff)
self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
self.layernorm2 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
self.layernorm1 = nn.LayerNorm(d_model_size, eps=1e-6)
self.layernorm2 = nn.LayerNorm(d_model_size, eps=1e-6)
self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate)
self.dropout1 = nn.Dropout(rate)
self.dropout2 = nn.Dropout(rate)
def forward(
self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
......
......@@ -163,7 +163,7 @@ class XDropout(torch.autograd.Function):
return grad_output, None
class StableDropout(torch.nn.Module):
class StableDropout(nn.Module):
"""
Optimized dropout module for stabilizing the training
......@@ -477,7 +477,7 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
class DisentangledSelfAttention(torch.nn.Module):
class DisentangledSelfAttention(nn.Module):
"""
Disentangled self-attention module
......@@ -498,19 +498,17 @@ class DisentangledSelfAttention(torch.nn.Module):
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
self.relative_attention = getattr(config, "relative_attention", False)
self.talking_head = getattr(config, "talking_head", False)
if self.talking_head:
self.head_logits_proj = torch.nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
self.head_weights_proj = torch.nn.Linear(
config.num_attention_heads, config.num_attention_heads, bias=False
)
self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
if self.relative_attention:
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
......@@ -519,9 +517,9 @@ class DisentangledSelfAttention(torch.nn.Module):
self.pos_dropout = StableDropout(config.hidden_dropout_prob)
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_proj = torch.nn.Linear(config.hidden_size, self.all_head_size, bias=False)
self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_q_proj = torch.nn.Linear(config.hidden_size, self.all_head_size)
self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = StableDropout(config.attention_probs_dropout_prob)
......@@ -1122,7 +1120,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
self.pooler = ContextPooler(config)
output_dim = self.pooler.output_dim
self.classifier = torch.nn.Linear(output_dim, num_labels)
self.classifier = nn.Linear(output_dim, num_labels)
drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
......@@ -1182,7 +1180,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
if labels is not None:
if self.num_labels == 1:
# regression task
loss_fn = torch.nn.MSELoss()
loss_fn = nn.MSELoss()
logits = logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1))
elif labels.dim() == 1 or labels.size(-1) == 1:
......@@ -1196,7 +1194,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
else:
loss = torch.tensor(0).to(logits)
else:
log_softmax = torch.nn.LogSoftmax(-1)
log_softmax = nn.LogSoftmax(-1)
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
if not return_dict:
output = (logits,) + outputs[1:]
......
......@@ -168,7 +168,7 @@ class XDropout(torch.autograd.Function):
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(torch.nn.Module):
class StableDropout(nn.Module):
"""
Optimized dropout module for stabilizing the training
......@@ -342,7 +342,7 @@ class ConvLayer(nn.Module):
kernel_size = getattr(config, "conv_kernel_size", 3)
groups = getattr(config, "conv_groups", 1)
self.conv_act = getattr(config, "conv_act", "tanh")
self.conv = torch.nn.Conv1d(
self.conv = nn.Conv1d(
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
......@@ -546,7 +546,7 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
class DisentangledSelfAttention(torch.nn.Module):
class DisentangledSelfAttention(nn.Module):
"""
Disentangled self-attention module
......@@ -1244,7 +1244,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
self.pooler = ContextPooler(config)
output_dim = self.pooler.output_dim
self.classifier = torch.nn.Linear(output_dim, num_labels)
self.classifier = nn.Linear(output_dim, num_labels)
drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
......@@ -1304,7 +1304,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
if labels is not None:
if self.num_labels == 1:
# regression task
loss_fn = torch.nn.MSELoss()
loss_fn = nn.MSELoss()
logits = logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1))
elif labels.dim() == 1 or labels.size(-1) == 1:
......@@ -1318,7 +1318,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
else:
loss = torch.tensor(0).to(logits)
else:
log_softmax = torch.nn.LogSoftmax(-1)
log_softmax = nn.LogSoftmax(-1)
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
if not return_dict:
output = (logits,) + outputs[1:]
......
......@@ -30,7 +30,7 @@ from ...utils import logging
if is_torch_available():
import torch
import torch.nn.functional as F
from torch import nn
logger = logging.get_logger(__name__)
......@@ -374,7 +374,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
# use PyTorch as current workaround
# TODO replace by self.resize
masks = torch.from_numpy(target["masks"][:, None]).float()
interpolated_masks = F.interpolate(masks, size=(h, w), mode="nearest")[:, 0] > 0.5
interpolated_masks = nn.functional.interpolate(masks, size=(h, w), mode="nearest")[:, 0] > 0.5
target["masks"] = interpolated_masks.numpy()
return rescaled_image, target
......@@ -697,7 +697,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
target_sizes.shape[1] == 2
), "Each element of target_sizes must contain the size (h, w) of each image of the batch"
prob = F.softmax(out_logits, -1)
prob = nn.functional.softmax(out_logits, -1)
scores, labels = prob[..., :-1].max(-1)
# convert to [x0, y0, x1, y1] format
......@@ -742,13 +742,15 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
), "Make sure to pass in as many orig_target_sizes as max_target_sizes"
max_h, max_w = max_target_sizes.max(0)[0].tolist()
outputs_masks = outputs.pred_masks.squeeze(2)
outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
outputs_masks = nn.functional.interpolate(
outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
)
outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
img_h, img_w = t[0], t[1]
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
results[i]["masks"] = F.interpolate(
results[i]["masks"] = nn.functional.interpolate(
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
).byte()
......@@ -810,7 +812,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep]
cur_masks = cur_masks[keep]
cur_masks = F.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
cur_boxes = center_to_corners_format(cur_boxes[keep])
h, w = cur_masks.shape[-2:]
......
......@@ -21,7 +21,6 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from ...activations import ACT2FN
......@@ -314,7 +313,7 @@ class DetrFrozenBatchNorm2d(nn.Module):
def replace_batch_norm(m, name=""):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if isinstance(target_attr, torch.nn.BatchNorm2d):
if isinstance(target_attr, nn.BatchNorm2d):
frozen = DetrFrozenBatchNorm2d(target_attr.num_features)
bn = getattr(m, attr_str)
frozen.weight.data.copy_(bn.weight)
......@@ -362,7 +361,7 @@ class DetrTimmConvEncoder(nn.Module):
out = []
for feature_map in features:
# downsample pixel_mask to match shape of corresponding feature_map
mask = F.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
out.append((feature_map, mask))
return out
......@@ -570,7 +569,7 @@ class DetrAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit awkward, but it's required to
......@@ -582,7 +581,7 @@ class DetrAttention(nn.Module):
else:
attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
......@@ -642,16 +641,16 @@ class DetrEncoderLayer(nn.Module):
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
......@@ -731,7 +730,7 @@ class DetrDecoderLayer(nn.Module):
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
......@@ -749,16 +748,16 @@ class DetrDecoderLayer(nn.Module):
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
......@@ -885,7 +884,7 @@ class DetrEncoder(DetrPreTrainedModel):
Args:
config: DetrConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: DetrConfig):
......@@ -946,7 +945,7 @@ class DetrEncoder(DetrPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = inputs_embeds
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
......@@ -999,7 +998,7 @@ class DetrDecoder(DetrPreTrainedModel):
Args:
config: DetrConfig
embed_tokens (torch.nn.Embedding): output embedding
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: DetrConfig, embed_tokens: Optional[nn.Embedding] = None):
......@@ -1717,23 +1716,23 @@ class DetrMaskHeadSmallConv(nn.Module):
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
self.gn1 = torch.nn.GroupNorm(8, dim)
self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1)
self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
self.gn1 = nn.GroupNorm(8, dim)
self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
self.gn2 = nn.GroupNorm(8, inter_dims[1])
self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
self.gn3 = nn.GroupNorm(8, inter_dims[2])
self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
self.gn4 = nn.GroupNorm(8, inter_dims[3])
self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
self.gn5 = nn.GroupNorm(8, inter_dims[4])
self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
self.dim = dim
self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
......@@ -1748,34 +1747,34 @@ class DetrMaskHeadSmallConv(nn.Module):
x = self.lay1(x)
x = self.gn1(x)
x = F.relu(x)
x = nn.functional.relu(x)
x = self.lay2(x)
x = self.gn2(x)
x = F.relu(x)
x = nn.functional.relu(x)
cur_fpn = self.adapter1(fpns[0])
if cur_fpn.size(0) != x.size(0):
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
x = self.lay3(x)
x = self.gn3(x)
x = F.relu(x)
x = nn.functional.relu(x)
cur_fpn = self.adapter2(fpns[1])
if cur_fpn.size(0) != x.size(0):
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
x = self.lay4(x)
x = self.gn4(x)
x = F.relu(x)
x = nn.functional.relu(x)
cur_fpn = self.adapter3(fpns[2])
if cur_fpn.size(0) != x.size(0):
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
x = self.lay5(x)
x = self.gn5(x)
x = F.relu(x)
x = nn.functional.relu(x)
x = self.out_lay(x)
return x
......@@ -1797,14 +1796,14 @@ class DetrMHAttentionMap(nn.Module):
def forward(self, q, k, mask: Optional[Tensor] = None):
q = self.q_linear(q)
k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
if mask is not None:
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size())
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
weights = self.dropout(weights)
return weights
......@@ -1847,7 +1846,7 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
......@@ -1909,7 +1908,7 @@ class DetrLoss(nn.Module):
)
target_classes[idx] = target_classes_o
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"loss_ce": loss_ce}
return losses
......@@ -1926,7 +1925,7 @@ class DetrLoss(nn.Module):
tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {"cardinality_error": card_err}
return losses
......@@ -1942,7 +1941,7 @@ class DetrLoss(nn.Module):
src_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
......@@ -1972,7 +1971,7 @@ class DetrLoss(nn.Module):
target_masks = target_masks[tgt_idx]
# upsample predictions to the target size
src_masks = F.interpolate(
src_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
)
src_masks = src_masks[:, 0].flatten(1)
......@@ -2068,7 +2067,7 @@ class DetrMLPPredictionHead(nn.Module):
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
......
......@@ -23,7 +23,7 @@ import math
import numpy as np
import torch
import torch.nn as nn
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment