"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "c54fb922631f15347087444702060f1c907a2926"
Unverified Commit bad83008 authored by Benoit's avatar Benoit Committed by GitHub
Browse files

Error (also in original) model, scaling only q matrix not qk.T dot product...

Error (also in original) model, scaling only q matrix not qk.T dot product (qk.T/sqrt(dim_per_head)) (#21627)

* Error in model, scaling only q matrix not qK.T dot product (qk.T/sqrt(dim_per_head))

As per Vaswani et al, 2017 p.4
Is torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) not q / math.sqrt(dim_per_head)
https://arxiv.org/pdf/1912.05372.pdf

Error was in original FlauBERT repo and effectively scales queries but not values
cf. https://github.com/getalp/Flaubert/pull/45/commits/6d176880ca3a1a8dfa2b76c97030bb51c5e917b8

* Update modeling_flaubert.py

Update to https://github.com/huggingface/transformers/pull/21627
make fixup
make repo_consistency

* Update modeling_xlm.py

* Update modeling_flaubert.py

* Update modeling_xlm.py
parent aaf6795f
......@@ -172,8 +172,7 @@ class MultiHeadAttention(nn.Module):
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
......
......@@ -176,8 +176,7 @@ class MultiHeadAttention(nn.Module):
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment