"...resnet50_tensorflow.git" did not exist on "092a5461e4c6d272ffdeb26b940bec45f3019427"
Commit 0c8e823b authored by LysandreJik's avatar LysandreJik
Browse files

Added patch to remaining models

parent 0cd28352
...@@ -337,12 +337,14 @@ class BertAttention(nn.Module): ...@@ -337,12 +337,14 @@ class BertAttention(nn.Module):
super(BertAttention, self).__init__() super(BertAttention, self).__init__()
self.self = BertSelfAttention(config) self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config) self.output = BertSelfOutput(config)
self.pruned_heads = []
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
for head in heads: for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
mask[head] = 0 mask[head] = 0
mask = mask.view(-1).contiguous().eq(1) mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long() index = torch.arange(len(mask))[mask].long()
...@@ -354,6 +356,7 @@ class BertAttention(nn.Module): ...@@ -354,6 +356,7 @@ class BertAttention(nn.Module):
# Update hyper params # Update hyper params
self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads.extend(heads)
def forward(self, input_tensor, attention_mask, head_mask=None): def forward(self, input_tensor, attention_mask, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask) self_outputs = self.self(input_tensor, attention_mask, head_mask)
......
...@@ -249,12 +249,14 @@ class Attention(nn.Module): ...@@ -249,12 +249,14 @@ class Attention(nn.Module):
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = []
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.n_head, self.split_size // self.n_head) mask = torch.ones(self.n_head, self.split_size // self.n_head)
for head in heads: for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
mask[head] = 0 mask[head] = 0
mask = mask.view(-1).contiguous().eq(1) mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long() index = torch.arange(len(mask))[mask].long()
...@@ -265,6 +267,7 @@ class Attention(nn.Module): ...@@ -265,6 +267,7 @@ class Attention(nn.Module):
# Update hyper params # Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads) self.n_head = self.n_head - len(heads)
self.pruned_heads.extend(heads)
def _attn(self, q, k, v, head_mask=None): def _attn(self, q, k, v, head_mask=None):
w = torch.matmul(q, k) w = torch.matmul(q, k)
......
...@@ -271,6 +271,7 @@ class MultiHeadAttention(nn.Module): ...@@ -271,6 +271,7 @@ class MultiHeadAttention(nn.Module):
self.k_lin = nn.Linear(dim, dim) self.k_lin = nn.Linear(dim, dim)
self.v_lin = nn.Linear(dim, dim) self.v_lin = nn.Linear(dim, dim)
self.out_lin = nn.Linear(dim, dim) self.out_lin = nn.Linear(dim, dim)
self.pruned_heads = []
def prune_heads(self, heads): def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads attention_head_size = self.dim // self.n_heads
...@@ -278,6 +279,7 @@ class MultiHeadAttention(nn.Module): ...@@ -278,6 +279,7 @@ class MultiHeadAttention(nn.Module):
return return
mask = torch.ones(self.n_heads, attention_head_size) mask = torch.ones(self.n_heads, attention_head_size)
for head in heads: for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
mask[head] = 0 mask[head] = 0
mask = mask.view(-1).contiguous().eq(1) mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long() index = torch.arange(len(mask))[mask].long()
...@@ -289,6 +291,7 @@ class MultiHeadAttention(nn.Module): ...@@ -289,6 +291,7 @@ class MultiHeadAttention(nn.Module):
# Update hyper params # Update hyper params
self.n_heads = self.n_heads - len(heads) self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads self.dim = attention_head_size * self.n_heads
self.pruned_heads.extend(heads)
def forward(self, input, mask, kv=None, cache=None, head_mask=None): def forward(self, input, mask, kv=None, cache=None, head_mask=None):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment