Commit 6741925c authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev' into v0.5.4_dev_linhai

parents 12b60933 93eb92f8
...@@ -163,6 +163,9 @@ class Envs: ...@@ -163,6 +163,9 @@ class Envs:
SGLANG_USE_AITER = EnvBool(False) SGLANG_USE_AITER = EnvBool(False)
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False) SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False) SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)
# DCU Lightop
SGLANG_USE_LIGHTOP = EnvBool(False)
# Quantization # Quantization
SGLANG_INT4_WEIGHT = EnvBool(False) SGLANG_INT4_WEIGHT = EnvBool(False)
......
...@@ -127,9 +127,10 @@ def flash_attn_varlen_func( ...@@ -127,9 +127,10 @@ def flash_attn_varlen_func(
k=k, k=k,
v=v, v=v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q, max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=causal, causal=causal,
return_attn_probs=return_softmax_lse,
) )
\ No newline at end of file
...@@ -167,8 +167,6 @@ class RMSNorm(CustomOp): ...@@ -167,8 +167,6 @@ class RMSNorm(CustomOp):
if residual is not None: if residual is not None:
try: try:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm( fused_add_rms_norm(
x, x,
residual, residual,
...@@ -177,6 +175,8 @@ class RMSNorm(CustomOp): ...@@ -177,6 +175,8 @@ class RMSNorm(CustomOp):
) )
return x, residual return x, residual
except TypeError: except TypeError:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm( fused_add_rms_norm(
output, output,
x, x,
......
...@@ -28,6 +28,8 @@ from typing import ( ...@@ -28,6 +28,8 @@ from typing import (
runtime_checkable, runtime_checkable,
) )
from numpy import dtype
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -68,6 +70,7 @@ _is_cpu = is_cpu() ...@@ -68,6 +70,7 @@ _is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu() _is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_use_lightop = get_bool_env_var("SGLANG_USE_LIGHTOP")
if _is_cuda: if _is_cuda:
from sgl_kernel import moe_fused_gate from sgl_kernel import moe_fused_gate
...@@ -79,6 +82,8 @@ if _use_aiter: ...@@ -79,6 +82,8 @@ if _use_aiter:
from aiter import biased_grouped_topk as aiter_biased_grouped_topk from aiter import biased_grouped_topk as aiter_biased_grouped_topk
except ImportError: except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _use_lightop:
from lightop import op as op
if _is_npu: if _is_npu:
import torch_npu import torch_npu
...@@ -725,6 +730,18 @@ def biased_grouped_topk_gpu( ...@@ -725,6 +730,18 @@ def biased_grouped_topk_gpu(
routed_scaling_factor, routed_scaling_factor,
) )
return topk_weights, topk_ids return topk_weights, topk_ids
elif _use_lightop:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = op.moe_fused_gate(
gating_output.to(dtype=torch.float32), # or bfloat16
correction_bias,
num_expert_group,
topk_group,
topk,
0, # 0 in vllm
routed_scaling_factor,
)
return topk_weights, topk_ids
else: else:
return biased_grouped_topk_impl( return biased_grouped_topk_impl(
hidden_states, hidden_states,
......
...@@ -22,6 +22,8 @@ from sglang.srt.utils import ( ...@@ -22,6 +22,8 @@ from sglang.srt.utils import (
is_xpu, is_xpu,
) )
from sglang.srt.utils import direct_register_custom_op
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
...@@ -29,6 +31,7 @@ _is_npu = is_npu() ...@@ -29,6 +31,7 @@ _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_xpu = is_xpu() _is_xpu = is_xpu()
_use_lightop = get_bool_env_var("SGLANG_USE_LIGHTOP")
if _is_cuda: if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
...@@ -57,6 +60,34 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: ...@@ -57,6 +60,34 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x = torch.stack((-x2, x1), dim=-1) x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) return x.flatten(-2)
# for dcu
@triton.jit
def deepseek_scaling_rotary_emb_kernel_gptj(cos_sin, q, stride1: int,
stride2: int, stride_cs: int,
dim1: int, dim2: int, dim3: int,
BLOCK_SIZE: tl.constexpr):
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
pid2 = tl.program_id(2)
offsets_cs = tl.arange(0, BLOCK_SIZE) + pid2 * BLOCK_SIZE
offsets_q = tl.arange(0, BLOCK_SIZE * 2) + pid2 * BLOCK_SIZE * 2
offsets = pid0 * stride1 + pid1 * stride2 + offsets_q
mask = offsets_cs < dim3
mask2 = offsets_q < dim3 * 2
v_cos = tl.load(cos_sin + pid0 * stride_cs + offsets_cs, mask=mask)
v_cos2 = tl.interleave(v_cos, v_cos)
v_sin = tl.load(cos_sin + pid0 * stride_cs + dim3 + offsets_cs, mask=mask)
v_sin2 = tl.interleave(v_sin, v_sin)
x12 = tl.load(q + offsets, mask=mask2)
x1, x2 = tl.split(x12.reshape([BLOCK_SIZE, 2]))
# we are both reading and writing 'q'; make sure all warps are in sync
tl.debug_barrier()
x12_ = tl.ravel(tl.join(-x2, x1))
x12 = x12 * v_cos2 + x12_ * v_sin2
tl.store(q + offsets, x12, mask=mask2)
def _apply_rotary_emb( def _apply_rotary_emb(
x: torch.Tensor, x: torch.Tensor,
...@@ -736,7 +767,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -736,7 +767,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
# Re-dispatch # Re-dispatch
if _is_hip: if _is_hip:
self._forward_method = self.forward_native if _use_lightop:
self._forward_method = self.forward_dcu
else:
self._forward_method = self.forward_native
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** ( pos_freqs = self.base ** (
...@@ -778,6 +812,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -778,6 +812,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
sin = freqs.sin() * self.mscale sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def rotary_embedding_deepseek_fuse(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
from lightop import op
op.rotary_embedding_deepseek_fuse(positions, query, key, head_size, cos_sin_cache, is_neox_style)
def rotary_embedding_deepseek_fuse_fake(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
pass
direct_register_custom_op(
op_name="rotary_embedding_deepseek_fuse",
op_func=rotary_embedding_deepseek_fuse,
mutates_args=["query", "key"],
fake_impl=rotary_embedding_deepseek_fuse_fake,
)
def forward_native( def forward_native(
self, self,
...@@ -819,6 +871,77 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -819,6 +871,77 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
query = query_rot query = query_rot
key = key_rot key = key_rot
return query.to(dtype), key.to(dtype) return query.to(dtype), key.to(dtype)
def forward_dcu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert key is not None
if self.cos_sin_cache.device != positions.device:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
if query.device.type == 'cuda' and not self.is_neox_style: # not self.reference ?
assert len(query.shape) == 3
def call(q):
BLOCK_SIZE = 64
grid = (
q.shape[-3],
q.shape[-2],
triton.cdiv(self.rotary_dim // 2, BLOCK_SIZE),
)
deepseek_scaling_rotary_emb_kernel_gptj[grid](
cos_sin,
q,
stride1=q.stride()[-3],
stride2=q.stride()[-2],
stride_cs=cos_sin.stride()[-2],
dim1=q.shape[0],
dim2=q.shape[1],
dim3=self.rotary_dim // 2,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=1)
if _use_lightop:
torch.ops.sglang.rotary_embedding_deepseek_fuse(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style)
else:
call(query)
call(key)
return query, key
else:
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
def forward_npu( def forward_npu(
self, self,
......
...@@ -487,29 +487,15 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -487,29 +487,15 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(extend_num_tokens,), dtype=torch.int64, device=self.device (extend_num_tokens,), dtype=torch.int64, device=self.device
) )
if self.sglang_kvalloc_kernel: if self.sglang_kvalloc_kernel:
if bs < 3: dcu_alloc_extend_kernel(
dcu_alloc_extend_kernel( pre_lens_ptr = prefix_lens,
pre_lens_ptr = prefix_lens, seq_lens_ptr = seq_lens,
seq_lens_ptr = seq_lens, last_loc_ptr = last_loc,
last_loc_ptr = last_loc, free_page_ptr = self.free_pages,
free_page_ptr = self.free_pages, out_indices = out_indices,
out_indices = out_indices, bs = bs,
bs = bs, page_size = self.page_size,
bs_upper = next_power_of_2(bs), )
page_size = self.page_size,
max_num_extend_tokens = self.seen_max_num_extend_tokens_next_power_of_2,
)
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
else: else:
alloc_extend_kernel[(bs,)]( alloc_extend_kernel[(bs,)](
prefix_lens, prefix_lens,
...@@ -560,7 +546,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -560,7 +546,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
free_page_ptr = self.free_pages, free_page_ptr = self.free_pages,
out_indices = out_indices, out_indices = out_indices,
bs = bs, bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size, page_size = self.page_size,
) )
else: else:
......
...@@ -131,9 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -131,9 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size, int max_num_extend_tokens) -> ()"); m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel); m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()"); m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel); m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel);
m.def( m.def(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " "transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
......
...@@ -585,12 +585,12 @@ __global__ void launch_alloc_decode_kernel( ...@@ -585,12 +585,12 @@ __global__ void launch_alloc_decode_kernel(
const int32_t* last_loc_ptr, const int32_t* last_loc_ptr,
const int64_t* free_page_ptr, const int64_t* free_page_ptr,
int64_t* out_indices, int64_t* out_indices,
int64_t bs_upper, int64_t bs,
int64_t page_size) { int64_t page_size) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs_upper) return; if (pid >= bs) return;
int64_t seq_len = seq_lens_ptr[pid]; int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = seq_len - 1; int64_t pre_len = seq_len - 1;
...@@ -625,13 +625,12 @@ __global__ void launch_alloc_extend_kernel( ...@@ -625,13 +625,12 @@ __global__ void launch_alloc_extend_kernel(
const int64_t* last_loc_ptr, const int64_t* last_loc_ptr,
const int64_t* free_page_ptr, const int64_t* free_page_ptr,
int64_t* out_indices, int64_t* out_indices,
int64_t bs_upper, int64_t bs,
int64_t page_size, int64_t page_size)
int64_t max_num_extend_tokens)
{ {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs_upper) return; if (pid >= bs) return;
int64_t seq_len = seq_lens_ptr[pid]; int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = pre_lens_ptr[pid]; int64_t pre_len = pre_lens_ptr[pid];
...@@ -674,7 +673,7 @@ __global__ void launch_alloc_extend_kernel( ...@@ -674,7 +673,7 @@ __global__ void launch_alloc_extend_kernel(
} }
int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size; int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
for (int64_t offset = 0; offset < num_part2 && offset < max_num_extend_tokens; offset++) { for (int64_t offset = 0; offset < num_part2; offset++) {
int64_t page_idx = new_page_start_loc + offset / page_size; int64_t page_idx = new_page_start_loc + offset / page_size;
int64_t page_start = free_page_ptr[page_idx]; int64_t page_start = free_page_ptr[page_idx];
int64_t output_idx = output_start_loc + num_part1 + offset; int64_t output_idx = output_start_loc + num_part1 + offset;
...@@ -701,7 +700,6 @@ void dcu_alloc_decode_kernel( ...@@ -701,7 +700,6 @@ void dcu_alloc_decode_kernel(
const at::Tensor free_page_ptr, const at::Tensor free_page_ptr,
at::Tensor out_indices, at::Tensor out_indices,
int64_t bs, int64_t bs,
int64_t bs_upper,
int64_t page_size) { int64_t page_size) {
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr()); const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
...@@ -712,7 +710,7 @@ void dcu_alloc_decode_kernel( ...@@ -712,7 +710,7 @@ void dcu_alloc_decode_kernel(
int64_t block_size = 64; int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size; int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size); launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
...@@ -723,9 +721,7 @@ void dcu_alloc_extend_kernel( ...@@ -723,9 +721,7 @@ void dcu_alloc_extend_kernel(
const at::Tensor free_page_ptr, const at::Tensor free_page_ptr,
at::Tensor out_indices, at::Tensor out_indices,
int64_t bs, int64_t bs,
int64_t bs_upper, int64_t page_size) {
int64_t page_size,
int64_t max_num_extend_tokens) {
const int64_t* pre_lens_ptr1 = static_cast<const int64_t*>(pre_lens_ptr.data_ptr()); const int64_t* pre_lens_ptr1 = static_cast<const int64_t*>(pre_lens_ptr.data_ptr());
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr()); const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
...@@ -736,6 +732,6 @@ void dcu_alloc_extend_kernel( ...@@ -736,6 +732,6 @@ void dcu_alloc_extend_kernel(
int64_t block_size = 64; int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size; int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size, max_num_extend_tokens); launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
\ No newline at end of file
...@@ -545,9 +545,7 @@ void dcu_alloc_extend_kernel( ...@@ -545,9 +545,7 @@ void dcu_alloc_extend_kernel(
const at::Tensor free_page_ptr, const at::Tensor free_page_ptr,
at::Tensor out_indices, at::Tensor out_indices,
int64_t bs, int64_t bs,
int64_t bs_upper, int64_t page_size);
int64_t page_size,
int64_t max_num_extend_tokens);
void dcu_alloc_decode_kernel( void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr, const at::Tensor seq_lens_ptr,
...@@ -555,7 +553,6 @@ void dcu_alloc_decode_kernel( ...@@ -555,7 +553,6 @@ void dcu_alloc_decode_kernel(
const at::Tensor free_page_ptr, const at::Tensor free_page_ptr,
at::Tensor out_indices, at::Tensor out_indices,
int64_t bs, int64_t bs,
int64_t bs_upper,
int64_t page_size); int64_t page_size);
void transfer_kv_per_layer( void transfer_kv_per_layer(
......
...@@ -17,9 +17,7 @@ def dcu_alloc_extend_kernel( ...@@ -17,9 +17,7 @@ def dcu_alloc_extend_kernel(
free_page_ptr: torch.Tensor, free_page_ptr: torch.Tensor,
out_indices: torch.Tensor, out_indices: torch.Tensor,
bs: int, bs: int,
bs_upper: int,
page_size: int, page_size: int,
max_num_extend_tokens: int,
): ):
torch.ops.sgl_kernel.dcu_alloc_extend_kernel( torch.ops.sgl_kernel.dcu_alloc_extend_kernel(
pre_lens_ptr, pre_lens_ptr,
...@@ -28,9 +26,7 @@ def dcu_alloc_extend_kernel( ...@@ -28,9 +26,7 @@ def dcu_alloc_extend_kernel(
free_page_ptr, free_page_ptr,
out_indices, out_indices,
bs, bs,
bs_upper,
page_size, page_size,
max_num_extend_tokens,
) )
def dcu_alloc_decode_kernel( def dcu_alloc_decode_kernel(
...@@ -39,7 +35,6 @@ def dcu_alloc_decode_kernel( ...@@ -39,7 +35,6 @@ def dcu_alloc_decode_kernel(
free_page_ptr: torch.Tensor , free_page_ptr: torch.Tensor ,
out_indices: torch.Tensor , out_indices: torch.Tensor ,
bs: int, bs: int,
bs_upper: int,
page_size: int, page_size: int,
): ):
torch.ops.sgl_kernel.dcu_alloc_decode_kernel( torch.ops.sgl_kernel.dcu_alloc_decode_kernel(
...@@ -48,7 +43,6 @@ def dcu_alloc_decode_kernel( ...@@ -48,7 +43,6 @@ def dcu_alloc_decode_kernel(
free_page_ptr, free_page_ptr,
out_indices, out_indices,
bs, bs,
bs_upper,
page_size, page_size,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment