Unverified Commit 4e4403c9 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[BART] torch 1.0 compatibility (#3322)

* config.activation_function
parent c44a17db
...@@ -44,8 +44,4 @@ def get_activation(activation_string): ...@@ -44,8 +44,4 @@ def get_activation(activation_string):
if activation_string in ACT2FN: if activation_string in ACT2FN:
return ACT2FN[activation_string] return ACT2FN[activation_string]
else: else:
raise KeyError( raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
"function {} not found in ACT2FN mapping {} or torch.nn.functional".format(
activation_string, list(ACT2FN.keys())
)
)
...@@ -39,6 +39,7 @@ class BartConfig(PretrainedConfig): ...@@ -39,6 +39,7 @@ class BartConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
activation_dropout=0.0, activation_dropout=0.0,
activation_function="gelu",
vocab_size=50265, vocab_size=50265,
bos_token_id=0, bos_token_id=0,
pad_token_id=1, pad_token_id=1,
...@@ -89,6 +90,7 @@ class BartConfig(PretrainedConfig): ...@@ -89,6 +90,7 @@ class BartConfig(PretrainedConfig):
self.decoder_attention_heads = decoder_attention_heads self.decoder_attention_heads = decoder_attention_heads
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.init_std = init_std # Normal(0, this parameter) self.init_std = init_std # Normal(0, this parameter)
self.activation_function = activation_function
# 3 Types of Dropout # 3 Types of Dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
......
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from .activations import ACT2FN
from .configuration_bart import BartConfig from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
...@@ -196,7 +197,7 @@ class EncoderLayer(nn.Module): ...@@ -196,7 +197,7 @@ class EncoderLayer(nn.Module):
) )
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = F.gelu self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
...@@ -278,8 +279,8 @@ class BartEncoder(nn.Module): ...@@ -278,8 +279,8 @@ class BartEncoder(nn.Module):
# check attention mask and invert # check attention mask and invert
if attention_mask is not None: if attention_mask is not None:
assert attention_mask.dim() == 2 assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * LARGE_NEGATIVE attention_mask = attention_mask.eq(0)
assert attention_mask.max() <= 0
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids) embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos x = inputs_embeds + embed_pos
...@@ -318,7 +319,7 @@ class DecoderLayer(nn.Module): ...@@ -318,7 +319,7 @@ class DecoderLayer(nn.Module):
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = F.gelu self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
...@@ -334,13 +335,7 @@ class DecoderLayer(nn.Module): ...@@ -334,13 +335,7 @@ class DecoderLayer(nn.Module):
self.final_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
def forward( def forward(
self, self, x, encoder_hidden_states, encoder_attn_mask=None, layer_state=None, attention_mask=None,
x,
encoder_hidden_states,
encoder_attn_mask=None,
layer_state=None,
attention_mask=None,
need_attn_weights=False,
): ):
residual = x residual = x
...@@ -437,9 +432,7 @@ class BartDecoder(nn.Module): ...@@ -437,9 +432,7 @@ class BartDecoder(nn.Module):
# check attention mask and invert # check attention mask and invert
if encoder_padding_mask is not None: if encoder_padding_mask is not None:
assert encoder_padding_mask.dim() == 2 assert encoder_padding_mask.dim() == 2
encoder_padding_mask = encoder_padding_mask.eq(0)
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
assert encoder_padding_mask.max() <= 0
# embed positions # embed positions
positions = self.embed_positions(input_ids, generation_mode=generation_mode) positions = self.embed_positions(input_ids, generation_mode=generation_mode)
...@@ -469,12 +462,7 @@ class BartDecoder(nn.Module): ...@@ -469,12 +462,7 @@ class BartDecoder(nn.Module):
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer( x, layer_self_attn, layer_past = decoder_layer(
x, x, encoder_hidden_states, encoder_padding_mask, layer_state=layer_state, attention_mask=combined_mask,
encoder_hidden_states,
encoder_padding_mask,
layer_state=layer_state,
attention_mask=combined_mask,
need_attn_weights=self.output_attentions,
) )
if self.output_past: if self.output_past:
...@@ -598,7 +586,7 @@ class SelfAttention(nn.Module): ...@@ -598,7 +586,7 @@ class SelfAttention(nn.Module):
if key_padding_mask is not None: # don't attend to padding symbols if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool) reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
...@@ -648,22 +636,20 @@ class SelfAttention(nn.Module): ...@@ -648,22 +636,20 @@ class SelfAttention(nn.Module):
static_kv: bool, static_kv: bool,
) -> Optional[Tensor]: ) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len) # saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv: if prev_key_padding_mask is not None:
new_key_padding_mask = prev_key_padding_mask if static_kv:
elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = prev_key_padding_mask
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1) else:
# During incremental decoding, as the padding token enters and new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
# leaves the frame, there will be a time when prev or current is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.to(prev_key_padding_mask.device)
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
elif key_padding_mask is not None: elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)) filler = torch.zeros(
if key_padding_mask.is_cuda: batch_size,
filler = filler.cuda() src_len - key_padding_mask.size(1),
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1) dtype=key_padding_mask.dtype,
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat([filler, key_padding_mask], dim=1)
else: else:
new_key_padding_mask = prev_key_padding_mask new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask return new_key_padding_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment