Unverified Commit 7e6d5fc6 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support Eagle cuda graph for Triton backend (#3500)

parent cadd5dbe
......@@ -38,6 +38,8 @@ class TritonAttnBackend(AttentionBackend):
self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.skip_prefill = skip_prefill
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
......@@ -48,13 +50,15 @@ class TritonAttnBackend(AttentionBackend):
self.kv_indptr = kv_indptr_buf
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.mask_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
)
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.mask_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
)
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
......@@ -196,22 +200,29 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr,
)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_max_total_num_tokens = max_bs * self.max_context_len
self.cuda_graph_start_loc = torch.zeros(
(max_bs,), dtype=torch.int32, device=self.device
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
self.cuda_graph_attn_logits = torch.zeros(
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
dtype=torch.float32,
device=self.device,
)
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
else:
self.cuda_graph_kv_indices = kv_indices_buf
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
self,
......@@ -224,31 +235,71 @@ class TritonAttnBackend(AttentionBackend):
spec_info: Optional[SpecInfo],
):
assert encoder_lens is None, "Not supported"
assert forward_mode.is_decode(), "Not supported"
assert spec_info is None, "Not supported"
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
if forward_mode.is_decode_or_idle():
if spec_info is None:
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
attn_logits = self.cuda_graph_attn_logits
max_extend_len = None
qo_indptr = None
custom_mask = None
mask_indptr = None
elif forward_mode.is_target_verify():
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
custom_mask = self.cuda_graph_custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
max_extend_len = self.num_draft_tokens
attn_logits = None
else:
raise ValueError(
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
)
self.forward_metadata = (
self.cuda_graph_attn_logits,
None,
attn_logits,
max_extend_len,
kv_indptr,
kv_indices,
None,
None,
None,
qo_indptr,
custom_mask,
mask_indptr,
)
def init_forward_metadata_replay_cuda_graph(
......@@ -262,22 +313,57 @@ class TritonAttnBackend(AttentionBackend):
spec_info: Optional[SpecInfo],
):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens[:bs],
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
if forward_mode.is_decode_or_idle():
# Update kv_indptr, kv_indices
kv_indptr = self.kv_indptr
kv_indices = self.cuda_graph_kv_indices
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
kv_indptr = kv_indptr[: bs + 1]
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens[:bs],
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
elif forward_mode.is_target_verify():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs = len(req_pool_indices)
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
else:
raise ValueError(
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
)
def get_cuda_graph_seq_len_fill_value(self):
return 1
......@@ -407,6 +493,7 @@ class TritonMultiStepDraftBackend:
)
)
self.max_context_len = self.attn_backends[0].max_context_len
self.device = model_runner.device
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
......@@ -450,7 +537,7 @@ class TritonMultiStepDraftBackend:
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device="cuda",
device=self.device,
)
def call_fn(i, forward_batch):
......@@ -468,7 +555,7 @@ class TritonMultiStepDraftBackend:
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
device=self.device,
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
......
......@@ -216,8 +216,6 @@ class TestEAGLEServerTriton(TestEAGLEServer):
"0.7",
"--attention-backend",
"triton",
# TODO: Support cuda graph
"--disable-cuda-graph",
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment