Unverified Commit 28830402 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

[example] change qkv processing (#870)

parent 96211c2c
......@@ -89,13 +89,14 @@ class GPTSelfAttention(nn.Module):
def forward(self, x, attention_mask=None):
qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3
num_attention_heads = divide(all_head_size, self.attention_head_size)
new_qkv_shape = qkv.shape[:-1] + \
(num_attention_heads, 3 * self.attention_head_size)
qkv = qkv.view(new_qkv_shape)
qkv = qkv.permute((0, 2, 1, 3))
q, k, v = torch.chunk(qkv, 3, dim=-1)
all_head_size = q.shape[-1]
num_attention_heads = divide(all_head_size, self.attention_head_size)
new_shape = q.shape[:-1] + \
(num_attention_heads, self.attention_head_size)
q = q.view(new_shape).permute((0, 2, 1, 3)).contiguous()
k = k.view(new_shape).permute((0, 2, 1, 3)).contiguous()
v = v.view(new_shape).permute((0, 2, 1, 3)).contiguous()
x = torch.matmul(q, k.transpose(-1, -2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment