Unverified Commit 4e4403c9 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[BART] torch 1.0 compatibility (#3322)

* config.activation_function
parent c44a17db
......@@ -44,8 +44,4 @@ def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(
"function {} not found in ACT2FN mapping {} or torch.nn.functional".format(
activation_string, list(ACT2FN.keys())
)
)
raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
......@@ -39,6 +39,7 @@ class BartConfig(PretrainedConfig):
def __init__(
self,
activation_dropout=0.0,
activation_function="gelu",
vocab_size=50265,
bos_token_id=0,
pad_token_id=1,
......@@ -89,6 +90,7 @@ class BartConfig(PretrainedConfig):
self.decoder_attention_heads = decoder_attention_heads
self.max_position_embeddings = max_position_embeddings
self.init_std = init_std # Normal(0, this parameter)
self.activation_function = activation_function
# 3 Types of Dropout
self.attention_dropout = attention_dropout
......
......@@ -21,6 +21,7 @@ import torch
import torch.nn.functional as F
from torch import Tensor, nn
from .activations import ACT2FN
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
......@@ -196,7 +197,7 @@ class EncoderLayer(nn.Module):
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = F.gelu
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
......@@ -278,8 +279,8 @@ class BartEncoder(nn.Module):
# check attention mask and invert
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * LARGE_NEGATIVE
assert attention_mask.max() <= 0
attention_mask = attention_mask.eq(0)
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos
......@@ -318,7 +319,7 @@ class DecoderLayer(nn.Module):
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
)
self.dropout = config.dropout
self.activation_fn = F.gelu
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
......@@ -334,13 +335,7 @@ class DecoderLayer(nn.Module):
self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(
self,
x,
encoder_hidden_states,
encoder_attn_mask=None,
layer_state=None,
attention_mask=None,
need_attn_weights=False,
self, x, encoder_hidden_states, encoder_attn_mask=None, layer_state=None, attention_mask=None,
):
residual = x
......@@ -437,9 +432,7 @@ class BartDecoder(nn.Module):
# check attention mask and invert
if encoder_padding_mask is not None:
assert encoder_padding_mask.dim() == 2
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
assert encoder_padding_mask.max() <= 0
encoder_padding_mask = encoder_padding_mask.eq(0)
# embed positions
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
......@@ -469,12 +462,7 @@ class BartDecoder(nn.Module):
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer(
x,
encoder_hidden_states,
encoder_padding_mask,
layer_state=layer_state,
attention_mask=combined_mask,
need_attn_weights=self.output_attentions,
x, encoder_hidden_states, encoder_padding_mask, layer_state=layer_state, attention_mask=combined_mask,
)
if self.output_past:
......@@ -598,7 +586,7 @@ class SelfAttention(nn.Module):
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
......@@ -648,22 +636,20 @@ class SelfAttention(nn.Module):
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.to(prev_key_padding_mask.device)
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
if prev_key_padding_mask is not None:
if static_kv:
new_key_padding_mask = prev_key_padding_mask
else:
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda:
filler = filler.cuda()
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
filler = torch.zeros(
batch_size,
src_len - key_padding_mask.size(1),
dtype=key_padding_mask.dtype,
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat([filler, key_padding_mask], dim=1)
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment