"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f8208fa456039b46873a2e497b6318d30a4fc84e"
Unverified Commit 734b7e2a authored by Had's avatar Had Committed by GitHub
Browse files

Mask t5 relative position bias then head pruned (#17968)



* add position bias head masking if heads pruned

* fix pruning function in t5 encoder

* make style

* make fix-copies

* Revert added folder
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent d4dbd7ca
...@@ -518,7 +518,14 @@ class LongT5Attention(nn.Module): ...@@ -518,7 +518,14 @@ class LongT5Attention(nn.Module):
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
scores += position_bias if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores scores
) # (batch_size, n_heads, seq_length, key_length) ) # (batch_size, n_heads, seq_length, key_length)
......
...@@ -528,7 +528,14 @@ class T5Attention(nn.Module): ...@@ -528,7 +528,14 @@ class T5Attention(nn.Module):
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
scores += position_bias if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores scores
) # (batch_size, n_heads, seq_length, key_length) ) # (batch_size, n_heads, seq_length, key_length)
...@@ -1802,7 +1809,7 @@ class T5EncoderModel(T5PreTrainedModel): ...@@ -1802,7 +1809,7 @@ class T5EncoderModel(T5PreTrainedModel):
class PreTrainedModel class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment