Commit edfc8f82 authored by Rémi Louf's avatar Rémi Louf
Browse files

Remove and do the branching in

parent 09cfd122
...@@ -282,53 +282,13 @@ class BertAttention(nn.Module): ...@@ -282,53 +282,13 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask) self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
class BertDecoderAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertGeneralAttention(config)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, query, key, value, attention_mask=None, head_mask=None):
self_outputs = self.self(query, key, value, attention_mask, head_mask)
# in encoder-decoder attention we use the output of the previous decoder stage as the query
# in the Multi-Head Attention. We thus pass query_tensor as the residual in BertOutput.
# This shows the limits of the current code architecture, which may benefit from some refactoring.
attention_output = self.output(self_outputs[0], query)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module): class BertIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertIntermediate, self).__init__() super(BertIntermediate, self).__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment