Unverified Commit ebd9dbe7 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: revert #8593 (#9581)

parent 938e986e
...@@ -24,7 +24,9 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1": ...@@ -24,7 +24,9 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -179,6 +181,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -179,6 +181,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_indptr_decode_buf: Optional[torch.Tensor] = None, q_indptr_decode_buf: Optional[torch.Tensor] = None,
): ):
super().__init__() super().__init__()
# Parse constants # Parse constants
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device self.device = model_runner.device
...@@ -210,25 +213,15 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -210,25 +213,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
else: else:
self.kv_indptr = kv_indptr_buf self.kv_indptr = kv_indptr_buf
self.kv_indices = torch.empty(
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
dtype=torch.int32,
device=model_runner.device,
)
if not self.skip_prefill: if not self.skip_prefill:
self.qo_indptr = torch.zeros( self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device (max_bs + 1,), dtype=torch.int32, device=model_runner.device
) )
if q_indptr_decode_buf is None: if q_indptr_decode_buf is None:
# A hack to pre-initialize large batch size for dp attention
if model_runner.server_args.enable_dp_attention:
max_bs = model_runner.server_args.dp_size * max_bs
self.q_indptr_decode = torch.arange( self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
) )
else: else:
self.q_indptr_decode = q_indptr_decode_buf self.q_indptr_decode = q_indptr_decode_buf
...@@ -273,7 +266,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -273,7 +266,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.prefill_cuda_graph_metadata = {} # For verify self.prefill_cuda_graph_metadata = {} # For verify
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update( self.indices_updater_decode.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
...@@ -331,9 +323,16 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -331,9 +323,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
max_num_tokens: int, max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None, kv_indices_buf: Optional[torch.Tensor] = None,
): ):
self.cuda_graph_kv_indices = ( if kv_indices_buf is None:
self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf cuda_graph_kv_indices = torch.zeros(
) (max_bs * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = cuda_graph_kv_indices
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone() self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
self.cuda_graph_kv_indptr = self.kv_indptr.clone() self.cuda_graph_kv_indptr = self.kv_indptr.clone()
self.cuda_graph_kv_lens = torch.ones( self.cuda_graph_kv_lens = torch.ones(
...@@ -359,7 +358,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -359,7 +358,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper( decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, self.workspace_buffer,
...@@ -370,6 +368,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -370,6 +368,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens], kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
backend="auto", backend="auto",
) )
seq_lens_sum = seq_lens.sum().item() seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices, req_pool_indices,
...@@ -440,13 +439,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -440,13 +439,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None assert seq_lens_cpu is not None
kv_len_arr_cpu = seq_lens_cpu[:bs] kv_len_arr_cpu = seq_lens_cpu[:bs]
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum( self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
num_pages_per_req, dim=0 kv_len_arr_cpu, dim=0
) )
self.fast_decode_kwargs.update( self.fast_decode_kwargs.update(
{ {
...@@ -455,6 +452,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -455,6 +452,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
"kv_len_arr_cpu": kv_len_arr_cpu, "kv_len_arr_cpu": kv_len_arr_cpu,
} }
) )
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens[:bs],
...@@ -534,6 +532,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -534,6 +532,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_rope = q_rope.view( q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
) )
if self.forward_metadata.use_ragged: if self.forward_metadata.use_ragged:
# ragged prefill # ragged prefill
if q_rope is not None: if q_rope is not None:
...@@ -554,8 +553,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -554,8 +553,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype q.dtype
) )
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
if q_rope is None: if q_rope is None:
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q, q_rope = ( q, q_rope = (
...@@ -617,17 +614,17 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -617,17 +614,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_nope = reshaped_q[:, :, : layer.v_head_dim] q_nope = reshaped_q[:, :, : layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim :] q_rope = reshaped_q[:, :, layer.v_head_dim :]
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype q.dtype
) )
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
o = q_nope.new_empty(q_nope.shape) o = q_nope.new_empty(q_nope.shape)
# Direct call to run without the wrapper
o = decode_wrapper.run( o = decode_wrapper.run(
q_nope, q_nope,
q_rope, q_rope,
k_buf[:, :, : layer.v_head_dim], k_buffer[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :], k_buffer[:, :, layer.v_head_dim :],
out=o, out=o,
) )
...@@ -646,10 +643,9 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -646,10 +643,9 @@ class FlashInferMLAIndicesUpdaterDecode:
self.scaling = model_runner.model_config.scaling self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.dtype self.data_type = model_runner.dtype
self.attn_backend = attn_backend self.attn_backend = attn_backend
self.page_size = model_runner.page_size
# Buffers and wrappers # Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_indices = attn_backend.kv_indices
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode self.q_indptr = attn_backend.q_indptr_decode
...@@ -693,17 +689,13 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -693,17 +689,13 @@ class FlashInferMLAIndicesUpdaterDecode:
kv_lens = paged_kernel_lens.to(torch.int32) kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling sm_scale = self.scaling
if spec_info is None: if spec_info is None:
num_pages_per_req = ( kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
paged_kernel_lens + self.page_size - 1
) // self.page_size
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = ( kv_indices = (
self.kv_indices[: kv_indptr[-1]] torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
if not init_metadata_replay if not init_metadata_replay
else fast_decode_kwargs["kv_indices"] else fast_decode_kwargs["kv_indices"]
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices, req_pool_indices,
...@@ -712,40 +704,39 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -712,40 +704,39 @@ class FlashInferMLAIndicesUpdaterDecode:
None, None,
kv_indices, kv_indices,
self.req_to_token.shape[1], self.req_to_token.shape[1],
self.page_size,
) )
else: else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
if not init_metadata_replay: if not init_metadata_replay:
wrapper.plan( wrapper.plan(
qo_indptr=q_indptr, q_indptr,
kv_indptr=kv_indptr, kv_indptr,
kv_indices=kv_indices, kv_indices,
kv_len_arr=kv_lens, kv_lens,
num_heads=self.num_local_heads, self.num_local_heads,
head_dim_ckv=self.kv_lora_rank, self.kv_lora_rank,
head_dim_kpe=self.qk_rope_head_dim, self.qk_rope_head_dim,
page_size=self.page_size, 1,
causal=False, False,
sm_scale=sm_scale, sm_scale,
q_data_type=self.data_type, self.data_type,
kv_data_type=self.data_type, self.data_type,
) )
else: else:
wrapper.plan( wrapper.plan(
qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"], fast_decode_kwargs["qo_indptr_cpu"],
kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"], fast_decode_kwargs["kv_indptr_cpu"],
kv_indices=kv_indices, kv_indices,
kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"], fast_decode_kwargs["kv_len_arr_cpu"],
num_heads=self.num_local_heads, self.num_local_heads,
head_dim_ckv=self.kv_lora_rank, self.kv_lora_rank,
head_dim_kpe=self.qk_rope_head_dim, self.qk_rope_head_dim,
page_size=self.page_size, 1,
causal=False, False,
sm_scale=sm_scale, sm_scale,
q_data_type=self.data_type, self.data_type,
kv_data_type=self.data_type, self.data_type,
) )
...@@ -767,14 +758,12 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -767,14 +758,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
# Buffers and wrappers # Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.qo_indptr = attn_backend.qo_indptr self.qo_indptr = attn_backend.qo_indptr
self.kv_indices = attn_backend.kv_indices
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
self.page_size = model_runner.page_size
def update( def update(
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tnesor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
...@@ -788,6 +777,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -788,6 +777,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
else: else:
paged_kernel_lens = seq_lens paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward( self.call_begin_forward(
self.prefill_wrapper_ragged, self.prefill_wrapper_ragged,
prefill_wrapper_paged, prefill_wrapper_paged,
...@@ -821,12 +811,13 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -821,12 +811,13 @@ class FlashInferMLAIndicesUpdaterPrefill:
if spec_info is None: if spec_info is None:
assert len(seq_lens) == len(req_pool_indices) assert len(seq_lens) == len(req_pool_indices)
num_pages_per_req = ( kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
paged_kernel_lens + self.page_size - 1
) // self.page_size
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.kv_indices[: kv_indptr[-1]] kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices, req_pool_indices,
...@@ -835,7 +826,6 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -835,7 +826,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
None, None,
kv_indices, kv_indices,
self.req_to_token.shape[1], self.req_to_token.shape[1],
self.page_size,
) )
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
...@@ -853,6 +843,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -853,6 +843,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.req_to_token, self.req_to_token,
) )
) )
if use_ragged: if use_ragged:
# ragged prefill # ragged prefill
wrapper_ragged.begin_forward( wrapper_ragged.begin_forward(
...@@ -867,26 +858,20 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -867,26 +858,20 @@ class FlashInferMLAIndicesUpdaterPrefill:
) )
else: else:
# mla paged prefill # mla paged prefill
if spec_info is not None: kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
assert (
self.page_size == 1
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
else:
kv_lens = paged_kernel_lens.to(torch.int32)
wrapper_paged.plan( wrapper_paged.plan(
qo_indptr=qo_indptr, qo_indptr,
kv_indptr=kv_indptr, kv_indptr,
kv_indices=kv_indices, kv_indices,
kv_len_arr=kv_lens, kv_len_arr,
num_heads=self.num_local_heads, self.num_local_heads,
head_dim_ckv=self.kv_lora_rank, self.kv_lora_rank,
head_dim_kpe=self.qk_rope_head_dim, self.qk_rope_head_dim,
page_size=self.page_size, 1,
causal=True, True,
sm_scale=sm_scale, sm_scale,
q_data_type=self.q_data_type, self.q_data_type,
kv_data_type=self.data_type, self.data_type,
) )
...@@ -981,7 +966,6 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -981,7 +966,6 @@ class FlashInferMLAMultiStepDraftBackend:
call_fn(i, forward_batch) call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros( kv_indices = torch.zeros(
( (
self.speculative_num_steps, self.speculative_num_steps,
...@@ -1017,7 +1001,6 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -1017,7 +1001,6 @@ class FlashInferMLAMultiStepDraftBackend:
) )
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph( self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size, forward_batch.batch_size,
...@@ -1034,7 +1017,6 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -1034,7 +1017,6 @@ class FlashInferMLAMultiStepDraftBackend:
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int self, forward_batch: ForwardBatch, bs: int
): ):
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph( self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs, bs,
......
...@@ -9,89 +9,18 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64 ...@@ -9,89 +9,18 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
@triton.jit @triton.jit
def create_flashinfer_kv_indices_triton( def create_flashinfer_kv_indices_triton(
req_to_token_ptr, req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr, req_pool_indices_ptr,
page_kernel_lens_ptr, page_kernel_lens_ptr,
kv_indptr, kv_indptr,
kv_start_idx, kv_start_idx,
kv_indices_ptr, kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr, req_to_token_ptr_stride: tl.constexpr,
PAGE_SIZE: tl.constexpr = 1,
): ):
"""
Create KV indices for FlashInfer attention backend.
This Triton kernel builds a lookup table that maps from logical request/token
coordinates to physical token locations in the global KV cache pool. It's used
by FlashInfer attention backends to efficiently access scattered KV cache data.
The kernel processes each request in parallel and converts the req_to_token
lookup table into a flat list of token indices that can be used by attention kernels.
general idea:
blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
fixed number of pages)]
max_pages = max_context_len / PAGED_SIZE
kv_indices_ptr will store the flat list of the pages used by each request
Args:
Inputs Arguments (non mutable):
req_to_token_ptr: Request to token location look up table
Shape: [max_batch, max_context_len]
req_pool_indices_ptr: Request to pool index look up table. Each request uses
one pool.
Shape: [batch_size]
page_kernel_lens_ptr: sequence lengths per request
Shape: [batch_size]
kv_indptr: Should be computed based on number of pages used by each request.
It is used by flashinfer attention kernels to index into the kv_indices_ptr.
per request.
Shape: [batch_size + 1]
kv_indptr[i] = start index in kv_indices for request i
kv_start_idx: Pointer to array containing start offsets for each request in SGL.
Can be None. If provided, adds offset to token positions.
req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
Equal to max_context_len.
PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
Outputs:
kv_indices_ptr: Pointer to output array where KV indices will be stored.
Shape:[total-num-pages],
where total_num_pages = sum(seq_lens // PAGED_SIZE)
Example:
If we have:
- req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
- page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
- req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
The kernel will output:
If PAGE_SIZE = 1:
packed
- kv_indptr (passed in as input arg): [0,3,5]
- kv_indices = [10, 11, 12, 20, 21]
padded - max_pages is 10 tokens per req
- kv_indptr (passed in as input arg): [0,10, 20]
- kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
If PAGE_SIZE = 2
packed:
- kv_indptr (passed in as input arg): [0,3,4]
- kv_indices = [5,6,10]
padded: max_pages is 4
- kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
- kv_indices = [5, 6, -1, -1,
10, -1, -1, -1]
This allows attention kernels to directly access the correct KV cache
entries for each request's tokens.
"""
BLOCK_SIZE: tl.constexpr = 512 BLOCK_SIZE: tl.constexpr = 512
NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
# find the req pool idx, this is for batch to token
req_pool_index = tl.load(req_pool_indices_ptr + pid) req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid) kv_indices_offset = tl.load(kv_indptr + pid)
...@@ -102,27 +31,19 @@ def create_flashinfer_kv_indices_triton( ...@@ -102,27 +31,19 @@ def create_flashinfer_kv_indices_triton(
kv_end = kv_start kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
kv_range = kv_end - kv_start num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
num_pages = tl.cdiv(kv_range, PAGE_SIZE) for i in range(num_loop):
num_loops = tl.cdiv(kv_range, BLOCK_SIZE) # index into req_to_token_ptr needs to be int64
req_to_token_block_start = ( offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start mask = offset < kv_end - kv_start
) data = tl.load(
for i in range(num_loops): req_to_token_ptr
token_offsets_in_block = ( + req_pool_index * req_to_token_ptr_stride
tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK + kv_start
) * PAGE_SIZE + offset,
page_offsets_in_block = token_offsets_in_block // PAGE_SIZE mask=mask,
valid_tokens = token_offsets_in_block < kv_range
valid_pages = page_offsets_in_block < num_pages
token_numbers = tl.load(
req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
)
tl.store(
kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
mask=valid_pages,
) )
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
@triton.jit @triton.jit
......
...@@ -639,10 +639,6 @@ class ServerArgs: ...@@ -639,10 +639,6 @@ class ServerArgs:
logger.warning( logger.warning(
"DeepSeek MTP does not require setting speculative_draft_model_path." "DeepSeek MTP does not require setting speculative_draft_model_path."
) )
if self.page_size != 1 and self.attention_backend == "flashinfer":
raise ValueError(
"Speculative decoding with page_size != 1 is not supported. Please set page_size to 1."
)
# Auto choose parameters # Auto choose parameters
if self.speculative_num_steps is None: if self.speculative_num_steps is None:
......
...@@ -4,10 +4,7 @@ import unittest ...@@ -4,10 +4,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from sglang.srt.layers.attention.utils import ( from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
create_flashinfer_kv_indices_triton,
create_flashmla_kv_indices_triton,
)
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -18,14 +15,10 @@ class TestCreateKvIndices(CustomTestCase): ...@@ -18,14 +15,10 @@ class TestCreateKvIndices(CustomTestCase):
raise unittest.SkipTest("CUDA is not available") raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda") torch.set_default_device("cuda")
def _run_test(self, batch, max_batch, max_context_len, page_size): def _run_test(self, batch, max_batch, max_context_len):
np.random.seed(9)
PAGE_SIZE = page_size
req_to_token = torch.arange( req_to_token = torch.arange(
max_batch * max_context_len, dtype=torch.int32, device="cuda" max_batch * max_context_len, dtype=torch.int32, device="cuda"
).reshape((max_batch, max_context_len)) ).reshape((max_batch, max_context_len))
# the block table
req_pool_indices = torch.tensor( req_pool_indices = torch.tensor(
torch.from_numpy( torch.from_numpy(
np.random.choice(range(max_batch), size=batch, replace=False) np.random.choice(range(max_batch), size=batch, replace=False)
...@@ -33,84 +26,49 @@ class TestCreateKvIndices(CustomTestCase): ...@@ -33,84 +26,49 @@ class TestCreateKvIndices(CustomTestCase):
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
seq_lens = torch.tensor( paged_kernel_lens = torch.tensor(
torch.from_numpy( torch.from_numpy(
np.random.choice(range(max_context_len), size=batch, replace=False) np.random.choice(range(max_context_len), size=batch, replace=False)
), ),
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
num_pages_per_req = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE
kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda") kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(num_pages_per_req, dim=0) kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
# ref # ref
kv_indices_ref = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
req_pool_indices_cpu = req_pool_indices.cpu().numpy() req_pool_indices_cpu = req_pool_indices.cpu().numpy()
seq_lens_cpu = seq_lens.cpu().numpy() paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
for i in range(batch): kv_indices_ref = torch.cat(
kv_indptr_req = kv_indptr[i] [
num_toks_seq = seq_lens_cpu[i] req_to_token[req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]]
curr_req_pool = req_pool_indices_cpu[i] for i in range(batch)
curr_num_pages = num_pages_per_req[i] ],
curr_token_ids = req_to_token[curr_req_pool] dim=0,
curr_pages = (curr_token_ids[:num_toks_seq] // PAGE_SIZE).unique() ).contiguous()
assert (
len(curr_pages) == curr_num_pages
), f"req {i} has #{curr_num_pages} pages, but got {len(curr_pages)} pages"
kv_indices_ref[kv_indptr_req : kv_indptr_req + curr_num_pages] = curr_pages
# triton # triton
kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(batch,)]( create_flashinfer_kv_indices_triton[(batch,)](
req_to_token, req_to_token,
req_pool_indices, req_pool_indices,
seq_lens, paged_kernel_lens,
kv_indptr, kv_indptr,
None, None,
kv_indices_triton, kv_indices_triton,
req_to_token.size(1), req_to_token.size(1),
PAGE_SIZE,
)
max_pages = max_context_len // PAGE_SIZE
kv_indices_flashmla = torch.empty(
batch, max_pages, dtype=torch.int32, device="cuda"
) )
create_flashmla_kv_indices_triton[(batch,)](
req_to_token,
req_pool_indices,
seq_lens,
None,
kv_indices_flashmla,
req_to_token.size(1),
max_pages,
PAGE_SIZE,
)
# Check # Check
self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton)) self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton))
def test_create_kvindices(self): def test_create_kvindices(self):
BATCH = [4, 37, 512, 1786] BATCH = [1, 37, 1786]
MAX_BATCH = 4096 MAX_BATCH = 4096
MAX_CONTEXT_LEN = 4096 MAX_CONTEXT_LEN = 4096
PAGE_SIZE = [1, 2, 16, 64] for batch in BATCH:
# for debug self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN)
# BATCH = [4]
# MAX_BATCH = 4
# MAX_CONTEXT_LEN = 10
# Test for small batch size
for page_size in PAGE_SIZE[:1]:
print(f"Running test for page size: {page_size} and batch size: {BATCH[0]}")
self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size)
# Test for larger batch size
for batch in BATCH[1:]:
for page_size in PAGE_SIZE:
print(
f"Running test for batch size: {batch} and page size: {page_size}"
)
self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -120,49 +120,5 @@ class TestFlashinferMLAMTP(CustomTestCase): ...@@ -120,49 +120,5 @@ class TestFlashinferMLAMTP(CustomTestCase):
self.assertGreater(avg_spec_accept_length, 2.5) self.assertGreater(avg_spec_accept_length, 2.5)
class TestFlashinferMLAPageSize16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--cuda-graph-max-bs",
"4",
"--attention-backend",
"flashinfer",
"--page-size",
"16",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.615)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment