Unverified Commit 1ed2ebf6 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[style] consistent nn. and nn.functional (#12124)

* consistent nn. and nn.functional

* fix glitch

* fix glitch #2
parent ff7c8168
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import math import math
import torch import torch
import torch.nn.functional as F
from packaging import version from packaging import version
from torch import nn
from .utils import logging from .utils import logging
...@@ -28,8 +28,8 @@ def _gelu_python(x): ...@@ -28,8 +28,8 @@ def _gelu_python(x):
""" """
Original Implementation of the GELU activation function in Google BERT repo when initially created. For Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
torch.nn.functional Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
""" """
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
...@@ -45,7 +45,7 @@ def gelu_new(x): ...@@ -45,7 +45,7 @@ def gelu_new(x):
if version.parse(torch.__version__) < version.parse("1.4"): if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python gelu = _gelu_python
else: else:
gelu = F.gelu gelu = nn.functional.gelu
def gelu_fast(x): def gelu_fast(x):
...@@ -70,11 +70,11 @@ def _silu_python(x): ...@@ -70,11 +70,11 @@ def _silu_python(x):
if version.parse(torch.__version__) < version.parse("1.7"): if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python silu = _silu_python
else: else:
silu = F.silu silu = nn.functional.silu
def mish(x): def mish(x):
return x * torch.tanh(torch.nn.functional.softplus(x)) return x * torch.tanh(nn.functional.softplus(x))
def linear_act(x): def linear_act(x):
...@@ -82,7 +82,7 @@ def linear_act(x): ...@@ -82,7 +82,7 @@ def linear_act(x):
ACT2FN = { ACT2FN = {
"relu": F.relu, "relu": nn.functional.relu,
"silu": silu, "silu": silu,
"swish": silu, "swish": silu,
"gelu": gelu, "gelu": gelu,
......
...@@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union ...@@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn import functional as F from torch import nn
from .file_utils import ModelOutput from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_beam_search import BeamScorer, BeamSearchScorer
...@@ -1564,7 +1564,7 @@ class GenerationMixin: ...@@ -1564,7 +1564,7 @@ class GenerationMixin:
) )
# sample # sample
probs = F.softmax(next_token_scores, dim=-1) probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token # finished sentences should have their next token be a padding token
...@@ -1801,9 +1801,11 @@ class GenerationMixin: ...@@ -1801,9 +1801,11 @@ class GenerationMixin:
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
...@@ -2098,9 +2100,11 @@ class GenerationMixin: ...@@ -2098,9 +2100,11 @@ class GenerationMixin:
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
...@@ -2128,7 +2132,7 @@ class GenerationMixin: ...@@ -2128,7 +2132,7 @@ class GenerationMixin:
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = F.softmax(next_token_scores, dim=-1) probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens) next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
...@@ -2426,9 +2430,11 @@ class GenerationMixin: ...@@ -2426,9 +2430,11 @@ class GenerationMixin:
next_token_logits = outputs.logits[batch_group_indices, -1, :] next_token_logits = outputs.logits[batch_group_indices, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
next_token_scores = logits_processor( next_token_scores = logits_processor(
......
...@@ -4,6 +4,7 @@ import inspect ...@@ -4,6 +4,7 @@ import inspect
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
from torch import nn
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
from torch.fx.node import Argument from torch.fx.node import Argument
...@@ -277,7 +278,7 @@ class HFTracer(Tracer): ...@@ -277,7 +278,7 @@ class HFTracer(Tracer):
return path return path
def path_of_module(self, mod: torch.nn.Module) -> str: def path_of_module(self, mod: nn.Module) -> str:
""" """
Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if
``root`` has a submodule named ``foo``, which has a submodule named ``bar``, passing ``bar`` into this function ``root`` has a submodule named ``foo``, which has a submodule named ``bar``, passing ``bar`` into this function
......
...@@ -25,7 +25,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union ...@@ -25,7 +25,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch import torch
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from .activations import get_activation from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -355,9 +354,7 @@ class ModuleUtilsMixin: ...@@ -355,9 +354,7 @@ class ModuleUtilsMixin:
""" """
def parameter_filter(x): def parameter_filter(x):
return (x.requires_grad or not only_trainable) and not ( return (x.requires_grad or not only_trainable) and not (isinstance(x, nn.Embedding) and exclude_embeddings)
isinstance(x, torch.nn.Embedding) and exclude_embeddings
)
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters() params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params) return sum(p.numel() for p in params)
...@@ -549,7 +546,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -549,7 +546,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
): ):
assert isinstance(decoder_pointer, nn.Module) and isinstance( assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
if hasattr(decoder_pointer, "weight"): if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight") assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight encoder_pointer.weight = decoder_pointer.weight
...@@ -613,7 +610,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -613,7 +610,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
output_embeddings.weight = input_embeddings.weight output_embeddings.weight = input_embeddings.weight
if getattr(output_embeddings, "bias", None) is not None: if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = torch.nn.functional.pad( output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data, output_embeddings.bias.data,
( (
0, 0,
...@@ -625,7 +622,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -625,7 +622,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding: def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
""" """
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`. Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
...@@ -668,8 +665,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -668,8 +665,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return self.get_input_embeddings() return self.get_input_embeddings()
def _get_resized_embeddings( def _get_resized_embeddings(
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None
) -> torch.nn.Embedding: ) -> nn.Embedding:
""" """
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
initialized vectors at the end. Reducing the size will remove vectors from the end initialized vectors at the end. Reducing the size will remove vectors from the end
...@@ -732,8 +729,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -732,8 +729,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return new_embeddings return new_embeddings
def _get_resized_lm_head( def _get_resized_lm_head(
self, old_lm_head: torch.nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
) -> torch.nn.Linear: ) -> nn.Linear:
""" """
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end vectors at the end. Reducing the size will remove vectors from the end
...@@ -1681,7 +1678,7 @@ class SQuADHead(nn.Module): ...@@ -1681,7 +1678,7 @@ class SQuADHead(nn.Module):
else: else:
# during inference, compute the end logits based on beam search # during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size() bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk( start_top_log_probs, start_top_index = torch.topk(
start_log_probs, self.start_n_top, dim=-1 start_log_probs, self.start_n_top, dim=-1
...@@ -1695,7 +1692,7 @@ class SQuADHead(nn.Module): ...@@ -1695,7 +1692,7 @@ class SQuADHead(nn.Module):
) # shape (bsz, slen, start_n_top, hsz) ) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk( end_top_log_probs, end_top_index = torch.topk(
end_log_probs, self.end_n_top, dim=1 end_log_probs, self.end_n_top, dim=1
...@@ -1820,7 +1817,7 @@ class SequenceSummary(nn.Module): ...@@ -1820,7 +1817,7 @@ class SequenceSummary(nn.Module):
return output return output
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: def unwrap_model(model: nn.Module) -> nn.Module:
""" """
Recursively unwraps a model from potential containers (as used in distributed training). Recursively unwraps a model from potential containers (as used in distributed training).
...@@ -1834,7 +1831,7 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: ...@@ -1834,7 +1831,7 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
return model return model
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear: def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
""" """
Prune a linear layer to keep only entries in index. Prune a linear layer to keep only entries in index.
...@@ -1902,8 +1899,8 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> ...@@ -1902,8 +1899,8 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) ->
def prune_layer( def prune_layer(
layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
) -> Union[torch.nn.Linear, Conv1D]: ) -> Union[nn.Linear, Conv1D]:
""" """
Prune a Conv1D or linear layer to keep only entries in index. Prune a Conv1D or linear layer to keep only entries in index.
......
...@@ -20,7 +20,7 @@ from dataclasses import dataclass ...@@ -20,7 +20,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
......
...@@ -20,7 +20,6 @@ import warnings ...@@ -20,7 +20,6 @@ import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
...@@ -223,7 +222,7 @@ class BartAttention(nn.Module): ...@@ -223,7 +222,7 @@ class BartAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -243,7 +242,7 @@ class BartAttention(nn.Module): ...@@ -243,7 +242,7 @@ class BartAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -303,15 +302,15 @@ class BartEncoderLayer(nn.Module): ...@@ -303,15 +302,15 @@ class BartEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -398,7 +397,7 @@ class BartDecoderLayer(nn.Module): ...@@ -398,7 +397,7 @@ class BartDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
...@@ -418,7 +417,7 @@ class BartDecoderLayer(nn.Module): ...@@ -418,7 +417,7 @@ class BartDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
...@@ -428,9 +427,9 @@ class BartDecoderLayer(nn.Module): ...@@ -428,9 +427,9 @@ class BartDecoderLayer(nn.Module):
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -661,7 +660,7 @@ class BartEncoder(BartPretrainedModel): ...@@ -661,7 +660,7 @@ class BartEncoder(BartPretrainedModel):
Args: Args:
config: BartConfig config: BartConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -760,7 +759,7 @@ class BartEncoder(BartPretrainedModel): ...@@ -760,7 +759,7 @@ class BartEncoder(BartPretrainedModel):
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -826,7 +825,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -826,7 +825,7 @@ class BartDecoder(BartPretrainedModel):
Args: Args:
config: BartConfig config: BartConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -997,7 +996,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -997,7 +996,7 @@ class BartDecoder(BartPretrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -139,7 +139,7 @@ class BertGenerationEmbeddings(nn.Module): ...@@ -139,7 +139,7 @@ class BertGenerationEmbeddings(nn.Module):
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
......
...@@ -22,7 +22,6 @@ from typing import Optional, Tuple ...@@ -22,7 +22,6 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -379,7 +378,7 @@ class BigBirdSelfAttention(nn.Module): ...@@ -379,7 +378,7 @@ class BigBirdSelfAttention(nn.Module):
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
...@@ -608,7 +607,9 @@ class BigBirdBlockSparseAttention(nn.Module): ...@@ -608,7 +607,9 @@ class BigBirdBlockSparseAttention(nn.Module):
first_product = first_product * rsqrt_d first_product = first_product * rsqrt_d
first_product += (1.0 - to_mask) * attn_mask_penalty first_product += (1.0 - to_mask) * attn_mask_penalty
first_attn_weights = F.softmax(first_product, dim=-1) # [bsz, n_heads, from_block_size, to_seq_len] first_attn_weights = nn.functional.softmax(
first_product, dim=-1
) # [bsz, n_heads, from_block_size, to_seq_len]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4) first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4)
...@@ -660,7 +661,7 @@ class BigBirdBlockSparseAttention(nn.Module): ...@@ -660,7 +661,7 @@ class BigBirdBlockSparseAttention(nn.Module):
) )
second_product = second_product * rsqrt_d second_product = second_product * rsqrt_d
second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty
second_attn_weights = F.softmax( second_attn_weights = nn.functional.softmax(
second_product, dim=-1 second_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
...@@ -721,7 +722,7 @@ class BigBirdBlockSparseAttention(nn.Module): ...@@ -721,7 +722,7 @@ class BigBirdBlockSparseAttention(nn.Module):
) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size]
# safely doing softmax since attention matrix is completed # safely doing softmax since attention matrix is completed
attn_weights = F.softmax( attn_weights = nn.functional.softmax(
band_product, dim=-1 band_product, dim=-1
) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size]
...@@ -794,7 +795,7 @@ class BigBirdBlockSparseAttention(nn.Module): ...@@ -794,7 +795,7 @@ class BigBirdBlockSparseAttention(nn.Module):
) )
second_last_product = second_last_product * rsqrt_d second_last_product = second_last_product * rsqrt_d
second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty
second_last_attn_weights = F.softmax( second_last_attn_weights = nn.functional.softmax(
second_last_product, dim=-1 second_last_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
...@@ -810,7 +811,7 @@ class BigBirdBlockSparseAttention(nn.Module): ...@@ -810,7 +811,7 @@ class BigBirdBlockSparseAttention(nn.Module):
last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4) last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4)
last_product = last_product * rsqrt_d last_product = last_product * rsqrt_d
last_product += (1.0 - to_mask) * attn_mask_penalty last_product += (1.0 - to_mask) * attn_mask_penalty
last_attn_weights = F.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n] last_attn_weights = nn.functional.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4) last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4)
...@@ -2210,10 +2211,10 @@ class BigBirdModel(BigBirdPreTrainedModel): ...@@ -2210,10 +2211,10 @@ class BigBirdModel(BigBirdPreTrainedModel):
f"`config.block_size`: {block_size}" f"`config.block_size`: {block_size}"
) )
if input_ids is not None: if input_ids is not None:
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id)
if position_ids is not None: if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_bigbird.BigBirdEmbeddings # pad with position_id = pad_token_id as in modeling_bigbird.BigBirdEmbeddings
position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) position_ids = nn.functional.pad(position_ids, (0, padding_len), value=pad_token_id)
if inputs_embeds is not None: if inputs_embeds is not None:
input_ids_padding = inputs_embeds.new_full( input_ids_padding = inputs_embeds.new_full(
(batch_size, padding_len), (batch_size, padding_len),
...@@ -2223,8 +2224,10 @@ class BigBirdModel(BigBirdPreTrainedModel): ...@@ -2223,8 +2224,10 @@ class BigBirdModel(BigBirdPreTrainedModel):
inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens attention_mask = nn.functional.pad(
token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 attention_mask, (0, padding_len), value=False
) # no attention on the padding tokens
token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
......
...@@ -22,7 +22,6 @@ from typing import Optional, Tuple ...@@ -22,7 +22,6 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
...@@ -206,7 +205,7 @@ class BigBirdPegasusSelfAttention(nn.Module): ...@@ -206,7 +205,7 @@ class BigBirdPegasusSelfAttention(nn.Module):
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
...@@ -436,7 +435,9 @@ class BigBirdPegasusBlockSparseAttention(nn.Module): ...@@ -436,7 +435,9 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
first_product = first_product * rsqrt_d first_product = first_product * rsqrt_d
first_product += (1.0 - to_mask) * attn_mask_penalty first_product += (1.0 - to_mask) * attn_mask_penalty
first_attn_weights = F.softmax(first_product, dim=-1) # [bsz, n_heads, from_block_size, to_seq_len] first_attn_weights = nn.functional.softmax(
first_product, dim=-1
) # [bsz, n_heads, from_block_size, to_seq_len]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4) first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4)
...@@ -488,7 +489,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module): ...@@ -488,7 +489,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
) )
second_product = second_product * rsqrt_d second_product = second_product * rsqrt_d
second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty
second_attn_weights = F.softmax( second_attn_weights = nn.functional.softmax(
second_product, dim=-1 second_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
...@@ -549,7 +550,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module): ...@@ -549,7 +550,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size]
# safely doing softmax since attention matrix is completed # safely doing softmax since attention matrix is completed
attn_weights = F.softmax( attn_weights = nn.functional.softmax(
band_product, dim=-1 band_product, dim=-1
) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size]
...@@ -622,7 +623,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module): ...@@ -622,7 +623,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
) )
second_last_product = second_last_product * rsqrt_d second_last_product = second_last_product * rsqrt_d
second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty
second_last_attn_weights = F.softmax( second_last_attn_weights = nn.functional.softmax(
second_last_product, dim=-1 second_last_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
...@@ -638,7 +639,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module): ...@@ -638,7 +639,7 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4) last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4)
last_product = last_product * rsqrt_d last_product = last_product * rsqrt_d
last_product += (1.0 - to_mask) * attn_mask_penalty last_product += (1.0 - to_mask) * attn_mask_penalty
last_attn_weights = F.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n] last_attn_weights = nn.functional.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4) last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4)
...@@ -1295,7 +1296,7 @@ class BigBirdPegasusDecoderAttention(nn.Module): ...@@ -1295,7 +1296,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -1315,7 +1316,7 @@ class BigBirdPegasusDecoderAttention(nn.Module): ...@@ -1315,7 +1316,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -1384,7 +1385,7 @@ class BigBirdPegasusEncoderLayer(nn.Module): ...@@ -1384,7 +1385,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
) )
hidden_states = self_attention_outputs[0] hidden_states = self_attention_outputs[0]
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
...@@ -1392,7 +1393,7 @@ class BigBirdPegasusEncoderLayer(nn.Module): ...@@ -1392,7 +1393,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and ( if hidden_states.dtype == torch.float16 and (
...@@ -1492,7 +1493,7 @@ class BigBirdPegasusDecoderLayer(nn.Module): ...@@ -1492,7 +1493,7 @@ class BigBirdPegasusDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Cross-Attention Block # Cross-Attention Block
...@@ -1512,7 +1513,7 @@ class BigBirdPegasusDecoderLayer(nn.Module): ...@@ -1512,7 +1513,7 @@ class BigBirdPegasusDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple # add cross-attn to positions 3,4 of present_key_value tuple
...@@ -1522,9 +1523,9 @@ class BigBirdPegasusDecoderLayer(nn.Module): ...@@ -1522,9 +1523,9 @@ class BigBirdPegasusDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
...@@ -1733,7 +1734,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -1733,7 +1734,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
Args: Args:
config: BigBirdPegasusConfig config: BigBirdPegasusConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -1829,7 +1830,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -1829,7 +1830,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=hidden_states.device) attention_mask = torch.ones(input_shape, device=hidden_states.device)
...@@ -2015,7 +2016,9 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -2015,7 +2016,9 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
inputs_embeds_padding = self.embed_tokens(input_ids_padding) inputs_embeds_padding = self.embed_tokens(input_ids_padding)
hidden_states = torch.cat([hidden_states, inputs_embeds_padding], dim=-2) hidden_states = torch.cat([hidden_states, inputs_embeds_padding], dim=-2)
attention_mask = F.pad(attention_mask, (0, padding_len), value=0) # no attention on the padding tokens attention_mask = nn.functional.pad(
attention_mask, (0, padding_len), value=0
) # no attention on the padding tokens
return padding_len, hidden_states, attention_mask return padding_len, hidden_states, attention_mask
...@@ -2027,7 +2030,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2027,7 +2030,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
Args: Args:
config: BigBirdPegasusConfig config: BigBirdPegasusConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -2198,7 +2201,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2198,7 +2201,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -23,7 +23,6 @@ import warnings ...@@ -23,7 +23,6 @@ import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -224,7 +223,7 @@ class BlenderbotAttention(nn.Module): ...@@ -224,7 +223,7 @@ class BlenderbotAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -244,7 +243,7 @@ class BlenderbotAttention(nn.Module): ...@@ -244,7 +243,7 @@ class BlenderbotAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -306,15 +305,15 @@ class BlenderbotEncoderLayer(nn.Module): ...@@ -306,15 +305,15 @@ class BlenderbotEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and ( if hidden_states.dtype == torch.float16 and (
...@@ -402,7 +401,7 @@ class BlenderbotDecoderLayer(nn.Module): ...@@ -402,7 +401,7 @@ class BlenderbotDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Cross-Attention Block # Cross-Attention Block
...@@ -422,7 +421,7 @@ class BlenderbotDecoderLayer(nn.Module): ...@@ -422,7 +421,7 @@ class BlenderbotDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple # add cross-attn to positions 3,4 of present_key_value tuple
...@@ -432,9 +431,9 @@ class BlenderbotDecoderLayer(nn.Module): ...@@ -432,9 +431,9 @@ class BlenderbotDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
...@@ -617,7 +616,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): ...@@ -617,7 +616,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
Args: Args:
config: BlenderbotConfig config: BlenderbotConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -715,7 +714,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): ...@@ -715,7 +714,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -784,7 +783,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -784,7 +783,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
Args: Args:
config: BlenderbotConfig config: BlenderbotConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -956,7 +955,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -956,7 +955,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -21,7 +21,6 @@ import random ...@@ -21,7 +21,6 @@ import random
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -222,7 +221,7 @@ class BlenderbotSmallAttention(nn.Module): ...@@ -222,7 +221,7 @@ class BlenderbotSmallAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,): if layer_head_mask.size() != (self.num_heads,):
...@@ -242,7 +241,7 @@ class BlenderbotSmallAttention(nn.Module): ...@@ -242,7 +241,7 @@ class BlenderbotSmallAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -303,15 +302,15 @@ class BlenderbotSmallEncoderLayer(nn.Module): ...@@ -303,15 +302,15 @@ class BlenderbotSmallEncoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -399,7 +398,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): ...@@ -399,7 +398,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
...@@ -419,7 +418,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): ...@@ -419,7 +418,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
...@@ -429,9 +428,9 @@ class BlenderbotSmallDecoderLayer(nn.Module): ...@@ -429,9 +428,9 @@ class BlenderbotSmallDecoderLayer(nn.Module):
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -618,7 +617,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): ...@@ -618,7 +617,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
Args: Args:
config: BlenderbotSmallConfig config: BlenderbotSmallConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -717,7 +716,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): ...@@ -717,7 +716,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -784,7 +783,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -784,7 +783,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
Args: Args:
config: BlenderbotSmallConfig config: BlenderbotSmallConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -957,7 +956,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -957,7 +956,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
inputs_embeds = self.layernorm_embedding(inputs_embeds) inputs_embeds = self.layernorm_embedding(inputs_embeds)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -21,6 +21,7 @@ import os ...@@ -21,6 +21,7 @@ import os
import numpy as np import numpy as np
import torch import torch
from packaging import version from packaging import version
from torch import nn
import gluonnlp as nlp import gluonnlp as nlp
import mxnet as mx import mxnet as mx
...@@ -170,8 +171,8 @@ def convert_bort_checkpoint_to_pytorch(bort_checkpoint_path: str, pytorch_dump_f ...@@ -170,8 +171,8 @@ def convert_bort_checkpoint_to_pytorch(bort_checkpoint_path: str, pytorch_dump_f
# | `encoder.transformer_cells.*.proj.weight` | `bert.encoder.layer.*.output.dense.weight` # | `encoder.transformer_cells.*.proj.weight` | `bert.encoder.layer.*.output.dense.weight`
# Helper function to convert MXNET Arrays to PyTorch # Helper function to convert MXNET Arrays to PyTorch
def to_torch(mx_array) -> torch.nn.Parameter: def to_torch(mx_array) -> nn.Parameter:
return torch.nn.Parameter(torch.FloatTensor(mx_array.data().asnumpy())) return nn.Parameter(torch.FloatTensor(mx_array.data().asnumpy()))
# Check param shapes and map new HF param back # Check param shapes and map new HF param back
def check_and_map_params(hf_param, gluon_param): def check_and_map_params(hf_param, gluon_param):
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
...@@ -62,7 +61,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -62,7 +61,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
# contrastive loss function, adapted from # contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor: def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
neg_ce = torch.diag(F.log_softmax(logits, dim=dim)) neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
return -neg_ce.mean() return -neg_ce.mean()
...@@ -235,7 +234,7 @@ class CLIPAttention(nn.Module): ...@@ -235,7 +234,7 @@ class CLIPAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions: if output_attentions:
# this operation is a bit akward, but it's required to # this operation is a bit akward, but it's required to
...@@ -247,7 +246,7 @@ class CLIPAttention(nn.Module): ...@@ -247,7 +246,7 @@ class CLIPAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -493,7 +492,7 @@ class CLIPEncoder(nn.Module): ...@@ -493,7 +492,7 @@ class CLIPEncoder(nn.Module):
Args: Args:
config: CLIPConfig config: CLIPConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: CLIPConfig): def __init__(self, config: CLIPConfig):
......
...@@ -383,7 +383,7 @@ class ConvBertSelfAttention(nn.Module): ...@@ -383,7 +383,7 @@ class ConvBertSelfAttention(nn.Module):
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -19,7 +19,7 @@ from typing import Tuple ...@@ -19,7 +19,7 @@ from typing import Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -87,7 +87,7 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N ...@@ -87,7 +87,7 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
return output, attention_weights return output, attention_weights
class MultiHeadAttention(torch.nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, d_model_size, num_heads): def __init__(self, d_model_size, num_heads):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
...@@ -95,11 +95,11 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -95,11 +95,11 @@ class MultiHeadAttention(torch.nn.Module):
self.depth = int(d_model_size / self.num_heads) self.depth = int(d_model_size / self.num_heads)
self.Wq = torch.nn.Linear(d_model_size, d_model_size) self.Wq = nn.Linear(d_model_size, d_model_size)
self.Wk = torch.nn.Linear(d_model_size, d_model_size) self.Wk = nn.Linear(d_model_size, d_model_size)
self.Wv = torch.nn.Linear(d_model_size, d_model_size) self.Wv = nn.Linear(d_model_size, d_model_size)
self.dense = torch.nn.Linear(d_model_size, d_model_size) self.dense = nn.Linear(d_model_size, d_model_size)
self.pruned_heads = set() self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads):
...@@ -167,21 +167,21 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -167,21 +167,21 @@ class MultiHeadAttention(torch.nn.Module):
def point_wise_feed_forward_network(d_model_size, dff): def point_wise_feed_forward_network(d_model_size, dff):
return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff), torch.nn.ReLU(), torch.nn.Linear(dff, d_model_size)) return nn.Sequential(nn.Linear(d_model_size, dff), nn.ReLU(), nn.Linear(dff, d_model_size))
class EncoderLayer(torch.nn.Module): class EncoderLayer(nn.Module):
def __init__(self, d_model_size, num_heads, dff, rate=0.1): def __init__(self, d_model_size, num_heads, dff, rate=0.1):
super().__init__() super().__init__()
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads) self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads)
self.ffn = point_wise_feed_forward_network(d_model_size, dff) self.ffn = point_wise_feed_forward_network(d_model_size, dff)
self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6) self.layernorm1 = nn.LayerNorm(d_model_size, eps=1e-6)
self.layernorm2 = torch.nn.LayerNorm(d_model_size, eps=1e-6) self.layernorm2 = nn.LayerNorm(d_model_size, eps=1e-6)
self.dropout1 = torch.nn.Dropout(rate) self.dropout1 = nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate) self.dropout2 = nn.Dropout(rate)
def forward( def forward(
self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
......
...@@ -163,7 +163,7 @@ class XDropout(torch.autograd.Function): ...@@ -163,7 +163,7 @@ class XDropout(torch.autograd.Function):
return grad_output, None return grad_output, None
class StableDropout(torch.nn.Module): class StableDropout(nn.Module):
""" """
Optimized dropout module for stabilizing the training Optimized dropout module for stabilizing the training
...@@ -477,7 +477,7 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer): ...@@ -477,7 +477,7 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
class DisentangledSelfAttention(torch.nn.Module): class DisentangledSelfAttention(nn.Module):
""" """
Disentangled self-attention module Disentangled self-attention module
...@@ -498,19 +498,17 @@ class DisentangledSelfAttention(torch.nn.Module): ...@@ -498,19 +498,17 @@ class DisentangledSelfAttention(torch.nn.Module):
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False) self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
self.relative_attention = getattr(config, "relative_attention", False) self.relative_attention = getattr(config, "relative_attention", False)
self.talking_head = getattr(config, "talking_head", False) self.talking_head = getattr(config, "talking_head", False)
if self.talking_head: if self.talking_head:
self.head_logits_proj = torch.nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False) self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
self.head_weights_proj = torch.nn.Linear( self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
config.num_attention_heads, config.num_attention_heads, bias=False
)
if self.relative_attention: if self.relative_attention:
self.max_relative_positions = getattr(config, "max_relative_positions", -1) self.max_relative_positions = getattr(config, "max_relative_positions", -1)
...@@ -519,9 +517,9 @@ class DisentangledSelfAttention(torch.nn.Module): ...@@ -519,9 +517,9 @@ class DisentangledSelfAttention(torch.nn.Module):
self.pos_dropout = StableDropout(config.hidden_dropout_prob) self.pos_dropout = StableDropout(config.hidden_dropout_prob)
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_proj = torch.nn.Linear(config.hidden_size, self.all_head_size, bias=False) self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_q_proj = torch.nn.Linear(config.hidden_size, self.all_head_size) self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = StableDropout(config.attention_probs_dropout_prob) self.dropout = StableDropout(config.attention_probs_dropout_prob)
...@@ -1122,7 +1120,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel): ...@@ -1122,7 +1120,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
self.pooler = ContextPooler(config) self.pooler = ContextPooler(config)
output_dim = self.pooler.output_dim output_dim = self.pooler.output_dim
self.classifier = torch.nn.Linear(output_dim, num_labels) self.classifier = nn.Linear(output_dim, num_labels)
drop_out = getattr(config, "cls_dropout", None) drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out) self.dropout = StableDropout(drop_out)
...@@ -1182,7 +1180,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel): ...@@ -1182,7 +1180,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.num_labels == 1:
# regression task # regression task
loss_fn = torch.nn.MSELoss() loss_fn = nn.MSELoss()
logits = logits.view(-1).to(labels.dtype) logits = logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1)) loss = loss_fn(logits, labels.view(-1))
elif labels.dim() == 1 or labels.size(-1) == 1: elif labels.dim() == 1 or labels.size(-1) == 1:
...@@ -1196,7 +1194,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel): ...@@ -1196,7 +1194,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
else: else:
loss = torch.tensor(0).to(logits) loss = torch.tensor(0).to(logits)
else: else:
log_softmax = torch.nn.LogSoftmax(-1) log_softmax = nn.LogSoftmax(-1)
loss = -((log_softmax(logits) * labels).sum(-1)).mean() loss = -((log_softmax(logits) * labels).sum(-1)).mean()
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
...@@ -168,7 +168,7 @@ class XDropout(torch.autograd.Function): ...@@ -168,7 +168,7 @@ class XDropout(torch.autograd.Function):
# Copied from transformers.models.deberta.modeling_deberta.StableDropout # Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(torch.nn.Module): class StableDropout(nn.Module):
""" """
Optimized dropout module for stabilizing the training Optimized dropout module for stabilizing the training
...@@ -342,7 +342,7 @@ class ConvLayer(nn.Module): ...@@ -342,7 +342,7 @@ class ConvLayer(nn.Module):
kernel_size = getattr(config, "conv_kernel_size", 3) kernel_size = getattr(config, "conv_kernel_size", 3)
groups = getattr(config, "conv_groups", 1) groups = getattr(config, "conv_groups", 1)
self.conv_act = getattr(config, "conv_act", "tanh") self.conv_act = getattr(config, "conv_act", "tanh")
self.conv = torch.nn.Conv1d( self.conv = nn.Conv1d(
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
) )
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
...@@ -546,7 +546,7 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer): ...@@ -546,7 +546,7 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
class DisentangledSelfAttention(torch.nn.Module): class DisentangledSelfAttention(nn.Module):
""" """
Disentangled self-attention module Disentangled self-attention module
...@@ -1244,7 +1244,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): ...@@ -1244,7 +1244,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
self.pooler = ContextPooler(config) self.pooler = ContextPooler(config)
output_dim = self.pooler.output_dim output_dim = self.pooler.output_dim
self.classifier = torch.nn.Linear(output_dim, num_labels) self.classifier = nn.Linear(output_dim, num_labels)
drop_out = getattr(config, "cls_dropout", None) drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out) self.dropout = StableDropout(drop_out)
...@@ -1304,7 +1304,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): ...@@ -1304,7 +1304,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.num_labels == 1:
# regression task # regression task
loss_fn = torch.nn.MSELoss() loss_fn = nn.MSELoss()
logits = logits.view(-1).to(labels.dtype) logits = logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1)) loss = loss_fn(logits, labels.view(-1))
elif labels.dim() == 1 or labels.size(-1) == 1: elif labels.dim() == 1 or labels.size(-1) == 1:
...@@ -1318,7 +1318,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): ...@@ -1318,7 +1318,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
else: else:
loss = torch.tensor(0).to(logits) loss = torch.tensor(0).to(logits)
else: else:
log_softmax = torch.nn.LogSoftmax(-1) log_softmax = nn.LogSoftmax(-1)
loss = -((log_softmax(logits) * labels).sum(-1)).mean() loss = -((log_softmax(logits) * labels).sum(-1)).mean()
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
...@@ -30,7 +30,7 @@ from ...utils import logging ...@@ -30,7 +30,7 @@ from ...utils import logging
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn.functional as F from torch import nn
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -374,7 +374,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -374,7 +374,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
# use PyTorch as current workaround # use PyTorch as current workaround
# TODO replace by self.resize # TODO replace by self.resize
masks = torch.from_numpy(target["masks"][:, None]).float() masks = torch.from_numpy(target["masks"][:, None]).float()
interpolated_masks = F.interpolate(masks, size=(h, w), mode="nearest")[:, 0] > 0.5 interpolated_masks = nn.functional.interpolate(masks, size=(h, w), mode="nearest")[:, 0] > 0.5
target["masks"] = interpolated_masks.numpy() target["masks"] = interpolated_masks.numpy()
return rescaled_image, target return rescaled_image, target
...@@ -697,7 +697,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -697,7 +697,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
target_sizes.shape[1] == 2 target_sizes.shape[1] == 2
), "Each element of target_sizes must contain the size (h, w) of each image of the batch" ), "Each element of target_sizes must contain the size (h, w) of each image of the batch"
prob = F.softmax(out_logits, -1) prob = nn.functional.softmax(out_logits, -1)
scores, labels = prob[..., :-1].max(-1) scores, labels = prob[..., :-1].max(-1)
# convert to [x0, y0, x1, y1] format # convert to [x0, y0, x1, y1] format
...@@ -742,13 +742,15 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -742,13 +742,15 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
), "Make sure to pass in as many orig_target_sizes as max_target_sizes" ), "Make sure to pass in as many orig_target_sizes as max_target_sizes"
max_h, max_w = max_target_sizes.max(0)[0].tolist() max_h, max_w = max_target_sizes.max(0)[0].tolist()
outputs_masks = outputs.pred_masks.squeeze(2) outputs_masks = outputs.pred_masks.squeeze(2)
outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) outputs_masks = nn.functional.interpolate(
outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
)
outputs_masks = (outputs_masks.sigmoid() > threshold).cpu() outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
img_h, img_w = t[0], t[1] img_h, img_w = t[0], t[1]
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
results[i]["masks"] = F.interpolate( results[i]["masks"] = nn.functional.interpolate(
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
).byte() ).byte()
...@@ -810,7 +812,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -810,7 +812,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
cur_scores = cur_scores[keep] cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep] cur_classes = cur_classes[keep]
cur_masks = cur_masks[keep] cur_masks = cur_masks[keep]
cur_masks = F.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
cur_boxes = center_to_corners_format(cur_boxes[keep]) cur_boxes = center_to_corners_format(cur_boxes[keep])
h, w = cur_masks.shape[-2:] h, w = cur_masks.shape[-2:]
......
...@@ -21,7 +21,6 @@ from dataclasses import dataclass ...@@ -21,7 +21,6 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
...@@ -314,7 +313,7 @@ class DetrFrozenBatchNorm2d(nn.Module): ...@@ -314,7 +313,7 @@ class DetrFrozenBatchNorm2d(nn.Module):
def replace_batch_norm(m, name=""): def replace_batch_norm(m, name=""):
for attr_str in dir(m): for attr_str in dir(m):
target_attr = getattr(m, attr_str) target_attr = getattr(m, attr_str)
if isinstance(target_attr, torch.nn.BatchNorm2d): if isinstance(target_attr, nn.BatchNorm2d):
frozen = DetrFrozenBatchNorm2d(target_attr.num_features) frozen = DetrFrozenBatchNorm2d(target_attr.num_features)
bn = getattr(m, attr_str) bn = getattr(m, attr_str)
frozen.weight.data.copy_(bn.weight) frozen.weight.data.copy_(bn.weight)
...@@ -362,7 +361,7 @@ class DetrTimmConvEncoder(nn.Module): ...@@ -362,7 +361,7 @@ class DetrTimmConvEncoder(nn.Module):
out = [] out = []
for feature_map in features: for feature_map in features:
# downsample pixel_mask to match shape of corresponding feature_map # downsample pixel_mask to match shape of corresponding feature_map
mask = F.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
out.append((feature_map, mask)) out.append((feature_map, mask))
return out return out
...@@ -570,7 +569,7 @@ class DetrAttention(nn.Module): ...@@ -570,7 +569,7 @@ class DetrAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions: if output_attentions:
# this operation is a bit awkward, but it's required to # this operation is a bit awkward, but it's required to
...@@ -582,7 +581,7 @@ class DetrAttention(nn.Module): ...@@ -582,7 +581,7 @@ class DetrAttention(nn.Module):
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
...@@ -642,16 +641,16 @@ class DetrEncoderLayer(nn.Module): ...@@ -642,16 +641,16 @@ class DetrEncoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -731,7 +730,7 @@ class DetrDecoderLayer(nn.Module): ...@@ -731,7 +730,7 @@ class DetrDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
...@@ -749,16 +748,16 @@ class DetrDecoderLayer(nn.Module): ...@@ -749,16 +748,16 @@ class DetrDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -885,7 +884,7 @@ class DetrEncoder(DetrPreTrainedModel): ...@@ -885,7 +884,7 @@ class DetrEncoder(DetrPreTrainedModel):
Args: Args:
config: DetrConfig config: DetrConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: DetrConfig): def __init__(self, config: DetrConfig):
...@@ -946,7 +945,7 @@ class DetrEncoder(DetrPreTrainedModel): ...@@ -946,7 +945,7 @@ class DetrEncoder(DetrPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = inputs_embeds hidden_states = inputs_embeds
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -999,7 +998,7 @@ class DetrDecoder(DetrPreTrainedModel): ...@@ -999,7 +998,7 @@ class DetrDecoder(DetrPreTrainedModel):
Args: Args:
config: DetrConfig config: DetrConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (nn.Embedding): output embedding
""" """
def __init__(self, config: DetrConfig, embed_tokens: Optional[nn.Embedding] = None): def __init__(self, config: DetrConfig, embed_tokens: Optional[nn.Embedding] = None):
...@@ -1717,23 +1716,23 @@ class DetrMaskHeadSmallConv(nn.Module): ...@@ -1717,23 +1716,23 @@ class DetrMaskHeadSmallConv(nn.Module):
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
self.gn1 = torch.nn.GroupNorm(8, dim) self.gn1 = nn.GroupNorm(8, dim)
self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) self.gn2 = nn.GroupNorm(8, inter_dims[1])
self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) self.gn3 = nn.GroupNorm(8, inter_dims[2])
self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) self.gn4 = nn.GroupNorm(8, inter_dims[3])
self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) self.gn5 = nn.GroupNorm(8, inter_dims[4])
self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
self.dim = dim self.dim = dim
self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
...@@ -1748,34 +1747,34 @@ class DetrMaskHeadSmallConv(nn.Module): ...@@ -1748,34 +1747,34 @@ class DetrMaskHeadSmallConv(nn.Module):
x = self.lay1(x) x = self.lay1(x)
x = self.gn1(x) x = self.gn1(x)
x = F.relu(x) x = nn.functional.relu(x)
x = self.lay2(x) x = self.lay2(x)
x = self.gn2(x) x = self.gn2(x)
x = F.relu(x) x = nn.functional.relu(x)
cur_fpn = self.adapter1(fpns[0]) cur_fpn = self.adapter1(fpns[0])
if cur_fpn.size(0) != x.size(0): if cur_fpn.size(0) != x.size(0):
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
x = self.lay3(x) x = self.lay3(x)
x = self.gn3(x) x = self.gn3(x)
x = F.relu(x) x = nn.functional.relu(x)
cur_fpn = self.adapter2(fpns[1]) cur_fpn = self.adapter2(fpns[1])
if cur_fpn.size(0) != x.size(0): if cur_fpn.size(0) != x.size(0):
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
x = self.lay4(x) x = self.lay4(x)
x = self.gn4(x) x = self.gn4(x)
x = F.relu(x) x = nn.functional.relu(x)
cur_fpn = self.adapter3(fpns[2]) cur_fpn = self.adapter3(fpns[2])
if cur_fpn.size(0) != x.size(0): if cur_fpn.size(0) != x.size(0):
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
x = self.lay5(x) x = self.lay5(x)
x = self.gn5(x) x = self.gn5(x)
x = F.relu(x) x = nn.functional.relu(x)
x = self.out_lay(x) x = self.out_lay(x)
return x return x
...@@ -1797,14 +1796,14 @@ class DetrMHAttentionMap(nn.Module): ...@@ -1797,14 +1796,14 @@ class DetrMHAttentionMap(nn.Module):
def forward(self, q, k, mask: Optional[Tensor] = None): def forward(self, q, k, mask: Optional[Tensor] = None):
q = self.q_linear(q) q = self.q_linear(q)
k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head) weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
if mask is not None: if mask is not None:
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size()) weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
weights = self.dropout(weights) weights = self.dropout(weights)
return weights return weights
...@@ -1847,7 +1846,7 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f ...@@ -1847,7 +1846,7 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
Loss tensor Loss tensor
""" """
prob = inputs.sigmoid() prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets) p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma) loss = ce_loss * ((1 - p_t) ** gamma)
...@@ -1909,7 +1908,7 @@ class DetrLoss(nn.Module): ...@@ -1909,7 +1908,7 @@ class DetrLoss(nn.Module):
) )
target_classes[idx] = target_classes_o target_classes[idx] = target_classes_o
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"loss_ce": loss_ce} losses = {"loss_ce": loss_ce}
return losses return losses
...@@ -1926,7 +1925,7 @@ class DetrLoss(nn.Module): ...@@ -1926,7 +1925,7 @@ class DetrLoss(nn.Module):
tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class) # Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {"cardinality_error": card_err} losses = {"cardinality_error": card_err}
return losses return losses
...@@ -1942,7 +1941,7 @@ class DetrLoss(nn.Module): ...@@ -1942,7 +1941,7 @@ class DetrLoss(nn.Module):
src_boxes = outputs["pred_boxes"][idx] src_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
losses = {} losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes losses["loss_bbox"] = loss_bbox.sum() / num_boxes
...@@ -1972,7 +1971,7 @@ class DetrLoss(nn.Module): ...@@ -1972,7 +1971,7 @@ class DetrLoss(nn.Module):
target_masks = target_masks[tgt_idx] target_masks = target_masks[tgt_idx]
# upsample predictions to the target size # upsample predictions to the target size
src_masks = F.interpolate( src_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
) )
src_masks = src_masks[:, 0].flatten(1) src_masks = src_masks[:, 0].flatten(1)
...@@ -2068,7 +2067,7 @@ class DetrMLPPredictionHead(nn.Module): ...@@ -2068,7 +2067,7 @@ class DetrMLPPredictionHead(nn.Module):
def forward(self, x): def forward(self, x):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x return x
......
...@@ -23,7 +23,7 @@ import math ...@@ -23,7 +23,7 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu from ...activations import gelu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment