Commit 62d065ca authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_niuhb' into 'v0.5.4_dev'

mtp增加dcu_assign_req_to_token_pool、dcu_get_last_loc、dcu_assign_extend_cache_locs、d...

See merge request OpenDAS/sglang!32
parents 769353e6 f6d91d7e
...@@ -11,6 +11,8 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend ...@@ -11,6 +11,8 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
from sglang.srt.utils import get_bool_env_var
try: try:
from flash_mla import ( from flash_mla import (
...@@ -104,6 +106,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -104,6 +106,7 @@ class DCUMLABackend(AttentionBackend):
) )
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_FLASHMLA_KV_INDICES_TRITON")
bs = forward_batch.batch_size bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
...@@ -118,15 +121,27 @@ class DCUMLABackend(AttentionBackend): ...@@ -118,15 +121,27 @@ class DCUMLABackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
device=forward_batch.seq_lens.device device=forward_batch.seq_lens.device
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
forward_batch.seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32), forward_batch.seq_lens.to(torch.int32),
...@@ -149,15 +164,27 @@ class DCUMLABackend(AttentionBackend): ...@@ -149,15 +164,27 @@ class DCUMLABackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
device=seq_lens.device, device=seq_lens.device,
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads, self.num_draft_tokens * self.num_q_heads,
...@@ -185,15 +212,27 @@ class DCUMLABackend(AttentionBackend): ...@@ -185,15 +212,27 @@ class DCUMLABackend(AttentionBackend):
) )
# 调用 Triton kernel 生成 block_kv_indices # 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token.to(torch.int32),
seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices.to(torch.int32),
None, page_kernel_lens_ptr = forward_batch.seq_lens.to(torch.int32),
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices.to(torch.int32),
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
# MLA # MLA
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
...@@ -211,6 +250,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -211,6 +250,7 @@ class DCUMLABackend(AttentionBackend):
self.flashattn_backend.init_forward_metadata(forward_batch) self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
...@@ -489,9 +529,10 @@ class DCUMLABackend(AttentionBackend): ...@@ -489,9 +529,10 @@ class DCUMLABackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
sinks=None, sinks=None,
): ):
if (
if ((
forward_batch.forward_mode == ForwardMode.EXTEND forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
): ):
if not self.skip_prefill: if not self.skip_prefill:
return self.flashattn_backend.forward_extend( return self.flashattn_backend.forward_extend(
......
...@@ -674,16 +674,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -674,16 +674,11 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
# if not self.use_mla:
if k_rope is None: if k_rope is None:
if not self.use_mla: forward_batch.token_to_kv_pool.set_kv_buffer(
forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, #layer.k_scale, layer.v_scale
layer, cache_loc, k, v, layer.k_scale, layer.v_scale )
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else: else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
......
...@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton ...@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
bs = forward_batch.batch_size bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv( max_seqlen_pad = triton.cdiv(
...@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32, dtype=torch.int32,
device=forward_batch.seq_lens.device, device=forward_batch.seq_lens.device,
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
forward_batch.seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32), forward_batch.seq_lens.to(torch.int32),
self.num_q_heads, self.num_q_heads,
...@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32, dtype=torch.int32,
device=seq_lens.device, device=seq_lens.device,
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads, self.num_draft_tokens * self.num_q_heads,
...@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
) )
else: else:
super().init_forward_metadata(forward_batch) super().init_forward_metadata(forward_batch)
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
......
...@@ -13,7 +13,8 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache ...@@ -13,7 +13,8 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import support_triton from sglang.srt.utils import support_triton,get_bool_env_var
from sgl_kernel.kvcacheio import dcu_get_last_loc
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
...@@ -125,13 +126,17 @@ def get_last_loc( ...@@ -125,13 +126,17 @@ def get_last_loc(
req_pool_indices_tensor: torch.Tensor, req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if ( use_sglang_get_last_loc = get_bool_env_var("SGLANG_GET_LAST_LOC")
get_global_server_args().attention_backend != "ascend" if use_sglang_get_last_loc:
and get_global_server_args().attention_backend != "torch_native" impl = dcu_get_last_loc
):
impl = get_last_loc_triton
else: else:
impl = get_last_loc_torch if (
get_global_server_args().attention_backend != "ascend"
and get_global_server_args().attention_backend != "torch_native"
):
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor) return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
......
...@@ -46,7 +46,11 @@ from sglang.srt.layers.dp_attention import ( ...@@ -46,7 +46,11 @@ from sglang.srt.layers.dp_attention import (
set_dp_buffer_len, set_dp_buffer_len,
set_is_extend_in_batch, set_is_extend_in_batch,
) )
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton from sglang.srt.utils import get_compiler_backend, is_npu, support_triton,get_bool_env_var
from sgl_kernel.kvcacheio import dcu_create_chunked_prefix_cache_kv_indices
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -128,8 +132,8 @@ class ForwardMode(IntEnum): ...@@ -128,8 +132,8 @@ class ForwardMode(IntEnum):
self == ForwardMode.EXTEND self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL or self == ForwardMode.SPLIT_PREFILL
or self == ForwardMode.DRAFT_EXTEND_V2 or self == ForwardMode.DRAFT_EXTEND_V2 #nhb
) )
def is_cuda_graph(self): def is_cuda_graph(self):
...@@ -317,6 +321,8 @@ class ForwardBatch: ...@@ -317,6 +321,8 @@ class ForwardBatch:
tbo_parent_token_range: Optional[Tuple[int, int]] = None tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List[ForwardBatch]] = None tbo_children: Optional[List[ForwardBatch]] = None
use_sglang_create_chunked_prefix_cache_kv_indices = get_bool_env_var("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES")
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
...@@ -635,15 +641,28 @@ class ForwardBatch: ...@@ -635,15 +641,28 @@ class ForwardBatch:
num_chunk_tokens, dtype=torch.int32, device=device num_chunk_tokens, dtype=torch.int32, device=device
) )
create_chunked_prefix_cache_kv_indices[(self.batch_size,)]( if self.use_sglang_create_chunked_prefix_cache_kv_indices:
self.req_to_token_pool.req_to_token, dcu_create_chunked_prefix_cache_kv_indices(
self.req_pool_indices, req_to_token = self.req_to_token_pool.req_to_token,
chunk_starts, req_pool_indices = self.req_pool_indices,
chunk_seq_lens, chunk_starts = chunk_starts,
chunk_cu_seq_lens, chunk_seq_lens = chunk_seq_lens,
chunk_kv_indices, chunk_cu_seq_lens = chunk_cu_seq_lens,
self.req_to_token_pool.req_to_token.shape[1], chunk_kv_indices = chunk_kv_indices,
) col_num = self.req_to_token_pool.req_to_token.shape[1],
bs = self.batch_size,
)
else:
logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
self.req_to_token_pool.req_to_token.shape[1],
)
self.prefix_chunk_kv_indices.append(chunk_kv_indices) self.prefix_chunk_kv_indices.append(chunk_kv_indices)
def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0): def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
......
...@@ -237,7 +237,14 @@ class DraftBackendFactory: ...@@ -237,7 +237,14 @@ class DraftBackendFactory:
return None return None
def _create_dcumla_prefill_backend(self): def _create_dcumla_prefill_backend(self):
logger.warning( # logger.warning(
"flashmla prefill backend is not yet supported for draft extend." # "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
) )
return None
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
...@@ -29,6 +29,12 @@ from sglang.srt.speculative.spec_utils import ( ...@@ -29,6 +29,12 @@ from sglang.srt.speculative.spec_utils import (
) )
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2 from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.kvcacheio import dcu_assign_req_to_token_pool,dcu_assign_extend_cache_locs
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
...@@ -77,6 +83,9 @@ def assign_draft_cache_locs_page_size_1( ...@@ -77,6 +83,9 @@ def assign_draft_cache_locs_page_size_1(
@dataclass @dataclass
class EagleDraftInputV2Mixin: class EagleDraftInputV2Mixin:
use_sglang_assign_req_to_token_pool = get_bool_env_var("SGLANG_ASSIGN_REQ_TO_TOKEN_POOL")
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
...@@ -112,15 +121,26 @@ class EagleDraftInputV2Mixin: ...@@ -112,15 +121,26 @@ class EagleDraftInputV2Mixin:
extend_num_tokens, extend_num_tokens,
) )
assign_req_to_token_pool[(bs,)]( if self.use_sglang_assign_req_to_token_pool:
batch.req_pool_indices, dcu_assign_req_to_token_pool(
batch.req_to_token_pool.req_to_token, req_pool_indices = batch.req_pool_indices,
self.allocate_lens, req_to_token = batch.req_to_token_pool.req_to_token,
new_allocate_lens, allocate_lens = self.allocate_lens,
out_cache_loc, new_allocate_lens = new_allocate_lens,
batch.req_to_token_pool.req_to_token.shape[1], out_cache_loc = out_cache_loc,
next_power_of_2(bs), shape = batch.req_to_token_pool.req_to_token.shape[1],
) bs = bs,
)
else:
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
self.allocate_lens = new_allocate_lens self.allocate_lens = new_allocate_lens
# FIXME(lsyin): make this sync optional # FIXME(lsyin): make this sync optional
...@@ -189,6 +209,9 @@ class EagleDraftInputV2Mixin: ...@@ -189,6 +209,9 @@ class EagleDraftInputV2Mixin:
@dataclass @dataclass
class EagleVerifyInputV2Mixin: class EagleVerifyInputV2Mixin:
use_sglang_assign_extend_cache_locs = get_bool_env_var("SGLANG_ASSIGN_EXTEND_CACHE_LOCS")
def prepare_for_v2_verify( def prepare_for_v2_verify(
self: EagleVerifyInput, self: EagleVerifyInput,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
...@@ -205,15 +228,26 @@ class EagleVerifyInputV2Mixin: ...@@ -205,15 +228,26 @@ class EagleVerifyInputV2Mixin:
device=device, device=device,
) )
assign_extend_cache_locs[(bs,)]( if self.use_sglang_assign_extend_cache_locs:
batch.req_pool_indices, dcu_assign_extend_cache_locs(
req_to_token_pool.req_to_token, batch.req_pool_indices,
batch.seq_lens, req_to_token_pool.req_to_token,
batch.seq_lens + self.draft_token_num, batch.seq_lens,
batch.out_cache_loc, batch.seq_lens + self.draft_token_num,
req_to_token_pool.req_to_token.shape[1], batch.out_cache_loc,
next_power_of_2(bs), req_to_token_pool.req_to_token.shape[1],
) bs,
)
else:
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
# Get a forward batch # Get a forward batch
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
......
...@@ -19,6 +19,14 @@ limitations under the License. ...@@ -19,6 +19,14 @@ limitations under the License.
#include "sgl_kernel_ops.h" #include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernel, m) { TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From FlashMLA
*/
m.def("dcu_create_flashmla_kv_indices(Tensor req_to_token, Tensor req_pool_indices,Tensor page_kernel_lens, Tensor? kv_start_idx, Tensor kv_indices, int req_to_token_stride, int kv_indices_stride, int PAGED_SIZE) -> ()");
m.impl("dcu_create_flashmla_kv_indices", torch::kCUDA, &dcu_create_flashmla_kv_indices);
/* /*
* From csrc/activation * From csrc/activation
*/ */
...@@ -133,6 +141,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -133,6 +141,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/ */
m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"); m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()");
m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info); m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info);
m.def("dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()");
m.impl("dcu_create_chunked_prefix_cache_kv_indices", torch::kCUDA, &dcu_create_chunked_prefix_cache_kv_indices);
m.def("dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()");
m.impl("dcu_assign_extend_cache_locs", torch::kCUDA, &dcu_assign_extend_cache_locs);
m.def("dcu_get_last_loc(Tensor req_to_token, Tensor req_pool_indices, Tensor prefix_lens) -> Tensor");
m.impl("dcu_get_last_loc", torch::kCUDA, &dcu_get_last_loc);
m.def("dcu_assign_req_to_token_pool(Tensor req_pool_indices_ptr,Tensor req_to_token_ptr,Tensor allocate_lens_ptr,Tensor new_allocate_lens,Tensor out_cache_loc_ptr,int shape,int bs) -> ()");
m.impl("dcu_assign_req_to_token_pool",torch::kCUDA,&dcu_assign_req_to_token_pool);
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"); m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel); m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
......
...@@ -836,4 +836,322 @@ void dcu_alloc_extend_kernel( ...@@ -836,4 +836,322 @@ void dcu_alloc_extend_kernel(
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size); launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
\ No newline at end of file
__global__ void launch_assign_req_to_token_pool(
const int64_t* req_pool_indices_ptr,
int32_t* req_to_token_ptr,
const int64_t* allocate_lens_ptr,
int64_t* new_allocate_lens,
int64_t* out_cache_loc_ptr,
int64_t shape,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = allocate_lens_ptr[pid];
int64_t kv_end = new_allocate_lens[pid];
int64_t pool_idx = req_pool_indices_ptr[pid];
int32_t* token_pool = (int32_t*)(req_to_token_ptr + pool_idx * shape);
int64_t sum_out_offset = 0;
for(int length_offset = 0; length_offset < pid;length_offset++){
int64_t start = allocate_lens_ptr[length_offset];
int64_t end = new_allocate_lens[length_offset];
sum_out_offset += (end- start);
}
int64_t* out_cache_ptr = out_cache_loc_ptr + sum_out_offset;
int64_t copy_length = kv_end - kv_start;
#pragma unroll(32)
for (int out_cache_index = 0; out_cache_index < copy_length; out_cache_index++) {
token_pool[kv_start + out_cache_index] = out_cache_ptr[out_cache_index];
}
}
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs) {
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* allocate_lens_ptr1 = static_cast<const int64_t*>(allocate_lens_ptr.data_ptr());
int64_t* new_allocate_lens1 = static_cast<int64_t*>(new_allocate_lens.data_ptr());
int64_t* out_cache_loc_ptr1 = static_cast<int64_t*>(out_cache_loc_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_assign_req_to_token_pool<<<grid_size, block_size, 0, torch_current_stream>>>(req_pool_indices_ptr1, req_to_token_ptr1, allocate_lens_ptr1, new_allocate_lens1, out_cache_loc_ptr1, shape, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
__global__ void get_last_loc_kernel(
const int32_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices_tensor,
const int64_t* __restrict__ prefix_lens_tensor,
int64_t* __restrict__ result,
int64_t num_tokens,
int64_t req_to_token_stride){
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= num_tokens) return;
int64_t pre_len = prefix_lens_tensor[pid];
if (pre_len > 0) {
int64_t req_idx = req_pool_indices_tensor[pid];
int64_t token_idx = req_idx * req_to_token_stride + (pre_len - 1);
result[pid] = static_cast<int64_t>(req_to_token[token_idx]);
} else {
result[pid] = static_cast<int64_t>(-1);
}
}
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens) {
TORCH_CHECK(req_to_token.device().is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(req_pool_indices.device().is_cuda(), "req_pool_indices must be CUDA tensor");
TORCH_CHECK(prefix_lens.device().is_cuda(), "prefix_lens must be CUDA tensor");
TORCH_CHECK(req_to_token.dim() == 2, "req_to_token must be 2D tensor [batch, seq_len]");
TORCH_CHECK(prefix_lens.dim() == 1, "prefix_lens must be 1D");
TORCH_CHECK(req_pool_indices.dim() == 1, "req_pool_indices must be 1D");
int64_t num_tokens = prefix_lens.numel();
TORCH_CHECK(req_pool_indices.numel() == num_tokens, "req_pool_indices must have same length as prefix_lens");
int64_t req_to_token_stride = req_to_token.stride(0);
auto req_to_token_c = req_to_token.contiguous();
auto req_pool_indices_c = req_pool_indices.contiguous();
auto prefix_lens_c = prefix_lens.contiguous();
const int32_t* req_to_token_ptr = req_to_token_c.data_ptr<int32_t>();
const int64_t* req_pool_indices_ptr = req_pool_indices_c.data_ptr<int64_t>();
const int64_t* prefix_lens_ptr = prefix_lens_c.data_ptr<int64_t>();
auto result = at::empty_like(prefix_lens_c);
int64_t* result_ptr = result.data_ptr<int64_t>();
const int64_t block_size = 64;
const int64_t grid_size = (num_tokens + block_size - 1) / block_size;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
get_last_loc_kernel<<<grid_size, block_size, 0, stream>>>(
req_to_token_ptr,
req_pool_indices_ptr,
prefix_lens_ptr,
result_ptr,
num_tokens,
req_to_token_stride
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return result;
}
__global__ void launch_assign_extend_cache_locs_kernel(
const int64_t* __restrict__ req_pool_indices, // [bs]
const int32_t* __restrict__ req_to_token, // [max_num_req, pool_len]
const int64_t* __restrict__ start_offset, // [bs]
const int64_t* __restrict__ end_offset, // [bs]
int64_t* __restrict__ out_cache_loc, // [sum(draft_token_num)]
int64_t pool_len,
int64_t bs)
{
int pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = start_offset[pid];
int64_t kv_end = end_offset[pid];
int64_t req_id = req_pool_indices[pid];
int64_t out_offset = 0;
for (int i = 0; i < pid; ++i) {
out_offset += end_offset[i] - start_offset[i];
}
const int32_t* src = req_to_token + req_id * pool_len + kv_start;
int64_t* dst = out_cache_loc + out_offset;
for (int64_t i = 0; i < kv_end - kv_start; ++i) {
dst[i] = src[i];
}
}
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs)
{
const int64_t* req_pool_indices_ptr = req_pool_indices.data_ptr<int64_t>();
const int32_t* req_to_token_ptr = req_to_token.data_ptr<int32_t>();
const int64_t* start_offset_ptr = start_offset.data_ptr<int64_t>();
const int64_t* end_offset_ptr = end_offset.data_ptr<int64_t>();
int64_t* out_cache_loc_ptr = out_cache_loc.data_ptr<int64_t>();
constexpr int64_t threads = 128;
int64_t blocks = (bs + threads - 1) / threads;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
launch_assign_extend_cache_locs_kernel<<<blocks, threads, 0, stream>>>(
req_pool_indices_ptr,
req_to_token_ptr,
start_offset_ptr,
end_offset_ptr,
out_cache_loc_ptr,
pool_len,
bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<int PAGED_SIZE>
__global__ void dcu_create_flashmla_kv_indices_kernel(
const int32_t* __restrict__ req_to_token,
const int32_t* __restrict__ req_pool_indices,
const int32_t* __restrict__ page_kernel_lens,
const int32_t* __restrict__ kv_start_idx,
int32_t* __restrict__ kv_indices,
int req_to_token_stride,
int kv_indices_stride)
{
int pid = blockIdx.x; // batch index
int req_pool_index = req_pool_indices[pid];
int kv_start = 0;
int kv_end = 0;
if (kv_start_idx != nullptr) {
kv_start = kv_start_idx[pid];
kv_end = kv_start;
}
kv_end += page_kernel_lens[pid];
int total_len = kv_end - kv_start;
int num_pages = (total_len + PAGED_SIZE - 1) / PAGED_SIZE;
for (int pg = 0; pg < num_pages; ++pg) {
int offset = pg * PAGED_SIZE;
// token id = req_to_token[req_pool_index][kv_start + offset]
int64_t token =
req_to_token[req_pool_index * req_to_token_stride + kv_start + offset];
// 页索引
kv_indices[pid * kv_indices_stride + pg] = token / PAGED_SIZE;
}
}
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE)
{
TORCH_CHECK(req_to_token.is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(kv_indices.is_cuda(), "kv_indices must be CUDA tensor");
int bs = req_pool_indices.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(bs);
dim3 block(1);
const int32_t* kv_start_idx_ptr = nullptr;
if (kv_start_idx.has_value()) {
kv_start_idx_ptr = kv_start_idx.value().data_ptr<int32_t>();
}
if (PAGED_SIZE == 64) {
dcu_create_flashmla_kv_indices_kernel<64><<<grid, block, 0, stream>>>(
req_to_token.data_ptr<int32_t>(),
req_pool_indices.data_ptr<int32_t>(),
page_kernel_lens.data_ptr<int32_t>(),
kv_start_idx_ptr,
kv_indices.data_ptr<int32_t>(),
req_to_token_stride,
kv_indices_stride
);
} else {
TORCH_CHECK(false, "Unsupported PAGED_SIZE");
}
}
__global__ void launch_create_chunked_prefix_cache_kv_indices(
int32_t* req_to_token_ptr,
const int64_t* req_pool_indices_ptr,
const int32_t* chunk_starts_ptr,
const int32_t* chunk_seq_lens_ptr,
const int32_t* chunk_cu_seq_lens_ptr,
int32_t* chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t req_pool_index = req_pool_indices_ptr[pid];
int64_t chunk_kv_indices_offset = chunk_cu_seq_lens_ptr[pid];
int32_t chunk_start_pos = chunk_starts_ptr[pid];
int32_t chunk_seq_len = chunk_seq_lens_ptr[pid];
#pragma unroll(32)
for(int32_t offset = 0;offset < chunk_seq_len;offset++){
chunk_kv_indices_ptr[chunk_kv_indices_offset+offset] = req_to_token_ptr[req_pool_index * col_num + chunk_start_pos + offset];
}
}
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token_ptr,
const at::Tensor req_pool_indices_ptr,
const at::Tensor chunk_starts_ptr,
const at::Tensor chunk_seq_lens_ptr,
const at::Tensor chunk_cu_seq_lens_ptr,
at::Tensor chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs) {
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
const int32_t* chunk_starts_ptr1 = static_cast<const int32_t*>(chunk_starts_ptr.data_ptr());
const int32_t* chunk_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_seq_lens_ptr.data_ptr());
const int32_t* chunk_cu_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_cu_seq_lens_ptr.data_ptr());
int32_t* chunk_kv_indices_ptr1 = static_cast<int32_t*>(chunk_kv_indices_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_create_chunked_prefix_cache_kv_indices<<<grid_size, block_size, 0, torch_current_stream>>>(req_to_token_ptr1, req_pool_indices_ptr1, chunk_starts_ptr1, chunk_seq_lens_ptr1, chunk_cu_seq_lens_ptr1,chunk_kv_indices_ptr1, col_num, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
...@@ -538,6 +538,7 @@ void segment_packbits( ...@@ -538,6 +538,7 @@ void segment_packbits(
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
void dcu_create_extend_after_decode_spec_info( void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id, const at::Tensor verified_id,
const at::Tensor seq_lens, const at::Tensor seq_lens,
...@@ -545,6 +546,49 @@ void dcu_create_extend_after_decode_spec_info( ...@@ -545,6 +546,49 @@ void dcu_create_extend_after_decode_spec_info(
at::Tensor positions, at::Tensor positions,
at::Tensor new_verified_id, at::Tensor new_verified_id,
int64_t bs); int64_t bs);
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor chunk_starts,
const at::Tensor chunk_seq_lens,
const at::Tensor chunk_cu_seq_lens,
at::Tensor chunk_kv_indices,
int64_t col_num,
int64_t bs);
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE);
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs);
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens);
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs);
void dcu_alloc_extend_kernel( void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr, const at::Tensor pre_lens_ptr,
......
...@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError( ...@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4" "Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
) )
def dcu_create_flashmla_kv_indices(
req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE = 64,
):
torch.ops.sgl_kernel.dcu_create_flashmla_kv_indices(req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE,
)
def get_mla_metadata( def get_mla_metadata(
cache_seqlens: torch.Tensor, cache_seqlens: torch.Tensor,
......
...@@ -293,3 +293,76 @@ def transfer_kv_all_layer_mla_lf_pf( ...@@ -293,3 +293,76 @@ def transfer_kv_all_layer_mla_lf_pf(
block_quota, block_quota,
num_warps_per_block, num_warps_per_block,
) )
def dcu_assign_req_to_token_pool(
req_pool_indices:torch.Tensor,
req_to_token:torch.Tensor,
allocate_lens:torch.Tensor,
new_allocate_lens:torch.Tensor,
out_cache_loc:torch.Tensor,
shape:int,
bs:int,
):
torch.ops.sgl_kernel.dcu_assign_req_to_token_pool(
req_pool_indices,
req_to_token,
allocate_lens,
new_allocate_lens,
out_cache_loc,
shape,
bs,
)
def dcu_get_last_loc(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
):
result = torch.ops.sgl_kernel.dcu_get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return result
def dcu_assign_extend_cache_locs(
req_pool_indices: torch.Tensor,
req_to_token: torch.Tensor,
start_offset: torch.Tensor,
end_offset: torch.Tensor,
out_cache_loc: torch.Tensor,
pool_len: int,
bs: int,
):
torch.ops.sgl_kernel.dcu_assign_extend_cache_locs(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len,
bs,
)
def dcu_create_chunked_prefix_cache_kv_indices(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
chunk_starts: torch.Tensor,
chunk_seq_lens: torch.Tensor,
chunk_cu_seq_lens: torch.Tensor,
chunk_kv_indices: torch.Tensor,
col_num: int,
bs: int,
):
torch.ops.sgl_kernel.dcu_create_chunked_prefix_cache_kv_indices(
req_to_token,
req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
col_num,
bs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment