Unverified Commit 55dae94c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Revert "Error (also in original) model, scaling only q matrix not qk.T dot...

Revert "Error (also in original) model, scaling only q matrix not qk.T dot product (qk.T/sqrt(dim_per_head))" (#22444)

Revert "Error (also in original) model, scaling only q matrix not qk.T dot product (qk.T/sqrt(dim_per_head)) (#21627)"

This reverts commit bad83008.
parent 8894b817
...@@ -172,7 +172,8 @@ class MultiHeadAttention(nn.Module): ...@@ -172,7 +172,8 @@ class MultiHeadAttention(nn.Module):
k, v = cache[self.layer_id] k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v) cache[self.layer_id] = (k, v)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen) q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
......
...@@ -176,7 +176,8 @@ class MultiHeadAttention(nn.Module): ...@@ -176,7 +176,8 @@ class MultiHeadAttention(nn.Module):
k, v = cache[self.layer_id] k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v) cache[self.layer_id] = (k, v)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen) q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment