Commit 9285d94d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove undefined reference

parent 998155e9
...@@ -410,10 +410,16 @@ class Attention(nn.Module): ...@@ -410,10 +410,16 @@ class Attention(nn.Module):
Args: Args:
q_x: q_x:
[*, Q, C_q] query data [*, Q, C_q] query data
k_x: kv_x:
[*, K, C_k] key data [*, K, C_k] key data
v_x: biases:
[*, V, C_v] value data List of biases that broadcast to [*, H, Q, K]
use_lma:
Whether to use low-memory attention
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
Key/Value chunk size (for LMA)
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
...@@ -429,7 +435,7 @@ class Attention(nn.Module): ...@@ -429,7 +435,7 @@ class Attention(nn.Module):
if(use_lma): if(use_lma):
biases = [ biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases for b in biases
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment