Unverified Commit 53f7874a authored by valarLip's avatar valarLip Committed by GitHub
Browse files

refine aiter_backend for mtp (#7279)


Co-authored-by: default avatarHAI <hixiao@gmail.com>
parent 61a46804
...@@ -32,7 +32,7 @@ try: ...@@ -32,7 +32,7 @@ try:
mha_batch_prefill_func, mha_batch_prefill_func,
paged_attention_ragged, paged_attention_ragged,
) )
from aiter.mla import mla_decode_fwd from aiter.mla import mla_decode_fwd, mla_prefill_fwd
except ImportError: except ImportError:
print( print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
...@@ -52,10 +52,8 @@ class ForwardMetadata: ...@@ -52,10 +52,8 @@ class ForwardMetadata:
kv_indices: torch.Tensor kv_indices: torch.Tensor
qo_indptr: torch.Tensor qo_indptr: torch.Tensor
kv_last_page_len: torch.Tensor kv_last_page_len: torch.Tensor
max_extend_len: int
max_prefix_extend_len: int
max_q_len: int max_q_len: int
max_kv_len: int max_kv_len: Optional[int]
global_workspace_buffer = None global_workspace_buffer = None
...@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend): ...@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
kv_indptr_buf: Optional[torch.Tensor] = None, kv_indptr_buf: Optional[torch.Tensor] = None,
): ):
super().__init__() super().__init__()
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
)
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
self.device = model_runner.device self.device = model_runner.device
self.is_multimodal = model_runner.model_config.is_multimodal self.is_multimodal = model_runner.model_config.is_multimodal
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.num_head = ( self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
...@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend): ...@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
spec_info = forward_batch.spec_info spec_info = forward_batch.spec_info
qo_indptr = None qo_indptr = None
kv_last_page_len = None kv_last_page_len = None
max_extend_len = None max_q_len = None
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None: if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros( kv_indices = torch.empty(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
...@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend): ...@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
qo_indptr = self.qo_indptr_[: bs + 1] qo_indptr = self.qo_indptr_[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0) qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
kv_last_page_len = self.kv_last_page_len[:bs] kv_last_page_len = self.kv_last_page_len[:bs]
max_extend_len = 1 max_q_len = 1
self.forward_metadata = ForwardMetadata( self.forward_metadata = ForwardMetadata(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
qo_indptr, qo_indptr,
kv_last_page_len, kv_last_page_len,
max_extend_len, max_q_len,
None,
None,
None, None,
) )
elif forward_batch.forward_mode.is_draft_extend(): elif forward_batch.forward_mode.is_draft_extend():
if self.use_mla: if self.use_mla:
prefix_lens = forward_batch.extend_prefix_lens kv_indices, kv_indptr, qo_indptr, custom_mask = (
self.mla_indices_updater_prefill.update( spec_info.generate_attn_arg_prefill(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
prefix_lens, forward_batch.seq_lens,
prefix_lens.sum().item(), forward_batch.seq_lens_sum,
forward_batch.extend_seq_lens, self.req_to_token,
encoder_lens=forward_batch.encoder_lens, )
spec_info=None,
) )
self.forward_metadata = ForwardMetadata( self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr, kv_indptr,
self.mla_indices_updater_prefill.kv_indices, kv_indices,
self.mla_indices_updater_prefill.qo_indptr, qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len, # self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len, self.kv_last_page_len[:bs],
self.mla_indices_updater_prefill.max_prefix_extend_len, max(forward_batch.extend_seq_lens_cpu),
None, forward_batch.seq_lens_cpu.max().item(),
None,
) )
else: else:
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
...@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend): ...@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill.kv_indices, self.indices_updater_prefill.kv_indices,
None, None,
None, None,
None,
None,
self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len, self.indices_updater_prefill.max_kv_len,
) )
elif forward_batch.forward_mode.is_target_verify(): elif forward_batch.forward_mode.is_target_verify():
if self.use_mla: if self.use_mla:
prefix_lens = forward_batch.extend_prefix_lens draft_num = spec_info.draft_token_num
self.mla_indices_updater_prefill.update( kv_lens = forward_batch.seq_lens + draft_num
kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs
device = forward_batch.seq_lens.device
qo_indptr = torch.arange(
0,
(1 + bs) * draft_num,
step=draft_num,
dtype=torch.int32,
device=device,
)
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
kv_lens_sum,
dtype=torch.int32,
device=device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
prefix_lens, kv_lens,
prefix_lens.sum().item(), kv_indptr,
forward_batch.extend_seq_lens, None,
encoder_lens=forward_batch.encoder_lens, kv_indices,
spec_info=None, self.req_to_token.stride(0),
) )
self.forward_metadata = ForwardMetadata( self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr, kv_indptr,
self.mla_indices_updater_prefill.kv_indices, kv_indices,
self.mla_indices_updater_prefill.qo_indptr, qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len, # self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len, self.kv_last_page_len[:bs],
self.mla_indices_updater_prefill.max_prefix_extend_len, draft_num,
None,
None, None,
) )
else: else:
...@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend): ...@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill.kv_indices, self.indices_updater_prefill.kv_indices,
None, None,
None, None,
None,
None,
self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len, self.indices_updater_prefill.max_kv_len,
) )
...@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend): ...@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
extend_no_prefix = False extend_no_prefix = False
else: else:
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
if self.use_mla: if self.use_mla:
self.mla_indices_updater_prefill.update( self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
prefix_lens, forward_batch.extend_prefix_lens,
prefix_lens.sum().item(), sum(forward_batch.extend_prefix_lens_cpu),
forward_batch.extend_seq_lens, forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens, max(forward_batch.extend_seq_lens_cpu),
forward_batch.seq_lens_cpu.max().item(),
spec_info=None, spec_info=None,
) )
self.mla_indices_updater_prefill.kv_indptr += (
self.mla_indices_updater_prefill.qo_indptr
)
self.forward_metadata = ForwardMetadata( self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr, self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices, self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr, self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len, self.kv_last_page_len[:bs],
self.mla_indices_updater_prefill.max_extend_len, self.mla_indices_updater_prefill.max_q_len,
self.mla_indices_updater_prefill.max_prefix_extend_len, self.mla_indices_updater_prefill.max_kv_len,
None,
None,
) )
else: else:
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
...@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend): ...@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill.kv_indices, self.indices_updater_prefill.kv_indices,
None, None,
None, None,
None,
None,
self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len, self.indices_updater_prefill.max_kv_len,
) )
...@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend): ...@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
qo_indptr = None qo_indptr = None
kv_last_page_len = None kv_last_page_len = None
max_extend_len = None max_q_len = None
if spec_info is None: if spec_info is None:
kv_indptr = self.kv_indptr kv_indptr = self.kv_indptr
...@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend): ...@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
qo_indptr[1 : bs + 1] = torch.cumsum( qo_indptr[1 : bs + 1] = torch.cumsum(
self.cuda_graph_kv_last_page_len[:bs], dim=0 self.cuda_graph_kv_last_page_len[:bs], dim=0
) )
max_extend_len = 1
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = 1
self.forward_metadata = ForwardMetadata( self.forward_metadata = ForwardMetadata(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
qo_indptr, qo_indptr,
kv_last_page_len, kv_last_page_len,
max_extend_len, max_q_len,
None,
None,
None, None,
) )
...@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend): ...@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
kv_indices, kv_indices,
self.req_to_token.stride(0), self.req_to_token.stride(0),
) )
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_extend_len = self.num_draft_tokens max_q_len = self.num_draft_tokens
kv_last_page_len = None
self.forward_metadata = ForwardMetadata( self.forward_metadata = ForwardMetadata(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
qo_indptr, qo_indptr,
kv_last_page_len, kv_last_page_len,
max_extend_len, max_q_len,
None,
None,
None, None,
) )
else: else:
...@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend): ...@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill.kv_indices, self.indices_updater_prefill.kv_indices,
None, None,
None, None,
None,
None,
self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len, self.indices_updater_prefill.max_kv_len,
) )
elif forward_mode.is_draft_extend():
num_tokens_per_bs = self.speculative_num_steps + 1
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
bs * num_tokens_per_bs + 1,
step=num_tokens_per_bs,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = num_tokens_per_bs
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_q_len,
None,
)
else: else:
raise ValueError(f"Invalid mode: {forward_mode=}") raise ValueError(f"Invalid mode: {forward_mode=}")
...@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend): ...@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
self.indices_updater_prefill.update( bs = len(req_pool_indices)
req_pool_indices[:bs], qo_indptr = self.qo_indptr[: bs + 1]
seq_lens[:bs], qo_indptr[: bs + 1] = torch.arange(
seq_lens_sum, 0,
prefix_lens=None, (1 + bs) * self.num_draft_tokens,
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, step=self.num_draft_tokens,
spec_info=spec_info, dtype=torch.int32,
device=self.device,
)
kv_lens = seq_lens + self.num_draft_tokens
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
kv_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
elif forward_mode.is_draft_extend():
seq_lens = seq_lens[:bs]
accept_lens = spec_info.accept_length[:bs]
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
) )
else: else:
raise ValueError("Invalid forward mode") raise ValueError("Invalid forward mode")
...@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend): ...@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
) )
if self.use_mla: if self.use_mla:
max_extend_len = self.forward_metadata.max_extend_len max_q_len = self.forward_metadata.max_q_len
max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len max_kv_len = self.forward_metadata.max_kv_len
kv_indptr = self.forward_metadata.kv_indptr kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices kv_indices = self.forward_metadata.kv_indices
kv_last_page_lens = self.forward_metadata.kv_last_page_len
qo_indptr = self.forward_metadata.qo_indptr qo_indptr = self.forward_metadata.qo_indptr
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
...@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend): ...@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
v, v,
qo_indptr, qo_indptr,
qo_indptr, qo_indptr,
max_extend_len, max_q_len,
max_extend_len, max_q_len,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
) )
...@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend): ...@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
v, v,
qo_indptr, qo_indptr,
kv_indptr, kv_indptr,
max_extend_len, max_q_len,
max_prefix_extend_len, max_kv_len,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
) )
return o return o
elif forward_batch.forward_mode.is_target_verify():
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
mla_decode_fwd(
q,
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
o,
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_q_len,
layer.scaling,
layer.logit_cap,
)
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
return o
elif forward_batch.forward_mode.is_draft_extend():
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
causal = True
sliding_window_size = -1
kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices
mla_prefill_fwd(
q,
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
o,
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_q_len,
layer.scaling,
layer.logit_cap,
)
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
return o
# self.extend_attention_fwd(
# q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
# k.contiguous(),
# v.contiguous(),
# o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
# forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
# forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
# self.forward_metadata.qo_indptr,
# kv_indptr,
# kv_indices,
# None,
# causal,
# None,
# self.forward_metadata.max_q_len,
# layer.scaling,
# layer.logit_cap,
# sliding_window_size,
# )
# return o
else:
raise ValueError(
f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}"
)
else: else:
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id layer.layer_id
...@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend): ...@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
self.forward_metadata.kv_indptr, self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices, self.forward_metadata.kv_indices,
self.forward_metadata.kv_last_page_len, self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_extend_len, self.forward_metadata.max_q_len,
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
) )
...@@ -816,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill: ...@@ -816,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
self.kv_indices = None self.kv_indices = None
self.qo_indptr = None self.qo_indptr = None
self.kv_last_page_len = None self.kv_last_page_len = None
self.max_extend_len = 0 self.max_q_len = 0
self.max_prefix_extend_len = 0 self.max_kv_len = 0
def update( def update(
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor, kv_lens: torch.Tensor,
prefix_lens_sum: int, kv_lens_sum: int,
extend_lens: torch.Tensor, extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], max_q_len: int,
max_kv_len: int,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
...@@ -834,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill: ...@@ -834,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
def update_single_wrapper( def update_single_wrapper(
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor, kv_lens: torch.Tensor,
prefix_lens_sum: int, kv_lens_sum: int,
extend_lens: torch.Tensor, extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], max_q_len: int,
max_kv_len: int,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = prefix_lens_sum
bs = len(req_pool_indices) bs = len(req_pool_indices)
kv_indptr = self.attn_backend.kv_indptr kv_indptr = self.attn_backend.kv_indptr
if spec_info is None: if spec_info is None:
# Normal extend # Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.empty(
paged_kernel_lens_sum, kv_lens_sum,
dtype=torch.int32, dtype=torch.int32,
device=req_pool_indices.device, device=req_pool_indices.device,
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices, req_pool_indices,
paged_kernel_lens, kv_lens,
kv_indptr, kv_indptr,
None, None,
kv_indices, kv_indices,
...@@ -870,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill: ...@@ -870,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
qo_indptr = self.attn_backend.qo_indptr qo_indptr = self.attn_backend.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0) qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
max_extend_len = torch.max(extend_lens).item()
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
kv_indptr += qo_indptr
else: else:
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
req_pool_indices, req_pool_indices,
paged_kernel_lens, kv_lens,
paged_kernel_lens_sum, kv_lens_sum,
self.req_to_token, self.req_to_token,
) )
) )
...@@ -887,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill: ...@@ -887,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
self.kv_indptr = kv_indptr self.kv_indptr = kv_indptr
self.kv_indices = kv_indices self.kv_indices = kv_indices
self.qo_indptr = qo_indptr self.qo_indptr = qo_indptr
self.max_extend_len = max_extend_len self.max_q_len = max_q_len
self.max_prefix_extend_len = max_prefix_extend_len self.max_kv_len = max_kv_len
class AiterMultiStepDraftBackend:
"""
Wrap multiple triton attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
AiterAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
)
)
self.max_context_len = self.attn_backends[0].max_context_len
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.device = model_runner.device
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.page_size = model_runner.server_args.page_size
assert self.page_size == 1, "Page size must be 1"
def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
self.page_size,
)
for i in range(self.speculative_num_steps):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.empty(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device=self.device,
)
def call_fn(i, forward_batch):
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
...@@ -1722,6 +1722,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1722,6 +1722,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
or attention_backend_str == "cutlass_mla" or attention_backend_str == "cutlass_mla"
or attention_backend_str == "ascend" or attention_backend_str == "ascend"
or attention_backend_str == "trtllm_mha" or attention_backend_str == "trtllm_mha"
or attention_backend_str == "aiter"
or global_server_args_dict["enable_two_batch_overlap"] or global_server_args_dict["enable_two_batch_overlap"]
): ):
seq_lens_cpu = ( seq_lens_cpu = (
......
...@@ -226,6 +226,22 @@ class EAGLEWorker(TpModelWorker): ...@@ -226,6 +226,22 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner, self.draft_model_runner,
skip_prefill=False, skip_prefill=False,
) )
elif self.server_args.attention_backend == "aiter":
from sglang.srt.layers.attention.aiter_backend import (
AiterAttnBackend,
AiterMultiStepDraftBackend,
)
self.draft_attn_backend = AiterMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = AiterAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3": elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import ( from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend, FlashAttentionBackend,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment