Unverified Commit 7dc66fcb authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Optimize Triton decoding kernel for long context (#2394)

parent 1f09e84b
......@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend):
else:
self.reduce_dtype = torch.float16
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.forward_metadata = None
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
......@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend):
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
total_num_tokens = forward_batch.seq_lens_sum
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.reduce_dtype,
(
forward_batch.batch_size,
self.num_head,
self.num_kv_splits,
self.v_head_dim + 1,
),
dtype=torch.float32,
device=self.device,
)
......@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend):
(max_bs,), dtype=torch.int32, device=self.device
)
self.cuda_graph_attn_logits = torch.empty(
(
self.num_head,
self.cuda_graph_max_total_num_tokens,
),
dtype=self.reduce_dtype,
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
dtype=torch.float32,
device="cuda",
)
......@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.seq_lens,
attn_logits,
max_seq_len,
self.num_kv_splits,
layer.scaling,
layer.logit_cap,
)
......
......@@ -141,6 +141,7 @@ class ServerArgs:
enable_nan_detection: bool = False
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
......@@ -753,6 +754,12 @@ class ServerArgs:
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.",
)
parser.add_argument(
"--triton-attention-num-kv-splits",
type=int,
default=ServerArgs.triton_attention_num_kv_splits,
help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
)
parser.add_argument(
"--num-continuous-decode-steps",
type=int,
......
......@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
seq_len = 10 # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
......@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len = torch.full((B,), seq_len, device="cuda")
attn_logits = torch.empty(
(H_Q, total_tokens),
dtype=dtype,
(B, H_Q, num_kv_splits, D + 1),
dtype=torch.float32,
device="cuda",
)
......@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len,
attn_logits,
seq_len,
num_kv_splits,
sm_scale,
)
......@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
dtype = torch.bfloat16
seq_len = 10 # This represents the number of tokens already in the sequence
seq_len = 128 # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
......@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
......@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len = torch.full((B,), seq_len, device="cuda")
attn_logits = torch.empty(
(H_Q, total_tokens),
dtype=dtype,
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
......@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
seq_len,
num_kv_splits,
sm_scale,
)
attn_logits1 = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
decode_attention_fwd_grouped(
q,
k_buffer,
......@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
o_grouped,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
attn_logits1,
seq_len,
num_kv_splits,
sm_scale,
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0
)
print(cos_sim.item())
self.assertTrue(cos_sim.item() > 0.99)
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
def test_grouped_decode_attention(self):
configs = [
(2, 16, 16, 64, 64),
(2, 16, 1, 64, 64),
(2, 64, 1, 13, 13),
(2, 128, 1, 80, 80),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment