"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "68c5d532645b9ca4edb986236ded0cc0af0d7761"
Unverified Commit a7d0b288 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Remove the need for `einsum` in Albert's attention computation (#12394)

* debug albert einsum

* Fix matmul computation

* Let's use torch linear layer.

* Style.
parent 276bc149
...@@ -360,18 +360,9 @@ class AlbertAttention(nn.Module): ...@@ -360,18 +360,9 @@ class AlbertAttention(nn.Module):
attention_probs = attention_probs * head_mask attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(2, 1).flatten(2)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() projected_context_layer = self.dense(context_layer)
# Should find a better way to do this
w = (
self.dense.weight.t()
.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
.to(context_layer.dtype)
)
b = self.dense.bias.to(context_layer.dtype)
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
projected_context_layer_dropout = self.output_dropout(projected_context_layer) projected_context_layer_dropout = self.output_dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout) layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,) return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment