"vscode:/vscode.git/clone" did not exist on "0de15430fd79c57358bf8749410602ccf9c240f3"
Commit 215930bc authored by Tri Dao's avatar Tri Dao
Browse files

Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd

parent 4f81aff4
...@@ -257,11 +257,9 @@ def _bwd_kernel_one_col_block( ...@@ -257,11 +257,9 @@ def _bwd_kernel_one_col_block(
start_m = tl.multiple_of(start_m, BLOCK_M) start_m = tl.multiple_of(start_m, BLOCK_M)
offs_m_curr = start_m + offs_m offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip # load q, k, v, do on-chip
if EVEN_M: # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
if EVEN_HEADDIM: if EVEN_M & EVEN_HEADDIM:
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=(offs_d[None, :] < headdim))
else: else:
if EVEN_HEADDIM: if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment