Commit 8e5d84fc authored by v_sboliu's avatar v_sboliu Committed by Julien Chaumond
Browse files

Fixed typo

parent 5d3b8daa
...@@ -278,7 +278,7 @@ class BertAttention(nn.Module): ...@@ -278,7 +278,7 @@ class BertAttention(nn.Module):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
for head in heads: for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly # Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads) head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment