Unverified Commit 6bf88537 authored by Kian Sierra McGettigan's avatar Kian Sierra McGettigan Committed by GitHub
Browse files

Prophetnet batch dimension inversion fix (#21870)

* decoder forward pass is working

* no model has forward pass returning attentions

* decoder ngram changed to not mix batch size

* current basic forward pass returns identical result

* passed test_model attentions

* passed test_encoder_decoder_model_generate

* passed test_headmasking

* removed old block

* removed comments bug/fixme

* removed bug comments

* applied styling

* applied fix-copies

* applied ngram forward comments

* corrected dimension notation

* applied styling and comment fixes

* changed asserts for raise ValueError

* changed question gen test

* updated hidden_states integration test

* applied styling
parent 99ba36e7
...@@ -701,44 +701,27 @@ class ProphetNetAttention(nn.Module): ...@@ -701,44 +701,27 @@ class ProphetNetAttention(nn.Module):
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
# project states into the correct shape # project states into the correct shape
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape) query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.view(*proj_shape)
src_len = key_states.size(2)
src_len = key_states.size(1) attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
assert attn_weights.size() == ( if attn_weights.size() != expected_shape:
batch_size * self.num_attn_heads, raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
tgt_len,
src_len,
), (
f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size"
f" {attn_weights.shape}"
)
# This is part of a workaround to get around fork/join parallelism not supporting Optional types. # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if attention_mask is not None and attention_mask.dim() == 0: if attention_mask is not None and attention_mask.dim() == 0:
attention_mask = None attention_mask = None
assert attention_mask is None or attention_mask.size() == (
self.num_attn_heads * batch_size,
1,
src_len,
), (
"`attention_mask` should be `None` or of shape attention_mask.size() =="
f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
)
expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
if attention_mask is not None and attention_mask.size() != expected_shape:
raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
if attention_mask is not None: # don't attend to padding symbols if attention_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
if output_attentions: if output_attentions:
# this operation is a bit awkward, but it's required to attn_weights_reshaped = attn_weights
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len)
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
...@@ -752,7 +735,6 @@ class ProphetNetAttention(nn.Module): ...@@ -752,7 +735,6 @@ class ProphetNetAttention(nn.Module):
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
batch_size, self.num_attn_heads, tgt_len, src_len batch_size, self.num_attn_heads, tgt_len, src_len
) )
attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len)
# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
...@@ -762,23 +744,12 @@ class ProphetNetAttention(nn.Module): ...@@ -762,23 +744,12 @@ class ProphetNetAttention(nn.Module):
p=self.attention_dropout, p=self.attention_dropout,
training=self.training, training=self.training,
) )
attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
if attn_output.size() != expected_shape:
raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
attn_output = torch.bmm(attn_probs, value_states) attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
assert attn_output.size() == (
batch_size * self.num_attn_heads,
tgt_len,
self.head_dim,
), (
f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of"
f" shape {attn_output.size()}"
)
attn_output = (
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, tgt_len, hidden_size)
)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
...@@ -856,7 +827,6 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -856,7 +827,6 @@ class ProphetNetNgramSelfAttention(nn.Module):
position_ids=None, position_ids=None,
): ):
batch_size, ngram_sequence_length, hidden_size = hidden_states.size() batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
f" {hidden_states.shape}" f" {hidden_states.shape}"
...@@ -874,8 +844,7 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -874,8 +844,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
query_states = self._shape(query_states, ngram_sequence_length, batch_size) query_states = self._shape(query_states, ngram_sequence_length, batch_size)
key_states = self._shape(key_states, -1, batch_size) key_states = self._shape(key_states, -1, batch_size)
value_states = self._shape(value_states, -1, batch_size) value_states = self._shape(value_states, -1, batch_size)
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
query_states = query_states.view(*proj_shape) query_states = query_states.view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.view(*proj_shape)
...@@ -883,10 +852,9 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -883,10 +852,9 @@ class ProphetNetNgramSelfAttention(nn.Module):
# chunk into main stream and predict stream # chunk into main stream and predict stream
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
query_states_list = query_states.chunk(1 + self.ngram, dim=2)
query_states_list = query_states.chunk(1 + self.ngram, dim=1) key_states_list = key_states.chunk(1 + self.ngram, dim=2)
key_states_list = key_states.chunk(1 + self.ngram, dim=1) value_states_list = value_states.chunk(1 + self.ngram, dim=2)
value_states_list = value_states.chunk(1 + self.ngram, dim=1)
main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:] main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:] main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
...@@ -895,28 +863,29 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -895,28 +863,29 @@ class ProphetNetNgramSelfAttention(nn.Module):
# saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
if past_key_value is not None: if past_key_value is not None:
prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim) prev_main_key_states = past_key_value[0]
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1) main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2)
prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim) prev_main_value_states = past_key_value[1]
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1) main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2)
# Update cache # Update cache
past_key_value = ( past_key_value = (main_key_states, main_value_states)
main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
)
# get seq_length of main stream only # get seq_length of main stream only
sequence_length = ngram_sequence_length // (1 + self.ngram) sequence_length = ngram_sequence_length // (1 + self.ngram)
# MAIN-STREAM # MAIN-STREAM
# main attn weights # main attn weights
main_attn_weights = torch.bmm(main_query_states, main_key_states.transpose(1, 2)) # [batch_size, number_heads, sequence_length, head_dimesion]
# x [batch_size, number_heads, head_dimesion, sequence_length]
# -> [batch_size, number_heads, sequence_length, sequence_length]
main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
# retrieve relative position embeddings for each layer -> see paper for more details # retrieve relative position embeddings for each layer -> see paper for more details
main_relative_pos_embeddings = self.get_main_relative_pos_embeddings( main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
) )
main_attn_weights = main_attn_weights + main_relative_pos_embeddings main_attn_weights = main_attn_weights + main_relative_pos_embeddings
if attention_mask is not None: if attention_mask is not None:
...@@ -936,55 +905,53 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -936,55 +905,53 @@ class ProphetNetNgramSelfAttention(nn.Module):
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view( main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
batch_size, self.num_attn_heads, -1, sequence_length batch_size, self.num_attn_heads, -1, sequence_length
) )
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)
main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output # project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states) # [batch_size, number_heads, sequence_length, sequence_length]
# x [batch_size, number_heads, sequence_length, head_dimesion]
# -> [batch_size, number_heads, sequence_length, head_dimesion]
main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
# reshape so that num_heads dim is merged into last `head_dim` axis # reshape so that num_heads dim is merged into last `head_dim` axis
main_attn_output = ( main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, 1, sequence_length, hidden_size)
)
main_attn_output = self.out_proj(main_attn_output) main_attn_output = self.out_proj(main_attn_output)
# PREDICT-STREAM # PREDICT-STREAM
# [ngram, B*head, T, c] # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
predict_query_states = torch.cat(predict_query_states_list, 0).view( predict_query_states = torch.stack(predict_query_states_list, 1).view(
self.ngram, -1, sequence_length, self.head_dim batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
)
# [ngram, B*head, 2*T, c]
predict_key_states = torch.cat(
[torch.cat([main_key_states, key], 1).unsqueeze(0) for key in predict_key_states_list], 0
) )
# [ngram, T, B, C] # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view( predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
self.ngram, sequence_length, batch_size, hidden_size
) # [batch_size, sequence_length, ngram, hidden_size]
predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
# [ngram, B*head, 2*T, c] # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
predict_value_states = torch.cat( predict_value_states = torch.cat(
[torch.cat([main_value_states, v_p], 1).unsqueeze(0) for v_p in predict_value_states_list], 0 [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
) )
# [ngram, B*head, T, 2*T]
predict_attn_weights = torch.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states))
# [ngram, B*head, T, S] # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
# -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
# retrieve relative position embeddings for each layer -> see paper for more details # retrieve relative position embeddings for each layer -> see paper for more details
# [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
) )
# [ngram, B*head, T, 2*T] # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
if extended_predict_attention_mask is not None: if extended_predict_attention_mask is not None:
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to( # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights.dtype extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
) extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
predict_attn_probs = softmax( predict_attn_probs = softmax(
predict_attn_weights, predict_attn_weights,
...@@ -997,37 +964,30 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -997,37 +964,30 @@ class ProphetNetNgramSelfAttention(nn.Module):
f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
f" {layer_head_mask.size()}" f" {layer_head_mask.size()}"
) )
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view( predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = nn.functional.dropout( predict_attn_probs = nn.functional.dropout(
predict_attn_probs, p=self.attention_dropout, training=self.training predict_attn_probs, p=self.attention_dropout, training=self.training
) )
# project to attention output # project to attention output
# [ngram, B*head, T, c] # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
# -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
predict_attn_output = torch.einsum(
"bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
)
# reshape so that num_heads dim is merged into last `head_dim` axis # reshape so that num_heads dim is merged into last `head_dim` axis
# [ngram, B, T, C] # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
predict_attn_output = ( predict_attn_output = predict_attn_output.transpose(2, 3)
predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim) predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
.permute(1, 0, 3, 2, 4)
.reshape(batch_size, self.ngram, sequence_length, hidden_size)
)
predict_attn_output = self.out_proj(predict_attn_output) predict_attn_output = self.out_proj(predict_attn_output)
# concat to single attn output # concat to single attn output
# [B, 1+ngram*T, C] # [batch_size, (1+ngram)*sequence_length, hidden_size]
attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size) attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
# reshape into better form for `config.output_attentions` # reshape into better form for `config.output_attentions`
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1) main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
).transpose(0, 1)
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
...@@ -1036,8 +996,11 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -1036,8 +996,11 @@ class ProphetNetNgramSelfAttention(nn.Module):
def get_main_relative_pos_embeddings( def get_main_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, main_relative_position_buckets self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
): ):
# input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1] # input hidden_states [batch_size, sequence_length, hidden_size]
# input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
# input position_ids [batch_size, sequence_length] or [1,1]
batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
if main_relative_position_buckets is None: if main_relative_position_buckets is None:
batch_size, sequence_length = hidden_states.shape[:2] batch_size, sequence_length = hidden_states.shape[:2]
relative_positions = ( relative_positions = (
...@@ -1047,39 +1010,42 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -1047,39 +1010,42 @@ class ProphetNetNgramSelfAttention(nn.Module):
.repeat(batch_size, sequence_length, 1) .repeat(batch_size, sequence_length, 1)
.to(position_ids.device) .to(position_ids.device)
) )
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat( # [batch_size, sequence_length, sequence_length+1]
batch_size, sequence_length, 1 relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
) # [B, T, s]
main_relative_position_buckets = compute_relative_buckets( main_relative_position_buckets = compute_relative_buckets(
self.num_buckets, self.relative_max_distance, relative_positions, False self.num_buckets, self.relative_max_distance, relative_positions, False
) )
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head] # [batch_size, sequence_length, num_buckets * num_heads]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
rel_pos_embeddings = rel_pos_embeddings.view( rel_pos_embeddings = rel_pos_embeddings.view(
rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads) rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
).permute( )
0, 3, 1, 2 rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
) # [B,T,Buckets,head] # [batch_size, num_heads, sequence_length, num_buckets]
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + (-1,)) # [B*head,T,Buckets] rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
main_relative_position_buckets = (
main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
.view(-1, main_relative_position_buckets.shape[-1])
.long()
) # [B*head*T, T]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) # [B*head*T,Buckets]
main_relative_pos_embeddings = torch.gather(
rel_pos_embeddings, dim=1, index=main_relative_position_buckets
).view(attn_weights.shape[:2] + (-1,))
main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
# [batch_size * num_heads * sequence_length, sequence_length]
main_relative_position_buckets = main_relative_position_buckets.view(
-1, main_relative_position_buckets.shape[-1]
)
main_relative_position_buckets = main_relative_position_buckets.long()
# [batch_size * num_heads * sequence_length, sequence_length]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
return main_relative_pos_embeddings return main_relative_pos_embeddings
def get_predict_relative_pos_embeddings( def get_predict_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
): ):
# input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None # input hidden_states [batch_size, sequence_length, ngram, hidden_size]
sequence_length, batch_size = hidden_states.shape[1:3] # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
# input position_ids [batch_size, sequence_length] or [1,1]
# input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
batch_size, sequence_length = hidden_states.shape[0:2]
if predict_relative_position_buckets is None: if predict_relative_position_buckets is None:
key_sequence_length = attn_weights.shape[-1] key_sequence_length = attn_weights.shape[-1]
...@@ -1099,28 +1065,35 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -1099,28 +1065,35 @@ class ProphetNetNgramSelfAttention(nn.Module):
self.num_buckets, self.relative_max_distance, relative_positions, False self.num_buckets, self.relative_max_distance, relative_positions, False
) )
hidden_states = hidden_states.transpose(1, 2) # [ngram, B, T, C] # [batch_size, ngram, sequence_length, hidden_size]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states).view( hidden_states = hidden_states.transpose(1, 2)
hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads) rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
) # [ngram, B, T, bucket, head]
rel_pos_embeddings = rel_pos_embeddings.permute(0, 1, 4, 2, 3).reshape(
self.ngram * batch_size * self.num_attn_heads, sequence_length, -1
) # [ngram*B*head, T, bucket]
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0).repeat( # [batch_size, ngram, sequence_length, num_buckets, num_heads]
rel_pos_embeddings = rel_pos_embeddings.view(
hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
)
rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
# [batch_size * ngram * sequence_length * num_heads, num_buckets]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
# [ngram, batch_size, num_heads * sequence_length, -1]
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
predict_relative_position_buckets = predict_relative_position_buckets.repeat(
self.ngram, 1, self.num_attn_heads, 1 self.ngram, 1, self.num_attn_heads, 1
) # [ngram, B, head*T, S] )
# [ngram * batch_size * num_heads * sequence_length, -1]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
predict_relative_position_buckets = predict_relative_position_buckets.view( predict_relative_position_buckets = predict_relative_position_buckets.view(
-1, predict_relative_position_buckets.size(-1) -1, predict_relative_position_buckets.size(-1)
).long() # [ngram*B*head*T, S] ).long()
predict_relative_pos_embeddings = torch.gather( predict_relative_pos_embeddings = torch.gather(
rel_pos_embeddings, dim=1, index=predict_relative_position_buckets rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
).view( )
self.ngram, batch_size * self.num_attn_heads, sequence_length, -1
) # [ngram, B*head, T, S] # [batch_size, gram, num_heads, sequence_length, -1]
predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
)
return predict_relative_pos_embeddings return predict_relative_pos_embeddings
...@@ -1331,7 +1304,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1331,7 +1304,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
# prepare attention mask # prepare attention mask
if attention_mask is not None: if attention_mask is not None:
extended_attention_mask = ( extended_attention_mask = (
1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1) 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
) * torch.finfo(self.dtype).min ) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
else: else:
...@@ -1549,7 +1522,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1549,7 +1522,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# prepare encoder attention mask # prepare encoder attention mask
if encoder_attention_mask is not None: if encoder_attention_mask is not None:
extended_encoder_attention_mask = ( extended_encoder_attention_mask = (
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1) 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
) * torch.finfo(self.dtype).min ) * torch.finfo(self.dtype).min
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
else: else:
...@@ -1717,17 +1690,18 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1717,17 +1690,18 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
device=hidden_states.device, device=hidden_states.device,
) )
causal_mask = torch.triu(causal_mask, 1) causal_mask = torch.triu(causal_mask, 1)
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
(batch_size,) + causal_mask.shape extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
(batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
) )
# add usual attention mask # add usual attention mask
if attention_mask is not None: if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_causal_mask + extended_attention_mask extended_attention_mask = extended_causal_mask + extended_attention_mask
else: else:
extended_attention_mask = extended_causal_mask extended_attention_mask = extended_causal_mask
return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype) return extended_attention_mask.to(hidden_states.dtype)
def prepare_predict_attention_mask(self, hidden_states, attention_mask): def prepare_predict_attention_mask(self, hidden_states, attention_mask):
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
...@@ -1745,14 +1719,16 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1745,14 +1719,16 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
], ],
dim=-1, dim=-1,
) )
extended_predict_causal_mask = predict_causal_mask[:, None, :, :].expand( extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
predict_causal_mask.shape[:1] + (batch_size,) + predict_causal_mask.shape[1:] (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
) )
# add usual attention mask # add usual attention mask
if attention_mask is not None: if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length)) extended_attention_mask = extended_attention_mask.expand(
(batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
)
# predicted stream attention_mask should always be 0 # predicted stream attention_mask should always be 0
extended_attention_mask = torch.cat( extended_attention_mask = torch.cat(
[extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1 [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
...@@ -1760,9 +1736,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1760,9 +1736,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
else: else:
extended_predict_attention_mask = extended_predict_causal_mask extended_predict_attention_mask = extended_predict_causal_mask
return extended_predict_attention_mask.repeat(1, self.config.num_decoder_attention_heads, 1, 1).to( return extended_predict_attention_mask.to(hidden_states.dtype)
hidden_states.dtype
)
@add_start_docstrings( @add_start_docstrings(
......
...@@ -716,44 +716,27 @@ class XLMProphetNetAttention(nn.Module): ...@@ -716,44 +716,27 @@ class XLMProphetNetAttention(nn.Module):
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
# project states into the correct shape # project states into the correct shape
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape) query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.view(*proj_shape)
src_len = key_states.size(2)
src_len = key_states.size(1) attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
assert attn_weights.size() == ( if attn_weights.size() != expected_shape:
batch_size * self.num_attn_heads, raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
tgt_len,
src_len,
), (
f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size"
f" {attn_weights.shape}"
)
# This is part of a workaround to get around fork/join parallelism not supporting Optional types. # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if attention_mask is not None and attention_mask.dim() == 0: if attention_mask is not None and attention_mask.dim() == 0:
attention_mask = None attention_mask = None
assert attention_mask is None or attention_mask.size() == (
self.num_attn_heads * batch_size,
1,
src_len,
), (
"`attention_mask` should be `None` or of shape attention_mask.size() =="
f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
)
expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
if attention_mask is not None and attention_mask.size() != expected_shape:
raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
if attention_mask is not None: # don't attend to padding symbols if attention_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
if output_attentions: if output_attentions:
# this operation is a bit awkward, but it's required to attn_weights_reshaped = attn_weights
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len)
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
...@@ -767,7 +750,6 @@ class XLMProphetNetAttention(nn.Module): ...@@ -767,7 +750,6 @@ class XLMProphetNetAttention(nn.Module):
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
batch_size, self.num_attn_heads, tgt_len, src_len batch_size, self.num_attn_heads, tgt_len, src_len
) )
attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len)
# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
...@@ -777,23 +759,12 @@ class XLMProphetNetAttention(nn.Module): ...@@ -777,23 +759,12 @@ class XLMProphetNetAttention(nn.Module):
p=self.attention_dropout, p=self.attention_dropout,
training=self.training, training=self.training,
) )
attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
if attn_output.size() != expected_shape:
raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
attn_output = torch.bmm(attn_probs, value_states) attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
assert attn_output.size() == (
batch_size * self.num_attn_heads,
tgt_len,
self.head_dim,
), (
f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of"
f" shape {attn_output.size()}"
)
attn_output = (
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, tgt_len, hidden_size)
)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
...@@ -873,7 +844,6 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -873,7 +844,6 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
position_ids=None, position_ids=None,
): ):
batch_size, ngram_sequence_length, hidden_size = hidden_states.size() batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
f" {hidden_states.shape}" f" {hidden_states.shape}"
...@@ -891,8 +861,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -891,8 +861,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
query_states = self._shape(query_states, ngram_sequence_length, batch_size) query_states = self._shape(query_states, ngram_sequence_length, batch_size)
key_states = self._shape(key_states, -1, batch_size) key_states = self._shape(key_states, -1, batch_size)
value_states = self._shape(value_states, -1, batch_size) value_states = self._shape(value_states, -1, batch_size)
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
query_states = query_states.view(*proj_shape) query_states = query_states.view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.view(*proj_shape)
...@@ -900,10 +869,9 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -900,10 +869,9 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
# chunk into main stream and predict stream # chunk into main stream and predict stream
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
query_states_list = query_states.chunk(1 + self.ngram, dim=2)
query_states_list = query_states.chunk(1 + self.ngram, dim=1) key_states_list = key_states.chunk(1 + self.ngram, dim=2)
key_states_list = key_states.chunk(1 + self.ngram, dim=1) value_states_list = value_states.chunk(1 + self.ngram, dim=2)
value_states_list = value_states.chunk(1 + self.ngram, dim=1)
main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:] main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:] main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
...@@ -912,28 +880,29 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -912,28 +880,29 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
# saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
if past_key_value is not None: if past_key_value is not None:
prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim) prev_main_key_states = past_key_value[0]
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1) main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2)
prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim) prev_main_value_states = past_key_value[1]
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1) main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2)
# Update cache # Update cache
past_key_value = ( past_key_value = (main_key_states, main_value_states)
main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
)
# get seq_length of main stream only # get seq_length of main stream only
sequence_length = ngram_sequence_length // (1 + self.ngram) sequence_length = ngram_sequence_length // (1 + self.ngram)
# MAIN-STREAM # MAIN-STREAM
# main attn weights # main attn weights
main_attn_weights = torch.bmm(main_query_states, main_key_states.transpose(1, 2)) # [batch_size, number_heads, sequence_length, head_dimesion]
# x [batch_size, number_heads, head_dimesion, sequence_length]
# -> [batch_size, number_heads, sequence_length, sequence_length]
main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
# retrieve relative position embeddings for each layer -> see paper for more details # retrieve relative position embeddings for each layer -> see paper for more details
main_relative_pos_embeddings = self.get_main_relative_pos_embeddings( main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
) )
main_attn_weights = main_attn_weights + main_relative_pos_embeddings main_attn_weights = main_attn_weights + main_relative_pos_embeddings
if attention_mask is not None: if attention_mask is not None:
...@@ -953,55 +922,53 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -953,55 +922,53 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view( main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
batch_size, self.num_attn_heads, -1, sequence_length batch_size, self.num_attn_heads, -1, sequence_length
) )
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)
main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output # project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states) # [batch_size, number_heads, sequence_length, sequence_length]
# x [batch_size, number_heads, sequence_length, head_dimesion]
# -> [batch_size, number_heads, sequence_length, head_dimesion]
main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
# reshape so that num_heads dim is merged into last `head_dim` axis # reshape so that num_heads dim is merged into last `head_dim` axis
main_attn_output = ( main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, 1, sequence_length, hidden_size)
)
main_attn_output = self.out_proj(main_attn_output) main_attn_output = self.out_proj(main_attn_output)
# PREDICT-STREAM # PREDICT-STREAM
# [ngram, B*head, T, c] # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
predict_query_states = torch.cat(predict_query_states_list, 0).view( predict_query_states = torch.stack(predict_query_states_list, 1).view(
self.ngram, -1, sequence_length, self.head_dim batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
)
# [ngram, B*head, 2*T, c]
predict_key_states = torch.cat(
[torch.cat([main_key_states, key], 1).unsqueeze(0) for key in predict_key_states_list], 0
) )
# [ngram, T, B, C] # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view( predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
self.ngram, sequence_length, batch_size, hidden_size
) # [batch_size, sequence_length, ngram, hidden_size]
predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
# [ngram, B*head, 2*T, c] # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
predict_value_states = torch.cat( predict_value_states = torch.cat(
[torch.cat([main_value_states, v_p], 1).unsqueeze(0) for v_p in predict_value_states_list], 0 [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
) )
# [ngram, B*head, T, 2*T]
predict_attn_weights = torch.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states))
# [ngram, B*head, T, S] # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
# -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
# retrieve relative position embeddings for each layer -> see paper for more details # retrieve relative position embeddings for each layer -> see paper for more details
# [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
) )
# [ngram, B*head, T, 2*T] # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
if extended_predict_attention_mask is not None: if extended_predict_attention_mask is not None:
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to( # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights.dtype extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
) extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
predict_attn_probs = softmax( predict_attn_probs = softmax(
predict_attn_weights, predict_attn_weights,
...@@ -1014,37 +981,30 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -1014,37 +981,30 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
f" {layer_head_mask.size()}" f" {layer_head_mask.size()}"
) )
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view( predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = nn.functional.dropout( predict_attn_probs = nn.functional.dropout(
predict_attn_probs, p=self.attention_dropout, training=self.training predict_attn_probs, p=self.attention_dropout, training=self.training
) )
# project to attention output # project to attention output
# [ngram, B*head, T, c] # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
# -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
predict_attn_output = torch.einsum(
"bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
)
# reshape so that num_heads dim is merged into last `head_dim` axis # reshape so that num_heads dim is merged into last `head_dim` axis
# [ngram, B, T, C] # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
predict_attn_output = ( predict_attn_output = predict_attn_output.transpose(2, 3)
predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim) predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
.permute(1, 0, 3, 2, 4)
.reshape(batch_size, self.ngram, sequence_length, hidden_size)
)
predict_attn_output = self.out_proj(predict_attn_output) predict_attn_output = self.out_proj(predict_attn_output)
# concat to single attn output # concat to single attn output
# [B, 1+ngram*T, C] # [batch_size, (1+ngram)*sequence_length, hidden_size]
attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size) attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
# reshape into better form for `config.output_attentions` # reshape into better form for `config.output_attentions`
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1) main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
).transpose(0, 1)
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
...@@ -1053,8 +1013,11 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -1053,8 +1013,11 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
def get_main_relative_pos_embeddings( def get_main_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, main_relative_position_buckets self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
): ):
# input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1] # input hidden_states [batch_size, sequence_length, hidden_size]
# input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
# input position_ids [batch_size, sequence_length] or [1,1]
batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
if main_relative_position_buckets is None: if main_relative_position_buckets is None:
batch_size, sequence_length = hidden_states.shape[:2] batch_size, sequence_length = hidden_states.shape[:2]
relative_positions = ( relative_positions = (
...@@ -1064,39 +1027,42 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -1064,39 +1027,42 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
.repeat(batch_size, sequence_length, 1) .repeat(batch_size, sequence_length, 1)
.to(position_ids.device) .to(position_ids.device)
) )
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat( # [batch_size, sequence_length, sequence_length+1]
batch_size, sequence_length, 1 relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
) # [B, T, s]
main_relative_position_buckets = compute_relative_buckets( main_relative_position_buckets = compute_relative_buckets(
self.num_buckets, self.relative_max_distance, relative_positions, False self.num_buckets, self.relative_max_distance, relative_positions, False
) )
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head] # [batch_size, sequence_length, num_buckets * num_heads]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
rel_pos_embeddings = rel_pos_embeddings.view( rel_pos_embeddings = rel_pos_embeddings.view(
rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads) rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
).permute( )
0, 3, 1, 2 rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
) # [B,T,Buckets,head] # [batch_size, num_heads, sequence_length, num_buckets]
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + (-1,)) # [B*head,T,Buckets] rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
main_relative_position_buckets = (
main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
.view(-1, main_relative_position_buckets.shape[-1])
.long()
) # [B*head*T, T]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) # [B*head*T,Buckets]
main_relative_pos_embeddings = torch.gather(
rel_pos_embeddings, dim=1, index=main_relative_position_buckets
).view(attn_weights.shape[:2] + (-1,))
main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
# [batch_size * num_heads * sequence_length, sequence_length]
main_relative_position_buckets = main_relative_position_buckets.view(
-1, main_relative_position_buckets.shape[-1]
)
main_relative_position_buckets = main_relative_position_buckets.long()
# [batch_size * num_heads * sequence_length, sequence_length]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
return main_relative_pos_embeddings return main_relative_pos_embeddings
def get_predict_relative_pos_embeddings( def get_predict_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
): ):
# input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None # input hidden_states [batch_size, sequence_length, ngram, hidden_size]
sequence_length, batch_size = hidden_states.shape[1:3] # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
# input position_ids [batch_size, sequence_length] or [1,1]
# input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
batch_size, sequence_length = hidden_states.shape[0:2]
if predict_relative_position_buckets is None: if predict_relative_position_buckets is None:
key_sequence_length = attn_weights.shape[-1] key_sequence_length = attn_weights.shape[-1]
...@@ -1116,28 +1082,35 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -1116,28 +1082,35 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
self.num_buckets, self.relative_max_distance, relative_positions, False self.num_buckets, self.relative_max_distance, relative_positions, False
) )
hidden_states = hidden_states.transpose(1, 2) # [ngram, B, T, C] # [batch_size, ngram, sequence_length, hidden_size]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states).view( hidden_states = hidden_states.transpose(1, 2)
hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads) rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
) # [ngram, B, T, bucket, head]
rel_pos_embeddings = rel_pos_embeddings.permute(0, 1, 4, 2, 3).reshape(
self.ngram * batch_size * self.num_attn_heads, sequence_length, -1
) # [ngram*B*head, T, bucket]
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0).repeat( # [batch_size, ngram, sequence_length, num_buckets, num_heads]
rel_pos_embeddings = rel_pos_embeddings.view(
hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
)
rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
# [batch_size * ngram * sequence_length * num_heads, num_buckets]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
# [ngram, batch_size, num_heads * sequence_length, -1]
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
predict_relative_position_buckets = predict_relative_position_buckets.repeat(
self.ngram, 1, self.num_attn_heads, 1 self.ngram, 1, self.num_attn_heads, 1
) # [ngram, B, head*T, S] )
# [ngram * batch_size * num_heads * sequence_length, -1]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
predict_relative_position_buckets = predict_relative_position_buckets.view( predict_relative_position_buckets = predict_relative_position_buckets.view(
-1, predict_relative_position_buckets.size(-1) -1, predict_relative_position_buckets.size(-1)
).long() # [ngram*B*head*T, S] ).long()
predict_relative_pos_embeddings = torch.gather( predict_relative_pos_embeddings = torch.gather(
rel_pos_embeddings, dim=1, index=predict_relative_position_buckets rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
).view( )
self.ngram, batch_size * self.num_attn_heads, sequence_length, -1
) # [ngram, B*head, T, S] # [batch_size, gram, num_heads, sequence_length, -1]
predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
)
return predict_relative_pos_embeddings return predict_relative_pos_embeddings
...@@ -1351,7 +1324,7 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): ...@@ -1351,7 +1324,7 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
# prepare attention mask # prepare attention mask
if attention_mask is not None: if attention_mask is not None:
extended_attention_mask = ( extended_attention_mask = (
1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1) 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
) * torch.finfo(self.dtype).min ) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
else: else:
...@@ -1572,7 +1545,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1572,7 +1545,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
# prepare encoder attention mask # prepare encoder attention mask
if encoder_attention_mask is not None: if encoder_attention_mask is not None:
extended_encoder_attention_mask = ( extended_encoder_attention_mask = (
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1) 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
) * torch.finfo(self.dtype).min ) * torch.finfo(self.dtype).min
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
else: else:
...@@ -1740,17 +1713,18 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1740,17 +1713,18 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
device=hidden_states.device, device=hidden_states.device,
) )
causal_mask = torch.triu(causal_mask, 1) causal_mask = torch.triu(causal_mask, 1)
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
(batch_size,) + causal_mask.shape extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
(batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
) )
# add usual attention mask # add usual attention mask
if attention_mask is not None: if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_causal_mask + extended_attention_mask extended_attention_mask = extended_causal_mask + extended_attention_mask
else: else:
extended_attention_mask = extended_causal_mask extended_attention_mask = extended_causal_mask
return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype) return extended_attention_mask.to(hidden_states.dtype)
def prepare_predict_attention_mask(self, hidden_states, attention_mask): def prepare_predict_attention_mask(self, hidden_states, attention_mask):
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
...@@ -1768,14 +1742,16 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1768,14 +1742,16 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
], ],
dim=-1, dim=-1,
) )
extended_predict_causal_mask = predict_causal_mask[:, None, :, :].expand( extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
predict_causal_mask.shape[:1] + (batch_size,) + predict_causal_mask.shape[1:] (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
) )
# add usual attention mask # add usual attention mask
if attention_mask is not None: if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length)) extended_attention_mask = extended_attention_mask.expand(
(batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
)
# predicted stream attention_mask should always be 0 # predicted stream attention_mask should always be 0
extended_attention_mask = torch.cat( extended_attention_mask = torch.cat(
[extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1 [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
...@@ -1783,9 +1759,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1783,9 +1759,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
else: else:
extended_predict_attention_mask = extended_predict_causal_mask extended_predict_attention_mask = extended_predict_causal_mask
return extended_predict_attention_mask.repeat(1, self.config.num_decoder_attention_heads, 1, 1).to( return extended_predict_attention_mask.to(hidden_states.dtype)
hidden_states.dtype
)
@add_start_docstrings( @add_start_docstrings(
......
...@@ -1206,7 +1206,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): ...@@ -1206,7 +1206,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 12, 30522)) expected_shape = torch.Size((1, 12, 30522))
self.assertEqual(output_predited_logits.shape, expected_shape) self.assertEqual(output_predited_logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[[-7.6213, -7.9008, -7.9979], [-7.6834, -7.8467, -8.2187], [-7.5326, -7.4762, -8.1914]]] [[[-7.7729, -8.0343, -8.26001], [-7.74213, -7.8629, -8.6000], [-7.7328, -7.8269, -8.5264]]]
).to(torch_device) ).to(torch_device)
# self.assertTrue(torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4)) # self.assertTrue(torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4))
assert torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4) assert torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4)
...@@ -1306,7 +1306,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): ...@@ -1306,7 +1306,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
EXPECTED_QUESTIONS = [ EXPECTED_QUESTIONS = [
"along with paul allen, who founded microsoft?", "along with paul allen, who founded microsoft?",
"what year was microsoft founded?", "what year was microsoft founded?",
"on what date was microsoft founded?", "when was microsoft founded?",
] ]
self.assertListEqual( self.assertListEqual(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment