Commit 0cd28352 authored by LysandreJik's avatar LysandreJik
Browse files

Attempt to fix head index

parent c85b5db6
......@@ -233,12 +233,14 @@ class Attention(nn.Module):
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = []
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.n_head, self.split_size // self.n_head)
for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
......@@ -249,6 +251,7 @@ class Attention(nn.Module):
# Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
self.pruned_heads.extend(heads)
def _attn(self, q, k, v, head_mask=None):
w = torch.matmul(q, k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment