Commit 33fbf3ca authored by liucong's avatar liucong
Browse files

增加dcu_create_extend_after_decode_spec_info_kernel实现

parent cb4fb0ee
......@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc,
get_target_cache_loc,
)
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2, get_bool_env_var
from sgl_kernel.kvcacheio import dcu_create_extend_after_decode_spec_info
if is_cuda():
from sgl_kernel import (
......@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None
use_sglang_create_extend_after_decode_spec_info = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_DRAFT)
......@@ -679,14 +682,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids,
batch.seq_lens,
self.accept_length,
self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
if self.use_sglang_create_extend_after_decode_spec_info:
dcu_create_extend_after_decode_spec_info(
verified_id = batch.input_ids,
seq_lens = batch.seq_lens,
accept_lens = self.accept_length,
positions = self.positions,
new_verified_id = self.verified_id,
bs = max(speculative_num_steps + 1, len(batch.seq_lens)),
)
else:
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids,
batch.seq_lens,
self.accept_length,
self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
def generate_attn_arg_prefill(
self,
......
......@@ -125,6 +125,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/kvcacheio
*/
m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()");
m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info);
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
......
......@@ -694,6 +694,66 @@ __global__ void launch_alloc_extend_kernel(
}
}
__global__ void launch_create_extend_after_decode_spec_info_int32_kernel(
const int32_t* verified_id_ptr,
const int64_t* seq_lens_ptr,
const int32_t* accept_lens_ptr,
int64_t* positions_ptr,
int32_t* new_verified_id_ptr,
int64_t bs) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_length = seq_lens_ptr[pid];
int32_t accept_length = accept_lens_ptr[pid];
int32_t accept_len_cumsum = 0;
for (int32_t offset = 0; offset < pid; offset++) {
accept_len_cumsum += accept_lens_ptr[offset];
}
int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
for (int32_t offset = 0; offset < accept_length && offset < bs; offset++)
{
positions_ptr1[offset] = seq_length - accept_length + offset;
}
int32_t verified_idx = accept_len_cumsum + accept_length - 1;
new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}
__global__ void launch_create_extend_after_decode_spec_info_int64_kernel(
const int32_t* verified_id_ptr,
const int64_t* seq_lens_ptr,
const int64_t* accept_lens_ptr,
int64_t* positions_ptr,
int32_t* new_verified_id_ptr,
int64_t bs) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_length = seq_lens_ptr[pid];
int64_t accept_length = accept_lens_ptr[pid];
int64_t accept_len_cumsum = 0;
for (int64_t offset = 0; offset < pid; offset++) {
accept_len_cumsum += accept_lens_ptr[offset];
}
int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
for (int64_t offset = 0; offset < accept_length && offset < bs; offset++)
{
positions_ptr1[offset] = seq_length - accept_length + offset;
}
int64_t verified_idx = accept_len_cumsum + accept_length - 1;
new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}
void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
......@@ -734,4 +794,47 @@ void dcu_alloc_extend_kernel(
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
\ No newline at end of file
}
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
const at::Tensor accept_lens,
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs) {
const int32_t* verified_id_ptr;
const int64_t* seq_lens_ptr;
const int32_t* accept_lens_ptr_int32;
const int64_t* accept_lens_ptr_int64;
int64_t* positions_ptr;
int32_t* new_verified_id_ptr;
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
if (accept_lens.dtype() == torch::kInt32)
{
verified_id_ptr = static_cast<const int32_t*>(verified_id.data_ptr());
seq_lens_ptr = static_cast<const int64_t*>(seq_lens.data_ptr());
accept_lens_ptr_int32 = static_cast<const int32_t*>(accept_lens.data_ptr());
positions_ptr = static_cast<int64_t*>(positions.data_ptr());
new_verified_id_ptr = static_cast<int32_t*>(new_verified_id.data_ptr());
launch_create_extend_after_decode_spec_info_int32_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int32, positions_ptr, new_verified_id_ptr, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
else
{
verified_id_ptr = static_cast<const int32_t*>(verified_id.data_ptr());
seq_lens_ptr = static_cast<const int64_t*>(seq_lens.data_ptr());
accept_lens_ptr_int64 = static_cast<const int64_t*>(accept_lens.data_ptr());
positions_ptr = static_cast<int64_t*>(positions.data_ptr());
new_verified_id_ptr = static_cast<int32_t*>(new_verified_id.data_ptr());
launch_create_extend_after_decode_spec_info_int64_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int64, positions_ptr, new_verified_id_ptr, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
};
\ No newline at end of file
......@@ -538,6 +538,14 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
const at::Tensor accept_lens,
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs);
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
......
......@@ -9,6 +9,22 @@ def is_hip() -> bool:
_is_hip = is_hip()
def dcu_create_extend_after_decode_spec_info(
verified_id: torch.Tensor,
seq_lens: torch.Tensor,
accept_lens: torch.Tensor,
positions: torch.Tensor,
new_verified_id: torch.Tensor,
bs: int,
):
torch.ops.sgl_kernel.dcu_create_extend_after_decode_spec_info(
verified_id,
seq_lens,
accept_lens,
positions,
new_verified_id,
bs,
)
def dcu_alloc_extend_kernel(
pre_lens_ptr: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment