Unverified Commit a3c37825 authored by KarlFelixJoehnk's avatar KarlFelixJoehnk Committed by GitHub
Browse files

Make the attention_head_size in distilbert an object attribute (#20970)



* [Fix] Make the attention head size in distilbert an object attribute

* Fix code style
Co-authored-by: default avatarFelix Joehnk <fjoehnk@N73GCH2NDH.corp.proofpoint.com>
parent e3ecbaa4
...@@ -153,12 +153,14 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -153,12 +153,14 @@ class MultiHeadSelfAttention(nn.Module):
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.pruned_heads: Set[int] = set() self.pruned_heads: Set[int] = set()
self.attention_head_size = self.dim // self.n_heads
def prune_heads(self, heads: List[int]): def prune_heads(self, heads: List[int]):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0: if len(heads) == 0:
return return
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads) heads, index = find_pruneable_heads_and_indices(
heads, self.n_heads, self.attention_head_size, self.pruned_heads
)
# Prune linear layers # Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index) self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index) self.k_lin = prune_linear_layer(self.k_lin, index)
...@@ -166,7 +168,7 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -166,7 +168,7 @@ class MultiHeadSelfAttention(nn.Module):
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
# Update hyper params # Update hyper params
self.n_heads = self.n_heads - len(heads) self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads self.dim = self.attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment