Unverified Commit 2f9bd0fa authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix correctness issue for triton decoding kernel (#2479)

parent 5282a473
...@@ -32,7 +32,7 @@ is_hip_ = is_hip() ...@@ -32,7 +32,7 @@ is_hip_ = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. # TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
logger.warn( logger.warning(
"The following error message 'operation scheduled before its operands' can be ignored." "The following error message 'operation scheduled before its operands' can be ignored."
) )
...@@ -474,6 +474,7 @@ def _decode_grouped_att_m_fwd( ...@@ -474,6 +474,7 @@ def _decode_grouped_att_m_fwd(
def _fwd_kernel_stage2( def _fwd_kernel_stage2(
Mid_O, Mid_O,
O, O,
B_Seqlen,
stride_mid_ob, stride_mid_ob,
stride_mid_oh, stride_mid_oh,
stride_mid_os, stride_mid_os,
...@@ -486,6 +487,8 @@ def _fwd_kernel_stage2( ...@@ -486,6 +487,8 @@ def _fwd_kernel_stage2(
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
offs_d = tl.arange(0, BLOCK_DV) offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv mask_d = offs_d < Lv
...@@ -497,19 +500,24 @@ def _fwd_kernel_stage2( ...@@ -497,19 +500,24 @@ def _fwd_kernel_stage2(
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
for split_kv_id in range(0, NUM_KV_SPLITS): for split_kv_id in range(0, NUM_KV_SPLITS):
tv = tl.load( kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 split_kv_start = kv_len_per_split * split_kv_id
) split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max) if split_kv_end > split_kv_start:
acc *= old_scale tv = tl.load(
exp_logic = tl.exp(tlogic - n_e_max) Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
acc += exp_logic * tv )
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
e_sum = e_sum * old_scale + exp_logic old_scale = tl.exp(e_max - n_e_max)
e_max = n_e_max acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store( tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
...@@ -523,6 +531,7 @@ def _decode_softmax_reducev_fwd( ...@@ -523,6 +531,7 @@ def _decode_softmax_reducev_fwd(
q, q,
o, o,
v_buffer, v_buffer,
b_seq_len,
num_kv_splits, num_kv_splits,
): ):
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
...@@ -541,6 +550,7 @@ def _decode_softmax_reducev_fwd( ...@@ -541,6 +550,7 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2[grid]( _fwd_kernel_stage2[grid](
logits, logits,
o, o,
b_seq_len,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
logits.stride(2), logits.stride(2),
...@@ -580,7 +590,7 @@ def decode_attention_fwd_normal( ...@@ -580,7 +590,7 @@ def decode_attention_fwd_normal(
sm_scale, sm_scale,
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
def decode_attention_fwd_grouped( def decode_attention_fwd_grouped(
...@@ -608,7 +618,7 @@ def decode_attention_fwd_grouped( ...@@ -608,7 +618,7 @@ def decode_attention_fwd_grouped(
sm_scale, sm_scale,
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
def decode_attention_fwd( def decode_attention_fwd(
......
...@@ -232,9 +232,9 @@ class TestTritonAttention(unittest.TestCase): ...@@ -232,9 +232,9 @@ class TestTritonAttention(unittest.TestCase):
for B, H_Q, H_KV, D in configs: for B, H_Q, H_KV, D in configs:
self._test_decode_attention_once(B, H_Q, H_KV, D) self._test_decode_attention_once(B, H_Q, H_KV, D)
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
dtype = torch.bfloat16 dtype = torch.bfloat16
seq_len = 128 # This represents the number of tokens already in the sequence seq_len = S # This represents the number of tokens already in the sequence
total_tokens = B * seq_len total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5) sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8 num_kv_splits = 8
...@@ -300,6 +300,7 @@ class TestTritonAttention(unittest.TestCase): ...@@ -300,6 +300,7 @@ class TestTritonAttention(unittest.TestCase):
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2)) self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
def test_grouped_decode_attention(self): def test_grouped_decode_attention(self):
seq_lens = [5, 100, 128, 500]
configs = [ configs = [
(2, 16, 16, 64, 64), (2, 16, 16, 64, 64),
(2, 16, 1, 64, 64), (2, 16, 1, 64, 64),
...@@ -309,8 +310,9 @@ class TestTritonAttention(unittest.TestCase): ...@@ -309,8 +310,9 @@ class TestTritonAttention(unittest.TestCase):
(2, 128, 1, 576, 512), (2, 128, 1, 576, 512),
] ]
for B, H_Q, H_KV, D, D_V in configs: for S in seq_lens:
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V) for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment