Unverified Commit 00b19f19 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[triton] Remove the zero initialization of qk_acc by directly writing the result (#1288)

parent 6cb32ef9
...@@ -127,8 +127,7 @@ def _fwd_kernel( ...@@ -127,8 +127,7 @@ def _fwd_kernel(
) )
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.dot(q.to(k.dtype), k)
qk += tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_kpe = ( offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs offs_kv_loc[None, :] * stride_buf_kbs
...@@ -179,9 +178,7 @@ def _fwd_kernel( ...@@ -179,9 +178,7 @@ def _fwd_kernel(
) )
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.dot(q, k, out_dtype=tl.float32)
qk += tl.dot(q, k)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_kpe = ( offs_kpe = (
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment