Commit 9285d94d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove undefined reference

parent 998155e9
......@@ -410,10 +410,16 @@ class Attention(nn.Module):
Args:
q_x:
[*, Q, C_q] query data
k_x:
kv_x:
[*, K, C_k] key data
v_x:
[*, V, C_v] value data
biases:
List of biases that broadcast to [*, H, Q, K]
use_lma:
Whether to use low-memory attention
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
......@@ -429,7 +435,7 @@ class Attention(nn.Module):
if(use_lma):
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment