Commit 62d065ca authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_niuhb' into 'v0.5.4_dev'

mtp增加dcu_assign_req_to_token_pool、dcu_get_last_loc、dcu_assign_extend_cache_locs、d...

See merge request OpenDAS/sglang!32
parents 769353e6 f6d91d7e
......@@ -11,6 +11,8 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
from sglang.srt.utils import get_bool_env_var
try:
from flash_mla import (
......@@ -104,6 +106,7 @@ class DCUMLABackend(AttentionBackend):
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_FLASHMLA_KV_INDICES_TRITON")
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
......@@ -118,6 +121,18 @@ class DCUMLABackend(AttentionBackend):
dtype=torch.int32,
device=forward_batch.seq_lens.device
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
......@@ -149,10 +164,22 @@ class DCUMLABackend(AttentionBackend):
dtype=torch.int32,
device=seq_lens.device,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
......@@ -185,10 +212,22 @@ class DCUMLABackend(AttentionBackend):
)
# 调用 Triton kernel 生成 block_kv_indices
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token.to(torch.int32),
req_pool_indices_ptr = forward_batch.req_pool_indices.to(torch.int32),
page_kernel_lens_ptr = forward_batch.seq_lens.to(torch.int32),
kv_start_idx = None,
kv_indices_ptr = block_kv_indices.to(torch.int32),
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
......@@ -211,6 +250,7 @@ class DCUMLABackend(AttentionBackend):
self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(
self,
max_bs: int,
......@@ -489,9 +529,10 @@ class DCUMLABackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None,
sinks=None,
):
if (
if ((
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
):
if not self.skip_prefill:
return self.flashattn_backend.forward_extend(
......
......@@ -674,16 +674,11 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
# if not self.use_mla:
if k_rope is None:
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
layer, cache_loc, k, v, #layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
......
......@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(
......@@ -91,6 +95,18 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
......@@ -121,10 +137,22 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32,
device=seq_lens.device,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
......
......@@ -13,7 +13,8 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import support_triton
from sglang.srt.utils import support_triton,get_bool_env_var
from sgl_kernel.kvcacheio import dcu_get_last_loc
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
......@@ -125,6 +126,10 @@ def get_last_loc(
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
use_sglang_get_last_loc = get_bool_env_var("SGLANG_GET_LAST_LOC")
if use_sglang_get_last_loc:
impl = dcu_get_last_loc
else:
if (
get_global_server_args().attention_backend != "ascend"
and get_global_server_args().attention_backend != "torch_native"
......
......@@ -46,7 +46,11 @@ from sglang.srt.layers.dp_attention import (
set_dp_buffer_len,
set_is_extend_in_batch,
)
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton,get_bool_env_var
from sgl_kernel.kvcacheio import dcu_create_chunked_prefix_cache_kv_indices
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -129,7 +133,7 @@ class ForwardMode(IntEnum):
or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL
or self == ForwardMode.DRAFT_EXTEND_V2
or self == ForwardMode.DRAFT_EXTEND_V2 #nhb
)
def is_cuda_graph(self):
......@@ -317,6 +321,8 @@ class ForwardBatch:
tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List[ForwardBatch]] = None
use_sglang_create_chunked_prefix_cache_kv_indices = get_bool_env_var("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES")
@classmethod
def init_new(
cls,
......@@ -635,6 +641,19 @@ class ForwardBatch:
num_chunk_tokens, dtype=torch.int32, device=device
)
if self.use_sglang_create_chunked_prefix_cache_kv_indices:
dcu_create_chunked_prefix_cache_kv_indices(
req_to_token = self.req_to_token_pool.req_to_token,
req_pool_indices = self.req_pool_indices,
chunk_starts = chunk_starts,
chunk_seq_lens = chunk_seq_lens,
chunk_cu_seq_lens = chunk_cu_seq_lens,
chunk_kv_indices = chunk_kv_indices,
col_num = self.req_to_token_pool.req_to_token.shape[1],
bs = self.batch_size,
)
else:
logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
......
......@@ -237,7 +237,14 @@ class DraftBackendFactory:
return None
def _create_dcumla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
# logger.warning(
# "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
return None
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
......@@ -29,6 +29,12 @@ from sglang.srt.speculative.spec_utils import (
)
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.kvcacheio import dcu_assign_req_to_token_pool,dcu_assign_extend_cache_locs
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
......@@ -77,6 +83,9 @@ def assign_draft_cache_locs_page_size_1(
@dataclass
class EagleDraftInputV2Mixin:
use_sglang_assign_req_to_token_pool = get_bool_env_var("SGLANG_ASSIGN_REQ_TO_TOKEN_POOL")
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
......@@ -112,6 +121,17 @@ class EagleDraftInputV2Mixin:
extend_num_tokens,
)
if self.use_sglang_assign_req_to_token_pool:
dcu_assign_req_to_token_pool(
req_pool_indices = batch.req_pool_indices,
req_to_token = batch.req_to_token_pool.req_to_token,
allocate_lens = self.allocate_lens,
new_allocate_lens = new_allocate_lens,
out_cache_loc = out_cache_loc,
shape = batch.req_to_token_pool.req_to_token.shape[1],
bs = bs,
)
else:
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
......@@ -189,6 +209,9 @@ class EagleDraftInputV2Mixin:
@dataclass
class EagleVerifyInputV2Mixin:
use_sglang_assign_extend_cache_locs = get_bool_env_var("SGLANG_ASSIGN_EXTEND_CACHE_LOCS")
def prepare_for_v2_verify(
self: EagleVerifyInput,
req_to_token_pool: ReqToTokenPool,
......@@ -205,6 +228,17 @@ class EagleVerifyInputV2Mixin:
device=device,
)
if self.use_sglang_assign_extend_cache_locs:
dcu_assign_extend_cache_locs(
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
bs,
)
else:
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
......
......@@ -19,6 +19,14 @@ limitations under the License.
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From FlashMLA
*/
m.def("dcu_create_flashmla_kv_indices(Tensor req_to_token, Tensor req_pool_indices,Tensor page_kernel_lens, Tensor? kv_start_idx, Tensor kv_indices, int req_to_token_stride, int kv_indices_stride, int PAGED_SIZE) -> ()");
m.impl("dcu_create_flashmla_kv_indices", torch::kCUDA, &dcu_create_flashmla_kv_indices);
/*
* From csrc/activation
*/
......@@ -133,6 +141,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/
m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()");
m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info);
m.def("dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()");
m.impl("dcu_create_chunked_prefix_cache_kv_indices", torch::kCUDA, &dcu_create_chunked_prefix_cache_kv_indices);
m.def("dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()");
m.impl("dcu_assign_extend_cache_locs", torch::kCUDA, &dcu_assign_extend_cache_locs);
m.def("dcu_get_last_loc(Tensor req_to_token, Tensor req_pool_indices, Tensor prefix_lens) -> Tensor");
m.impl("dcu_get_last_loc", torch::kCUDA, &dcu_get_last_loc);
m.def("dcu_assign_req_to_token_pool(Tensor req_pool_indices_ptr,Tensor req_to_token_ptr,Tensor allocate_lens_ptr,Tensor new_allocate_lens,Tensor out_cache_loc_ptr,int shape,int bs) -> ()");
m.impl("dcu_assign_req_to_token_pool",torch::kCUDA,&dcu_assign_req_to_token_pool);
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
......
......@@ -837,3 +837,321 @@ void dcu_alloc_extend_kernel(
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
__global__ void launch_assign_req_to_token_pool(
const int64_t* req_pool_indices_ptr,
int32_t* req_to_token_ptr,
const int64_t* allocate_lens_ptr,
int64_t* new_allocate_lens,
int64_t* out_cache_loc_ptr,
int64_t shape,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = allocate_lens_ptr[pid];
int64_t kv_end = new_allocate_lens[pid];
int64_t pool_idx = req_pool_indices_ptr[pid];
int32_t* token_pool = (int32_t*)(req_to_token_ptr + pool_idx * shape);
int64_t sum_out_offset = 0;
for(int length_offset = 0; length_offset < pid;length_offset++){
int64_t start = allocate_lens_ptr[length_offset];
int64_t end = new_allocate_lens[length_offset];
sum_out_offset += (end- start);
}
int64_t* out_cache_ptr = out_cache_loc_ptr + sum_out_offset;
int64_t copy_length = kv_end - kv_start;
#pragma unroll(32)
for (int out_cache_index = 0; out_cache_index < copy_length; out_cache_index++) {
token_pool[kv_start + out_cache_index] = out_cache_ptr[out_cache_index];
}
}
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs) {
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* allocate_lens_ptr1 = static_cast<const int64_t*>(allocate_lens_ptr.data_ptr());
int64_t* new_allocate_lens1 = static_cast<int64_t*>(new_allocate_lens.data_ptr());
int64_t* out_cache_loc_ptr1 = static_cast<int64_t*>(out_cache_loc_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_assign_req_to_token_pool<<<grid_size, block_size, 0, torch_current_stream>>>(req_pool_indices_ptr1, req_to_token_ptr1, allocate_lens_ptr1, new_allocate_lens1, out_cache_loc_ptr1, shape, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
__global__ void get_last_loc_kernel(
const int32_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices_tensor,
const int64_t* __restrict__ prefix_lens_tensor,
int64_t* __restrict__ result,
int64_t num_tokens,
int64_t req_to_token_stride){
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= num_tokens) return;
int64_t pre_len = prefix_lens_tensor[pid];
if (pre_len > 0) {
int64_t req_idx = req_pool_indices_tensor[pid];
int64_t token_idx = req_idx * req_to_token_stride + (pre_len - 1);
result[pid] = static_cast<int64_t>(req_to_token[token_idx]);
} else {
result[pid] = static_cast<int64_t>(-1);
}
}
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens) {
TORCH_CHECK(req_to_token.device().is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(req_pool_indices.device().is_cuda(), "req_pool_indices must be CUDA tensor");
TORCH_CHECK(prefix_lens.device().is_cuda(), "prefix_lens must be CUDA tensor");
TORCH_CHECK(req_to_token.dim() == 2, "req_to_token must be 2D tensor [batch, seq_len]");
TORCH_CHECK(prefix_lens.dim() == 1, "prefix_lens must be 1D");
TORCH_CHECK(req_pool_indices.dim() == 1, "req_pool_indices must be 1D");
int64_t num_tokens = prefix_lens.numel();
TORCH_CHECK(req_pool_indices.numel() == num_tokens, "req_pool_indices must have same length as prefix_lens");
int64_t req_to_token_stride = req_to_token.stride(0);
auto req_to_token_c = req_to_token.contiguous();
auto req_pool_indices_c = req_pool_indices.contiguous();
auto prefix_lens_c = prefix_lens.contiguous();
const int32_t* req_to_token_ptr = req_to_token_c.data_ptr<int32_t>();
const int64_t* req_pool_indices_ptr = req_pool_indices_c.data_ptr<int64_t>();
const int64_t* prefix_lens_ptr = prefix_lens_c.data_ptr<int64_t>();
auto result = at::empty_like(prefix_lens_c);
int64_t* result_ptr = result.data_ptr<int64_t>();
const int64_t block_size = 64;
const int64_t grid_size = (num_tokens + block_size - 1) / block_size;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
get_last_loc_kernel<<<grid_size, block_size, 0, stream>>>(
req_to_token_ptr,
req_pool_indices_ptr,
prefix_lens_ptr,
result_ptr,
num_tokens,
req_to_token_stride
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return result;
}
__global__ void launch_assign_extend_cache_locs_kernel(
const int64_t* __restrict__ req_pool_indices, // [bs]
const int32_t* __restrict__ req_to_token, // [max_num_req, pool_len]
const int64_t* __restrict__ start_offset, // [bs]
const int64_t* __restrict__ end_offset, // [bs]
int64_t* __restrict__ out_cache_loc, // [sum(draft_token_num)]
int64_t pool_len,
int64_t bs)
{
int pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = start_offset[pid];
int64_t kv_end = end_offset[pid];
int64_t req_id = req_pool_indices[pid];
int64_t out_offset = 0;
for (int i = 0; i < pid; ++i) {
out_offset += end_offset[i] - start_offset[i];
}
const int32_t* src = req_to_token + req_id * pool_len + kv_start;
int64_t* dst = out_cache_loc + out_offset;
for (int64_t i = 0; i < kv_end - kv_start; ++i) {
dst[i] = src[i];
}
}
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs)
{
const int64_t* req_pool_indices_ptr = req_pool_indices.data_ptr<int64_t>();
const int32_t* req_to_token_ptr = req_to_token.data_ptr<int32_t>();
const int64_t* start_offset_ptr = start_offset.data_ptr<int64_t>();
const int64_t* end_offset_ptr = end_offset.data_ptr<int64_t>();
int64_t* out_cache_loc_ptr = out_cache_loc.data_ptr<int64_t>();
constexpr int64_t threads = 128;
int64_t blocks = (bs + threads - 1) / threads;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
launch_assign_extend_cache_locs_kernel<<<blocks, threads, 0, stream>>>(
req_pool_indices_ptr,
req_to_token_ptr,
start_offset_ptr,
end_offset_ptr,
out_cache_loc_ptr,
pool_len,
bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<int PAGED_SIZE>
__global__ void dcu_create_flashmla_kv_indices_kernel(
const int32_t* __restrict__ req_to_token,
const int32_t* __restrict__ req_pool_indices,
const int32_t* __restrict__ page_kernel_lens,
const int32_t* __restrict__ kv_start_idx,
int32_t* __restrict__ kv_indices,
int req_to_token_stride,
int kv_indices_stride)
{
int pid = blockIdx.x; // batch index
int req_pool_index = req_pool_indices[pid];
int kv_start = 0;
int kv_end = 0;
if (kv_start_idx != nullptr) {
kv_start = kv_start_idx[pid];
kv_end = kv_start;
}
kv_end += page_kernel_lens[pid];
int total_len = kv_end - kv_start;
int num_pages = (total_len + PAGED_SIZE - 1) / PAGED_SIZE;
for (int pg = 0; pg < num_pages; ++pg) {
int offset = pg * PAGED_SIZE;
// token id = req_to_token[req_pool_index][kv_start + offset]
int64_t token =
req_to_token[req_pool_index * req_to_token_stride + kv_start + offset];
// 页索引
kv_indices[pid * kv_indices_stride + pg] = token / PAGED_SIZE;
}
}
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE)
{
TORCH_CHECK(req_to_token.is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(kv_indices.is_cuda(), "kv_indices must be CUDA tensor");
int bs = req_pool_indices.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(bs);
dim3 block(1);
const int32_t* kv_start_idx_ptr = nullptr;
if (kv_start_idx.has_value()) {
kv_start_idx_ptr = kv_start_idx.value().data_ptr<int32_t>();
}
if (PAGED_SIZE == 64) {
dcu_create_flashmla_kv_indices_kernel<64><<<grid, block, 0, stream>>>(
req_to_token.data_ptr<int32_t>(),
req_pool_indices.data_ptr<int32_t>(),
page_kernel_lens.data_ptr<int32_t>(),
kv_start_idx_ptr,
kv_indices.data_ptr<int32_t>(),
req_to_token_stride,
kv_indices_stride
);
} else {
TORCH_CHECK(false, "Unsupported PAGED_SIZE");
}
}
__global__ void launch_create_chunked_prefix_cache_kv_indices(
int32_t* req_to_token_ptr,
const int64_t* req_pool_indices_ptr,
const int32_t* chunk_starts_ptr,
const int32_t* chunk_seq_lens_ptr,
const int32_t* chunk_cu_seq_lens_ptr,
int32_t* chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t req_pool_index = req_pool_indices_ptr[pid];
int64_t chunk_kv_indices_offset = chunk_cu_seq_lens_ptr[pid];
int32_t chunk_start_pos = chunk_starts_ptr[pid];
int32_t chunk_seq_len = chunk_seq_lens_ptr[pid];
#pragma unroll(32)
for(int32_t offset = 0;offset < chunk_seq_len;offset++){
chunk_kv_indices_ptr[chunk_kv_indices_offset+offset] = req_to_token_ptr[req_pool_index * col_num + chunk_start_pos + offset];
}
}
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token_ptr,
const at::Tensor req_pool_indices_ptr,
const at::Tensor chunk_starts_ptr,
const at::Tensor chunk_seq_lens_ptr,
const at::Tensor chunk_cu_seq_lens_ptr,
at::Tensor chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs) {
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
const int32_t* chunk_starts_ptr1 = static_cast<const int32_t*>(chunk_starts_ptr.data_ptr());
const int32_t* chunk_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_seq_lens_ptr.data_ptr());
const int32_t* chunk_cu_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_cu_seq_lens_ptr.data_ptr());
int32_t* chunk_kv_indices_ptr1 = static_cast<int32_t*>(chunk_kv_indices_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_create_chunked_prefix_cache_kv_indices<<<grid_size, block_size, 0, torch_current_stream>>>(req_to_token_ptr1, req_pool_indices_ptr1, chunk_starts_ptr1, chunk_seq_lens_ptr1, chunk_cu_seq_lens_ptr1,chunk_kv_indices_ptr1, col_num, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
......@@ -538,6 +538,7 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
......@@ -545,6 +546,49 @@ void dcu_create_extend_after_decode_spec_info(
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs);
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor chunk_starts,
const at::Tensor chunk_seq_lens,
const at::Tensor chunk_cu_seq_lens,
at::Tensor chunk_kv_indices,
int64_t col_num,
int64_t bs);
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE);
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs);
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens);
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs);
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
......
......@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
)
def dcu_create_flashmla_kv_indices(
req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE = 64,
):
torch.ops.sgl_kernel.dcu_create_flashmla_kv_indices(req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE,
)
def get_mla_metadata(
cache_seqlens: torch.Tensor,
......
......@@ -293,3 +293,76 @@ def transfer_kv_all_layer_mla_lf_pf(
block_quota,
num_warps_per_block,
)
def dcu_assign_req_to_token_pool(
req_pool_indices:torch.Tensor,
req_to_token:torch.Tensor,
allocate_lens:torch.Tensor,
new_allocate_lens:torch.Tensor,
out_cache_loc:torch.Tensor,
shape:int,
bs:int,
):
torch.ops.sgl_kernel.dcu_assign_req_to_token_pool(
req_pool_indices,
req_to_token,
allocate_lens,
new_allocate_lens,
out_cache_loc,
shape,
bs,
)
def dcu_get_last_loc(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
):
result = torch.ops.sgl_kernel.dcu_get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return result
def dcu_assign_extend_cache_locs(
req_pool_indices: torch.Tensor,
req_to_token: torch.Tensor,
start_offset: torch.Tensor,
end_offset: torch.Tensor,
out_cache_loc: torch.Tensor,
pool_len: int,
bs: int,
):
torch.ops.sgl_kernel.dcu_assign_extend_cache_locs(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len,
bs,
)
def dcu_create_chunked_prefix_cache_kv_indices(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
chunk_starts: torch.Tensor,
chunk_seq_lens: torch.Tensor,
chunk_cu_seq_lens: torch.Tensor,
chunk_kv_indices: torch.Tensor,
col_num: int,
bs: int,
):
torch.ops.sgl_kernel.dcu_create_chunked_prefix_cache_kv_indices(
req_to_token,
req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
col_num,
bs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment