Unverified Commit 7dc66fcb authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Optimize Triton decoding kernel for long context (#2394)

parent 1f09e84b
...@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend):
else: else:
self.reduce_dtype = torch.float16 self.reduce_dtype = torch.float16
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.forward_metadata = None self.forward_metadata = None
self.cuda_graph_max_seq_len = model_runner.model_config.context_len self.cuda_graph_max_seq_len = model_runner.model_config.context_len
...@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend): ...@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend):
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
total_num_tokens = forward_batch.seq_lens_sum
attn_logits = torch.empty( attn_logits = torch.empty(
(self.num_head, total_num_tokens), (
dtype=self.reduce_dtype, forward_batch.batch_size,
self.num_head,
self.num_kv_splits,
self.v_head_dim + 1,
),
dtype=torch.float32,
device=self.device, device=self.device,
) )
...@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend): ...@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend):
(max_bs,), dtype=torch.int32, device=self.device (max_bs,), dtype=torch.int32, device=self.device
) )
self.cuda_graph_attn_logits = torch.empty( self.cuda_graph_attn_logits = torch.empty(
( (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
self.num_head, dtype=torch.float32,
self.cuda_graph_max_total_num_tokens,
),
dtype=self.reduce_dtype,
device="cuda", device="cuda",
) )
...@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.seq_lens, forward_batch.seq_lens,
attn_logits, attn_logits,
max_seq_len, max_seq_len,
self.num_kv_splits,
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
) )
......
...@@ -141,6 +141,7 @@ class ServerArgs: ...@@ -141,6 +141,7 @@ class ServerArgs:
enable_nan_detection: bool = False enable_nan_detection: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1 num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False delete_ckpt_after_loading: bool = False
...@@ -753,6 +754,12 @@ class ServerArgs: ...@@ -753,6 +754,12 @@ class ServerArgs:
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.", "This only affects Triton attention kernels.",
) )
parser.add_argument(
"--triton-attention-num-kv-splits",
type=int,
default=ServerArgs.triton_attention_num_kv_splits,
help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
)
parser.add_argument( parser.add_argument(
"--num-continuous-decode-steps", "--num-continuous-decode-steps",
type=int, type=int,
......
...@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase): ...@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
seq_len = 10 # This represents the number of tokens already in the sequence seq_len = 10 # This represents the number of tokens already in the sequence
total_tokens = B * seq_len total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5) sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
# q represents the new token being generated, one per batch # q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
...@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len = torch.full((B,), seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda")
attn_logits = torch.empty( attn_logits = torch.empty(
(H_Q, total_tokens), (B, H_Q, num_kv_splits, D + 1),
dtype=dtype, dtype=torch.float32,
device="cuda", device="cuda",
) )
...@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase): ...@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len, b_seq_len,
attn_logits, attn_logits,
seq_len, seq_len,
num_kv_splits,
sm_scale, sm_scale,
) )
...@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase): ...@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
dtype = torch.bfloat16 dtype = torch.bfloat16
seq_len = 10 # This represents the number of tokens already in the sequence seq_len = 128 # This represents the number of tokens already in the sequence
total_tokens = B * seq_len total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5) sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
# q represents the new token being generated, one per batch # q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
...@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q # o will have the same shape as q
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda") b_req_idx = torch.arange(B, device="cuda")
...@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len = torch.full((B,), seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda")
attn_logits = torch.empty( attn_logits = torch.empty(
(H_Q, total_tokens), (B, H_Q, num_kv_splits, D_V + 1),
dtype=dtype, dtype=torch.float32,
device="cuda", device="cuda",
) )
...@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase): ...@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
o, o,
req_to_token, req_to_token,
b_req_idx, b_req_idx,
b_start_loc,
b_seq_len, b_seq_len,
attn_logits, attn_logits,
seq_len, seq_len,
num_kv_splits,
sm_scale, sm_scale,
) )
attn_logits1 = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
decode_attention_fwd_grouped( decode_attention_fwd_grouped(
q, q,
k_buffer, k_buffer,
...@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase): ...@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
o_grouped, o_grouped,
req_to_token, req_to_token,
b_req_idx, b_req_idx,
b_start_loc,
b_seq_len, b_seq_len,
attn_logits, attn_logits1,
seq_len, seq_len,
num_kv_splits,
sm_scale, sm_scale,
) )
cos_sim = torch.nn.functional.cosine_similarity( cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0 o.flatten(), o_grouped.flatten(), dim=0
) )
print(cos_sim.item())
self.assertTrue(cos_sim.item() > 0.99) self.assertTrue(cos_sim.item() > 0.99)
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2)) self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
def test_grouped_decode_attention(self): def test_grouped_decode_attention(self):
configs = [ configs = [
(2, 16, 16, 64, 64),
(2, 16, 1, 64, 64), (2, 16, 1, 64, 64),
(2, 64, 1, 13, 13), (2, 64, 1, 13, 13),
(2, 128, 1, 80, 80), (2, 128, 1, 80, 80),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment