Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a1175a4e
Commit
a1175a4e
authored
Nov 22, 2025
by
maxiao1
Browse files
Merge remote-tracking branch 'origin/v0.5.4_dev' into sglang_v0.5.5
parents
0c006b88
31653dd9
Changes
62
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1126 additions
and
112 deletions
+1126
-112
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+122
-3
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+82
-51
python/sglang/srt/models/qwen3_next.py
python/sglang/srt/models/qwen3_next.py
+1
-1
python/sglang/srt/profile/prof.py
python/sglang/srt/profile/prof.py
+58
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-1
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+24
-0
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+22
-11
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+53
-9
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
+7
-7
sgl-kernel/csrc/attention/merge_attn_states.cu
sgl-kernel/csrc/attention/merge_attn_states.cu
+18
-1
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+30
-0
sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu
sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu
+1
-1
sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu
sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu
+1
-1
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+585
-1
sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/csrc/moe/moe_align_kernel.cu
+2
-1
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
+3
-3
sgl-kernel/csrc/quantization/gguf/ggml-common.h
sgl-kernel/csrc/quantization/gguf/ggml-common.h
+1
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+69
-0
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+20
-20
sgl-kernel/python/sgl_kernel/flash_mla.py
sgl-kernel/python/sgl_kernel/flash_mla.py
+20
-0
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
a1175a4e
...
...
@@ -168,6 +168,7 @@ MLA_ATTENTION_BACKENDS = [
"triton"
,
"flashmla"
,
"cutlass_mla"
,
"dcu_mla"
,
"trtllm_mla"
,
"ascend"
,
"nsa"
,
...
...
@@ -176,6 +177,7 @@ MLA_ATTENTION_BACKENDS = [
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS
=
[
"flashinfer"
,
"fa3"
,
"dcu_mla"
,
"fa4"
,
"flashmla"
,
"cutlass_mla"
,
...
...
@@ -207,7 +209,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
SGLANG_CI_SMALL_KV_SIZE
=
os
.
getenv
(
"SGLANG_CI_SMALL_KV_SIZE"
,
None
)
# Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S
=
300
UNBALANCED_MODEL_LOADING_TIMEOUT_S
=
3
6
00
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
=
3
...
...
@@ -511,6 +513,121 @@ class ModelRunner:
def
model_specific_adjustment
(
self
):
server_args
=
self
.
server_args
if
(
server_args
.
attention_backend
==
"intel_amx"
and
server_args
.
device
==
"cpu"
and
not
_is_cpu_amx_available
):
logger
.
info
(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
)
server_args
.
attention_backend
=
"torch_native"
if
(
server_args
.
attention_backend
==
"intel_xpu"
and
server_args
.
device
==
"xpu"
and
not
_is_xpu_xmx_available
):
logger
.
info
(
"The current platform does not support Intel XMX, will fallback to triton backend."
)
server_args
.
attention_backend
=
"triton"
if
server_args
.
prefill_attention_backend
is
not
None
and
(
server_args
.
prefill_attention_backend
==
server_args
.
decode_attention_backend
):
# override the default attention backend
server_args
.
attention_backend
=
server_args
.
prefill_attention_backend
if
(
getattr
(
self
.
model_config
.
hf_config
,
"dual_chunk_attention_config"
,
None
)
is
not
None
):
if
server_args
.
attention_backend
is
None
:
server_args
.
attention_backend
=
"dual_chunk_flash_attn"
logger
.
info
(
"Dual chunk attention is turned on by default."
)
elif
server_args
.
attention_backend
!=
"dual_chunk_flash_attn"
:
raise
ValueError
(
"Dual chunk attention is enabled, but attention backend is set to "
f
"
{
server_args
.
attention_backend
}
. Please set it to 'dual_chunk_flash_attn'."
)
if
server_args
.
attention_backend
is
None
:
"""
Auto select the fastest attention backend.
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 We will use Flashinfer backend on blackwell.
2.3 Otherwise, we will use triton backend.
"""
if
not
self
.
use_mla_backend
:
# MHA architecture
if
(
is_hopper_with_cuda_12_3
()
and
is_no_spec_infer_or_topk_one
(
server_args
)
and
is_fa3_default_architecture
(
self
.
model_config
.
hf_config
)
):
server_args
.
attention_backend
=
"fa3"
elif
_is_hip
:
server_args
.
attention_backend
=
"triton"
elif
_is_npu
:
server_args
.
attention_backend
=
"ascend"
else
:
server_args
.
attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
)
else
:
# MLA architecture
if
is_hopper_with_cuda_12_3
():
server_args
.
attention_backend
=
"fa3"
elif
is_sm100_supported
():
server_args
.
attention_backend
=
"flashinfer"
elif
_is_hip
:
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
# TODO current aiter only support head number 16 or 128 head number
if
head_num
==
128
or
head_num
==
16
:
server_args
.
attention_backend
=
"triton"
else
:
server_args
.
attention_backend
=
"triton"
elif
_is_npu
:
server_args
.
attention_backend
=
"ascend"
else
:
server_args
.
attention_backend
=
"triton"
log_info_on_rank0
(
logger
,
f
"Attention backend not explicitly specified. Use
{
server_args
.
attention_backend
}
backend by default."
,
)
elif
self
.
use_mla_backend
:
if
server_args
.
device
!=
"cpu"
:
if
server_args
.
attention_backend
in
MLA_ATTENTION_BACKENDS
:
logger
.
info
(
f
"MLA optimization is turned on. Use
{
server_args
.
attention_backend
}
backend."
)
else
:
raise
ValueError
(
f
"Invalid attention backend for MLA:
{
server_args
.
attention_backend
}
"
)
else
:
if
server_args
.
attention_backend
!=
"intel_amx"
:
raise
ValueError
(
"MLA optimization not supported on CPU except for intel_amx backend."
)
if
(
server_args
.
attention_backend
==
"fa3"
and
server_args
.
kv_cache_dtype
==
"fp8_e5m2"
):
logger
.
warning
(
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
"Setting attention backend to triton."
)
server_args
.
attention_backend
=
"triton"
if
server_args
.
enable_double_sparsity
:
logger
.
info
(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
...
...
@@ -1521,12 +1638,14 @@ class ModelRunner:
self
.
kv_cache_dtype
=
self
.
dtype
elif
self
.
server_args
.
kv_cache_dtype
==
"fp8_e5m2"
:
if
_is_hip
:
# Using natively supported format
self
.
kv_cache_dtype
=
torch
.
float8_e5m2fnuz
# self.kv_cache_dtype = torch.float8_e5m2fnuz
self
.
kv_cache_dtype
=
torch
.
float8_e5m2
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e5m2
elif
self
.
server_args
.
kv_cache_dtype
==
"fp8_e4m3"
:
if
_is_hip
:
# Using natively supported format
self
.
kv_cache_dtype
=
torch
.
float8_e4m3fnuz
# self.kv_cache_dtype = torch.float8_e4m3fnuz
self
.
kv_cache_dtype
=
torch
.
float8_e4m3fn
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e4m3fn
elif
self
.
server_args
.
kv_cache_dtype
in
(
"bf16"
,
"bfloat16"
):
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
a1175a4e
...
...
@@ -137,6 +137,7 @@ from sglang.srt.utils import (
make_layers
,
use_intel_amx_backend
,
)
from
sglang.srt.layers.attention.lightop_concat
import
concat_decode_opt
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
...
...
@@ -147,8 +148,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu
=
is_cpu
()
_device_sm
=
get_device_sm
()
_is_gfx95_supported
=
is_gfx95_supported
()
_user_lightop_moe_sum_mul_add
=
get_bool_env_var
(
"SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD"
)
_use_fused_silu_mul_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_SILU_MUL_QUANT"
)
_use_aiter_gfx95
=
_use_aiter
and
_is_gfx95_supported
_use_opt_cat_decode
=
get_bool_env_var
(
"SGLANG_USE_OPT_CAT"
)
if
_use_aiter_gfx95
:
from
sglang.srt.layers.quantization.quark.utils
import
quark_post_load_weights
...
...
@@ -181,6 +184,7 @@ elif _is_hip:
from
sglang.srt.layers.quantization.awq_triton
import
(
awq_dequantize_triton
as
awq_dequantize
,
)
from
sgl_kernel
import
merge_state_v2
elif
_is_npu
:
import
custom_ops
# noqa: F401
import
sgl_kernel_npu
# noqa: F401
...
...
@@ -366,6 +370,10 @@ def handle_attention_flashmla(attn, forward_batch):
return
_handle_attention_backend
(
attn
,
forward_batch
,
"flashmla"
)
def
handle_attention_dcu_mla
(
attn
,
forward_batch
):
return
_handle_attention_backend
(
attn
,
forward_batch
,
"dcu_mla"
)
def
handle_attention_cutlass_mla
(
attn
,
forward_batch
):
return
_handle_attention_backend
(
attn
,
forward_batch
,
"cutlass_mla"
)
...
...
@@ -507,11 +515,13 @@ class DeepseekV2MLP(nn.Module):
x
=
(
x
,
None
,
y
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
should_allreduce_fusion
or
use_reduce_scatter
,
)
if
_use_fused_silu_mul_quant
:
x
,
_
=
self
.
down_proj
(
gate_up
,
skip_all_reduce
=
should_allreduce_fusion
or
use_reduce_scatter
,
use_fused_silu_mul_quant
=
True
)
else
:
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
should_allreduce_fusion
or
use_reduce_scatter
)
return
x
...
...
@@ -811,52 +821,58 @@ class DeepseekV2MoE(nn.Module):
self
.
shared_experts
.
gate_up_proj
):
return
self
.
forward_cpu
(
hidden_states
,
should_allreduce_fusion
)
if
hidden_states
.
shape
[
0
]
>
0
:
if
not
self
.
_fuse_shared_experts_inside_sbo
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
,
gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
if
_user_lightop_moe_sum_mul_add
:
if
hidden_states
.
shape
[
0
]
>
0
:
if
not
self
.
_fuse_shared_experts_inside_sbo
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
,
gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
,
shared_output
=
shared_output
)
else
:
shared_output
=
None
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
if
hidden_states
.
shape
[
0
]
>
0
:
if
not
self
.
_fuse_shared_experts_inside_sbo
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
,
gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
shared_output
=
None
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
if
self
.
_fuse_shared_experts_inside_sbo
:
shared_output
=
None
if
self
.
_fuse_shared_experts_inside_sbo
:
shared_output
=
None
def
_forward_shared_experts_and_put_results
():
nonlocal
shared_output
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
,
gemm_output_zero_allocator
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
,
**
(
dict
(
forward_shared_experts
=
_forward_shared_experts_and_put_results
,
alt_stream
=
self
.
alt_stream
,
)
if
self
.
_fuse_shared_experts_inside_sbo
else
{}
),
)
if
(
not
_is_cuda
and
not
_use_aiter
or
isinstance
(
self
.
experts
.
quant_method
,
CompressedTensorsWNA16AMXEPMoEMethod
def
_forward_shared_experts_and_put_results
():
nonlocal
shared_output
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
,
gemm_output_zero_allocator
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
,
**
(
dict
(
forward_shared_experts
=
_forward_shared_experts_and_put_results
,
alt_stream
=
self
.
alt_stream
,
)
if
self
.
_fuse_shared_experts_inside_sbo
else
{}
),
)
or
isinstance
(
self
.
experts
.
quant_method
,
CompressedTensorsWNA16MoEMethod
)
):
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
final_hidden_states
+=
shared_output
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
if
(
self
.
tp_size
>
1
and
not
should_allreduce_fusion
...
...
@@ -1766,7 +1782,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
rotary_emb
.
is_neox_style
,
)
else
:
q
=
torch
.
cat
([
q_nope_out
,
q_pe
],
dim
=-
1
)
if
_use_opt_cat_decode
and
q_nope_out
.
shape
[
0
]
<
1024
:
q
=
concat_decode_opt
(
q_nope_out
,
q_pe
,
dim
=
2
)
else
:
q
=
torch
.
cat
([
q_nope_out
,
q_pe
],
dim
=-
1
)
k
=
torch
.
cat
([
k_nope
,
k_pe
],
dim
=-
1
)
attn_output
=
self
.
attn_mqa
(
...
...
@@ -2365,9 +2384,20 @@ class DeepseekV2AttentionMLA(nn.Module):
kv_indices
=
forward_batch
.
prefix_chunk_kv_indices
[
i
]
# Fetch latent cache from memory pool with precomputed chunked kv indices
kv_a_normed
,
k_pe
=
self
.
_get_mla_kv_buffer
(
kv_indices
,
q
.
dtype
,
forward_batch
latent_cache_buf
,
dtype
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer_DeepSeekV2
(
self
.
attn_mha
.
layer_id
)
latent_cache
=
(
latent_cache_buf
[
forward_batch
.
prefix_chunk_kv_indices
[
i
]]
.
contiguous
()
.
view
(
dtype
)
.
to
(
q
.
dtype
)
)
kv_a_normed
,
k_pe
=
latent_cache
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_a_normed
=
kv_a_normed
.
squeeze
(
1
).
contiguous
()
kv
=
self
.
kv_b_proj
(
kv_a_normed
)[
0
]
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
...
...
@@ -3838,6 +3868,7 @@ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
AttentionBackendRegistry
.
register
(
"flashinfer"
,
handle_attention_flashinfer
)
AttentionBackendRegistry
.
register
(
"fa3"
,
handle_attention_fa3
)
AttentionBackendRegistry
.
register
(
"flashmla"
,
handle_attention_flashmla
)
AttentionBackendRegistry
.
register
(
"dcu_mla"
,
handle_attention_dcu_mla
)
AttentionBackendRegistry
.
register
(
"cutlass_mla"
,
handle_attention_cutlass_mla
)
AttentionBackendRegistry
.
register
(
"fa4"
,
handle_attention_fa4
)
AttentionBackendRegistry
.
register
(
"trtllm_mla"
,
handle_attention_trtllm_mla
)
...
...
python/sglang/srt/models/qwen3_next.py
View file @
a1175a4e
...
...
@@ -351,7 +351,7 @@ class Qwen3GatedDeltaNet(nn.Module):
def
_forward_input_proj
(
self
,
hidden_states
:
torch
.
Tensor
):
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
if
not
_is_npu
else
0
seq_len
,
_
=
hidden_states
.
shape
if
seq_len
<
DUAL_STREAM_TOKEN_THRESHOLD
:
if
seq_len
<
DUAL_STREAM_TOKEN_THRESHOLD
and
self
.
alt_stream
is
not
None
:
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
projected_states_qkvz
,
_
=
self
.
in_proj_qkvz
(
hidden_states
)
...
...
python/sglang/srt/profile/prof.py
0 → 100644
View file @
a1175a4e
from
ctypes
import
*
import
os
import
time
import
threading
class
Prof
:
def
__init__
(
self
):
self
.
use_roctx
=
os
.
getenv
(
'SGLANG_HIP_PROF'
)
is
not
None
if
self
.
use_roctx
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctxRangePushA
.
argtypes
=
[
c_char_p
]
self
.
lib
.
roctxRangePushA
.
restype
=
c_int
self
.
lib
.
roctxRangePop
.
restype
=
c_int
self
.
tm
=
time
.
perf_counter
()
self
.
push_depth
=
{}
def
StartTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_start
()
self
.
roc_tracer_flag
=
True
def
StopTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_stop
()
self
.
roc_tracer_flag
=
False
def
thread_depth_add
(
self
,
num
):
current_thread
=
threading
.
current_thread
()
thread_id
=
current_thread
.
ident
if
thread_id
not
in
self
.
push_depth
.
keys
():
self
.
push_depth
[
thread_id
]
=
0
if
num
<
0
and
self
.
push_depth
[
thread_id
]
==
0
:
return
False
self
.
push_depth
[
thread_id
]
+=
num
return
True
def
ProfRangePush
(
self
,
message
):
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
profile
.
lib
.
roctxRangePushA
(
message
.
encode
(
'utf-8'
))
profile
.
lib
.
roctxRangePushA
(
message
.
encode
(
'utf-8'
))
self
.
thread_depth_add
(
1
)
def
ProfRangePop
(
self
):
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
if
not
self
.
thread_depth_add
(
-
1
):
return
profile
.
lib
.
roctxRangePop
()
def
ProfRangeAutoPush
(
self
,
message
):
self
.
ProfRangePop
()
self
.
ProfRangePush
(
message
)
profile
=
Prof
()
python/sglang/srt/server_args.py
View file @
a1175a4e
...
...
@@ -103,6 +103,8 @@ QUANTIZATION_CHOICES = [
"mxfp4"
,
"auto-round"
,
"compressed-tensors"
,
# for Ktransformers
"slimquant_w4a8_marlin"
,
"slimquant_marlin"
,
]
ATTENTION_BACKEND_CHOICES
=
[
...
...
@@ -111,6 +113,8 @@ ATTENTION_BACKEND_CHOICES = [
"torch_native"
,
"flex_attention"
,
"nsa"
,
# ransplant from vllm
"dcu_mla"
,
# NVIDIA specific
"cutlass_mla"
,
"fa3"
,
...
...
@@ -1198,9 +1202,11 @@ class ServerArgs:
if
(
self
.
attention_backend
==
"flashmla"
or
self
.
decode_attention_backend
==
"flashmla"
or
self
.
attention_backend
==
"dcu_mla"
or
self
.
decode_attention_backend
==
"dcu_mla"
):
logger
.
warning
(
"FlashMLA only supports a page_size of 64, change page_size to 64."
"FlashMLA
/DCU MLA
only supports a page_size of 64, change page_size to 64."
)
self
.
page_size
=
64
...
...
python/sglang/srt/speculative/draft_utils.py
View file @
a1175a4e
...
...
@@ -46,6 +46,7 @@ class DraftBackendFactory:
else
self
.
_create_triton_decode_backend
),
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"dcu_mla"
:
self
.
_create_dcumla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
"nsa"
:
self
.
_create_nsa_decode_backend
,
...
...
@@ -70,6 +71,7 @@ class DraftBackendFactory:
else
self
.
_create_triton_prefill_backend
),
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"dcu_mla"
:
self
.
_create_dcumla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"nsa"
:
self
.
_create_nsa_prefill_backend
,
...
...
@@ -151,6 +153,15 @@ class DraftBackendFactory:
return
FlashMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_dcumla_decode_backend
(
self
):
from
sglang.srt.layers.attention.dcu_mla_backend
import
(
DCUMLAMultiStepDraftBackend
,
)
return
DCUMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mha_decode_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
...
...
@@ -240,3 +251,16 @@ class DraftBackendFactory:
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
def
_create_dcumla_prefill_backend
(
self
):
# logger.warning(
# "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
python/sglang/srt/speculative/eagle_info.py
View file @
a1175a4e
...
...
@@ -38,9 +38,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc
,
get_target_cache_loc
,
)
from
sglang.srt.utils
import
is_cuda
,
is_npu
,
next_power_of_2
_is_npu
=
is_npu
()
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_create_extend_after_decode_spec_info
if
is_cuda
():
from
sgl_kernel
import
(
...
...
@@ -620,6 +619,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
use_sglang_create_extend_after_decode_spec_info
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
def
__post_init__
(
self
):
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
...
...
@@ -684,14 +685,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
if
self
.
use_sglang_create_extend_after_decode_spec_info
:
dcu_create_extend_after_decode_spec_info
(
verified_id
=
batch
.
input_ids
,
seq_lens
=
batch
.
seq_lens
,
accept_lens
=
self
.
accept_length
,
positions
=
self
.
positions
,
new_verified_id
=
self
.
verified_id
,
bs
=
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
)),
)
else
:
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
def
generate_attn_arg_prefill
(
self
,
...
...
python/sglang/srt/speculative/eagle_info_v2.py
View file @
a1175a4e
...
...
@@ -34,6 +34,12 @@ _is_cuda = is_cuda()
_is_hip
=
is_hip
()
_is_npu
=
is_npu
()
from
sglang.srt.utils
import
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_assign_req_to_token_pool
,
dcu_assign_extend_cache_locs
import
logging
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
...
...
@@ -79,6 +85,9 @@ def assign_draft_cache_locs_page_size_1(
@
dataclass
class
EagleDraftInputV2Mixin
:
use_sglang_assign_req_to_token_pool
=
get_bool_env_var
(
"SGLANG_ASSIGN_REQ_TO_TOKEN_POOL"
)
def
prepare_for_decode
(
self
:
EagleDraftInput
,
batch
:
ScheduleBatch
):
from
sglang.srt.speculative.spec_utils
import
assign_req_to_token_pool_func
...
...
@@ -114,15 +123,26 @@ class EagleDraftInputV2Mixin:
extend_num_tokens
,
)
assign_req_to_token_pool_func
(
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
self
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
bs
,
)
if
self
.
use_sglang_assign_req_to_token_pool
:
dcu_assign_req_to_token_pool
(
req_pool_indices
=
batch
.
req_pool_indices
,
req_to_token
=
batch
.
req_to_token_pool
.
req_to_token
,
allocate_lens
=
self
.
allocate_lens
,
new_allocate_lens
=
new_allocate_lens
,
out_cache_loc
=
out_cache_loc
,
shape
=
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
=
bs
,
)
else
:
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
self
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
self
.
allocate_lens
=
new_allocate_lens
# FIXME(lsyin): make this sync optional
...
...
@@ -191,6 +211,9 @@ class EagleDraftInputV2Mixin:
@
dataclass
class
EagleVerifyInputV2Mixin
:
use_sglang_assign_extend_cache_locs
=
get_bool_env_var
(
"SGLANG_ASSIGN_EXTEND_CACHE_LOCS"
)
def
prepare_for_v2_verify
(
self
:
EagleVerifyInput
,
req_to_token_pool
:
ReqToTokenPool
,
...
...
@@ -211,6 +234,27 @@ class EagleVerifyInputV2Mixin:
device
=
device
,
)
if
self
.
use_sglang_assign_extend_cache_locs
:
dcu_assign_extend_cache_locs
(
batch
.
req_pool_indices
,
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
batch
.
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
,
)
else
:
assign_extend_cache_locs
[(
bs
,)](
batch
.
req_pool_indices
,
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
batch
.
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
# Get a forward batch
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
...
...
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
View file @
a1175a4e
...
...
@@ -165,10 +165,10 @@ DINLINE void start_sync(
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__
scoped
_atomic_store
_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_SYSTEM
);
__
hip
_atomic_store
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__
HIP_
MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__
scoped
_atomic_load
_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_
DEVICE
)
<
while
(
__
hip
_atomic_load
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
,
__
HIP_
MEMORY_SCOPE_
AGENT
)
<
flag
)
;
}
...
...
@@ -211,16 +211,16 @@ DINLINE void end_sync(
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__
scoped
_atomic_store
_n
(
__
hip
_atomic_store
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
,
__MEMORY_SCOPE_SYSTEM
);
__
HIP_
MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__
scoped
_atomic_load
_n
(
while
(
__
hip
_atomic_load
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
,
__MEMORY_SCOPE_
DEVICE
)
<
flag
)
__
HIP_
MEMORY_SCOPE_
AGENT
)
<
flag
)
;
}
__syncthreads
();
...
...
sgl-kernel/csrc/attention/merge_attn_states.cu
View file @
a1175a4e
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_bf16.h>
#endif
#include <algorithm>
#include <optional>
#include "pytorch_extension_utils.h"
#include "pytorch_extension_utils
_rocm
.h"
// Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel.
...
...
@@ -27,6 +31,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) {
d
=
__float2bfloat16
(
s
);
}
inline
void
check_shape
(
const
at
::
Tensor
&
a
,
const
at
::
Tensor
&
b
,
const
char
*
a_name
,
const
char
*
b_name
)
{
TORCH_CHECK
(
a
.
dim
()
==
b
.
dim
(),
a_name
,
".dim() != "
,
b_name
,
".dim(). "
,
a
.
dim
(),
" vs "
,
b
.
dim
());
for
(
int
i
=
0
;
i
<
a
.
dim
();
++
i
)
{
TORCH_CHECK
(
a
.
size
(
i
)
==
b
.
size
(
i
),
a_name
,
".size("
,
i
,
") != "
,
b_name
,
".size("
,
i
,
")"
);
}
}
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
template
<
typename
scalar_t
,
const
uint
NUM_THREADS
>
__global__
void
merge_attn_states_kernel
(
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
a1175a4e
...
...
@@ -19,6 +19,14 @@ limitations under the License.
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND
(
sgl_kernel
,
m
)
{
/*
* From FlashMLA
*/
m
.
def
(
"dcu_create_flashmla_kv_indices(Tensor req_to_token, Tensor req_pool_indices,Tensor page_kernel_lens, Tensor? kv_start_idx, Tensor kv_indices, int req_to_token_stride, int kv_indices_stride, int PAGED_SIZE) -> ()"
);
m
.
impl
(
"dcu_create_flashmla_kv_indices"
,
torch
::
kCUDA
,
&
dcu_create_flashmla_kv_indices
);
/*
* From csrc/activation
*/
...
...
@@ -34,6 +42,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m
.
def
(
"gelu_quick(Tensor! out, Tensor input) -> ()"
);
m
.
impl
(
"gelu_quick"
,
torch
::
kCUDA
,
&
gelu_quick
);
/*
* From csrc/attention
*/
m
.
def
(
"merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
/*
* From csrc/allreduce
*/
...
...
@@ -119,6 +133,22 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/kvcacheio
*/
m
.
def
(
"dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"
);
m
.
impl
(
"dcu_create_extend_after_decode_spec_info"
,
torch
::
kCUDA
,
&
dcu_create_extend_after_decode_spec_info
);
m
.
def
(
"dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()"
);
m
.
impl
(
"dcu_create_chunked_prefix_cache_kv_indices"
,
torch
::
kCUDA
,
&
dcu_create_chunked_prefix_cache_kv_indices
);
m
.
def
(
"dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()"
);
m
.
impl
(
"dcu_assign_extend_cache_locs"
,
torch
::
kCUDA
,
&
dcu_assign_extend_cache_locs
);
m
.
def
(
"dcu_get_last_loc(Tensor req_to_token, Tensor req_pool_indices, Tensor prefix_lens) -> Tensor"
);
m
.
impl
(
"dcu_get_last_loc"
,
torch
::
kCUDA
,
&
dcu_get_last_loc
);
m
.
def
(
"dcu_assign_req_to_token_pool(Tensor req_pool_indices_ptr,Tensor req_to_token_ptr,Tensor allocate_lens_ptr,Tensor new_allocate_lens,Tensor out_cache_loc_ptr,int shape,int bs) -> ()"
);
m
.
impl
(
"dcu_assign_req_to_token_pool"
,
torch
::
kCUDA
,
&
dcu_assign_req_to_token_pool
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
m
.
def
(
"dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_decode_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_decode_kernel
);
m
.
def
(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
...
...
sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu
View file @
a1175a4e
...
...
@@ -25,7 +25,7 @@
#define INTRIN_M 16
#define INTRIN_N 16
#define INTRIN_K 32
#define WARP_SIZE
32
#define WARP_SIZE
64
#define SMEM_PAD_A 0
#define SMEM_PAD_B 0
#define PACK_SIZE 16
...
...
sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu
View file @
a1175a4e
...
...
@@ -25,7 +25,7 @@
#define INTRIN_M 16
#define INTRIN_N 16
#define INTRIN_K 32
#define WARP_SIZE
32
#define WARP_SIZE
64
#define SMEM_PAD_A 0
#define SMEM_PAD_B 0
#define PACK_SIZE 16
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
a1175a4e
...
...
@@ -5,7 +5,7 @@
#include <cstdint>
#ifndef USE_ROCM
#define WARP_SIZE
32
#define WARP_SIZE
64
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
...
...
@@ -805,3 +805,587 @@ void transfer_kv_all_layer_direct_lf_pf(
int64_t
page_size
)
{
transfer_kv_page_first_direct_impl
<
true
>
(
src_ptrs
,
dst_ptrs
,
src_indices
,
dst_indices
,
0
,
page_size
);
}
__device__
int64_t
ceil_div
(
int64_t
a
,
int64_t
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
__device__
int64_t
safe_min
(
int64_t
a
,
int64_t
b
)
{
return
a
<
b
?
a
:
b
;
}
__global__
void
launch_alloc_decode_kernel
(
const
int64_t
*
seq_lens_ptr
,
const
int32_t
*
last_loc_ptr
,
const
int64_t
*
free_page_ptr
,
int64_t
*
out_indices
,
int64_t
bs
,
int64_t
page_size
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_len
=
seq_lens_ptr
[
pid
];
int64_t
pre_len
=
seq_len
-
1
;
int64_t
num_page_start_loc_self
=
ceil_div
(
seq_len
,
page_size
)
-
ceil_div
(
pre_len
,
page_size
);
int64_t
sum_num_new_pages
=
0
;
for
(
int64_t
i
=
0
;
i
<=
pid
;
i
++
)
{
int64_t
other_seq_len
=
seq_lens_ptr
[
i
];
int64_t
other_pre_len
=
(
i
<=
pid
)
?
(
other_seq_len
-
1
)
:
other_seq_len
;
int64_t
other_num_pages_after
=
ceil_div
(
other_seq_len
,
page_size
);
int64_t
other_num_pages_before
=
ceil_div
(
other_pre_len
,
page_size
);
int64_t
other_num_new_pages
=
other_num_pages_after
-
other_num_pages_before
;
sum_num_new_pages
+=
other_num_new_pages
;
}
int64_t
new_page_start_loc
=
sum_num_new_pages
-
num_page_start_loc_self
;
if
(
num_page_start_loc_self
==
0
)
{
int32_t
last_loc
=
last_loc_ptr
[
pid
];
out_indices
[
pid
]
=
last_loc
+
1
;
}
else
{
int64_t
page
=
free_page_ptr
[
new_page_start_loc
];
out_indices
[
pid
]
=
page
*
page_size
;
}
}
__global__
void
launch_alloc_extend_kernel
(
const
int64_t
*
pre_lens_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int64_t
*
last_loc_ptr
,
const
int64_t
*
free_page_ptr
,
int64_t
*
out_indices
,
int64_t
bs
,
int64_t
page_size
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_len
=
seq_lens_ptr
[
pid
];
int64_t
pre_len
=
pre_lens_ptr
[
pid
];
int64_t
extend_len
=
seq_len
-
pre_len
;
int64_t
sum_extend_lens
=
0
;
for
(
int64_t
i
=
0
;
i
<=
pid
;
i
++
)
{
int64_t
other_seq_len
=
seq_lens_ptr
[
i
];
int64_t
other_pre_len
=
pre_lens_ptr
[
i
];
int64_t
other_extend_len
=
other_seq_len
-
other_pre_len
;
sum_extend_lens
+=
other_extend_len
;
}
int64_t
output_start_loc
=
sum_extend_lens
-
extend_len
;
int64_t
num_page_start_loc_self
=
ceil_div
(
seq_len
,
page_size
)
-
ceil_div
(
pre_len
,
page_size
);
int64_t
sum_num_new_pages
=
0
;
for
(
int64_t
i
=
0
;
i
<=
pid
;
i
++
)
{
int64_t
other_seq_len
=
seq_lens_ptr
[
i
];
int64_t
other_pre_len
=
pre_lens_ptr
[
i
];
int64_t
other_num_pages_after
=
ceil_div
(
other_seq_len
,
page_size
);
int64_t
other_num_pages_before
=
ceil_div
(
other_pre_len
,
page_size
);
int64_t
other_num_new_pages
=
other_num_pages_after
-
other_num_pages_before
;
sum_num_new_pages
+=
other_num_new_pages
;
}
int64_t
new_page_start_loc
=
sum_num_new_pages
-
num_page_start_loc_self
;
int64_t
last_loc
=
last_loc_ptr
[
pid
];
int64_t
num_part1
=
safe_min
(
seq_len
,
ceil_div
(
pre_len
,
page_size
)
*
page_size
)
-
pre_len
;
for
(
int64_t
offset
=
0
;
offset
<
num_part1
&&
offset
<
page_size
;
offset
++
)
{
int64_t
output_idx
=
output_start_loc
+
offset
;
out_indices
[
output_idx
]
=
last_loc
+
1
+
offset
;
}
if
(
pre_len
+
num_part1
==
seq_len
)
{
return
;
}
int64_t
num_part2
=
(
seq_len
/
page_size
)
*
page_size
-
ceil_div
(
pre_len
,
page_size
)
*
page_size
;
for
(
int64_t
offset
=
0
;
offset
<
num_part2
;
offset
++
)
{
int64_t
page_idx
=
new_page_start_loc
+
offset
/
page_size
;
int64_t
page_start
=
free_page_ptr
[
page_idx
];
int64_t
output_idx
=
output_start_loc
+
num_part1
+
offset
;
out_indices
[
output_idx
]
=
page_start
*
page_size
+
offset
%
page_size
;
}
if
(
pre_len
+
num_part1
+
num_part2
==
seq_len
)
{
return
;
}
int64_t
num_part3
=
seq_len
-
(
seq_len
/
page_size
)
*
page_size
;
int64_t
last_page_idx
=
new_page_start_loc
+
num_page_start_loc_self
-
1
;
int64_t
start_loc
=
free_page_ptr
[
last_page_idx
];
for
(
int64_t
offset
=
0
;
offset
<
num_part3
&&
offset
<
page_size
;
offset
++
)
{
int64_t
output_idx
=
output_start_loc
+
num_part1
+
num_part2
+
offset
;
out_indices
[
output_idx
]
=
start_loc
*
page_size
+
offset
;
}
}
__global__
void
launch_create_extend_after_decode_spec_info_int32_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int32_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int32_t
accept_length
=
accept_lens_ptr
[
pid
];
int32_t
accept_len_cumsum
=
0
;
for
(
int32_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int32_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int32_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
__global__
void
launch_create_extend_after_decode_spec_info_int64_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int64_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int64_t
accept_length
=
accept_lens_ptr
[
pid
];
int64_t
accept_len_cumsum
=
0
;
for
(
int64_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int64_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int64_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
void
dcu_alloc_decode_kernel
(
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
last_loc_ptr
,
const
at
::
Tensor
free_page_ptr
,
at
::
Tensor
out_indices
,
int64_t
bs
,
int64_t
page_size
)
{
const
int64_t
*
seq_lens_ptr1
=
static_cast
<
const
int64_t
*>
(
seq_lens_ptr
.
data_ptr
());
const
int32_t
*
last_loc_ptr1
=
static_cast
<
const
int32_t
*>
(
last_loc_ptr
.
data_ptr
());
const
int64_t
*
free_page_ptr1
=
static_cast
<
const
int64_t
*>
(
free_page_ptr
.
data_ptr
());
int64_t
*
out_indices1
=
static_cast
<
int64_t
*>
(
out_indices
.
data_ptr
());
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_alloc_decode_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
seq_lens_ptr1
,
last_loc_ptr1
,
free_page_ptr1
,
out_indices1
,
bs
,
page_size
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
)
{
const
int32_t
*
verified_id_ptr
;
const
int64_t
*
seq_lens_ptr
;
const
int32_t
*
accept_lens_ptr_int32
;
const
int64_t
*
accept_lens_ptr_int64
;
int64_t
*
positions_ptr
;
int32_t
*
new_verified_id_ptr
;
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
accept_lens
.
dtype
()
==
torch
::
kInt32
)
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int32
=
static_cast
<
const
int32_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int32_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int32
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
else
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int64
=
static_cast
<
const
int64_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int64_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int64
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
};
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
last_loc_ptr
,
const
at
::
Tensor
free_page_ptr
,
at
::
Tensor
out_indices
,
int64_t
bs
,
int64_t
page_size
)
{
const
int64_t
*
pre_lens_ptr1
=
static_cast
<
const
int64_t
*>
(
pre_lens_ptr
.
data_ptr
());
const
int64_t
*
seq_lens_ptr1
=
static_cast
<
const
int64_t
*>
(
seq_lens_ptr
.
data_ptr
());
const
int64_t
*
last_loc_ptr1
=
static_cast
<
const
int64_t
*>
(
last_loc_ptr
.
data_ptr
());
const
int64_t
*
free_page_ptr1
=
static_cast
<
const
int64_t
*>
(
free_page_ptr
.
data_ptr
());
int64_t
*
out_indices1
=
static_cast
<
int64_t
*>
(
out_indices
.
data_ptr
());
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_alloc_extend_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
pre_lens_ptr1
,
seq_lens_ptr1
,
last_loc_ptr1
,
free_page_ptr1
,
out_indices1
,
bs
,
page_size
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
__global__
void
launch_assign_req_to_token_pool
(
const
int64_t
*
req_pool_indices_ptr
,
int32_t
*
req_to_token_ptr
,
const
int64_t
*
allocate_lens_ptr
,
int64_t
*
new_allocate_lens
,
int64_t
*
out_cache_loc_ptr
,
int64_t
shape
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
kv_start
=
allocate_lens_ptr
[
pid
];
int64_t
kv_end
=
new_allocate_lens
[
pid
];
int64_t
pool_idx
=
req_pool_indices_ptr
[
pid
];
int32_t
*
token_pool
=
(
int32_t
*
)(
req_to_token_ptr
+
pool_idx
*
shape
);
int64_t
sum_out_offset
=
0
;
for
(
int
length_offset
=
0
;
length_offset
<
pid
;
length_offset
++
){
int64_t
start
=
allocate_lens_ptr
[
length_offset
];
int64_t
end
=
new_allocate_lens
[
length_offset
];
sum_out_offset
+=
(
end
-
start
);
}
int64_t
*
out_cache_ptr
=
out_cache_loc_ptr
+
sum_out_offset
;
int64_t
copy_length
=
kv_end
-
kv_start
;
#pragma unroll(32)
for
(
int
out_cache_index
=
0
;
out_cache_index
<
copy_length
;
out_cache_index
++
)
{
token_pool
[
kv_start
+
out_cache_index
]
=
out_cache_ptr
[
out_cache_index
];
}
}
void
dcu_assign_req_to_token_pool
(
const
at
::
Tensor
req_pool_indices_ptr
,
at
::
Tensor
req_to_token_ptr
,
const
at
::
Tensor
allocate_lens_ptr
,
at
::
Tensor
new_allocate_lens
,
at
::
Tensor
out_cache_loc_ptr
,
int64_t
shape
,
int64_t
bs
)
{
const
int64_t
*
req_pool_indices_ptr1
=
static_cast
<
const
int64_t
*>
(
req_pool_indices_ptr
.
data_ptr
());
int32_t
*
req_to_token_ptr1
=
static_cast
<
int32_t
*>
(
req_to_token_ptr
.
data_ptr
());
const
int64_t
*
allocate_lens_ptr1
=
static_cast
<
const
int64_t
*>
(
allocate_lens_ptr
.
data_ptr
());
int64_t
*
new_allocate_lens1
=
static_cast
<
int64_t
*>
(
new_allocate_lens
.
data_ptr
());
int64_t
*
out_cache_loc_ptr1
=
static_cast
<
int64_t
*>
(
out_cache_loc_ptr
.
data_ptr
());
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_assign_req_to_token_pool
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
req_pool_indices_ptr1
,
req_to_token_ptr1
,
allocate_lens_ptr1
,
new_allocate_lens1
,
out_cache_loc_ptr1
,
shape
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
__global__
void
get_last_loc_kernel
(
const
int32_t
*
__restrict__
req_to_token
,
const
int64_t
*
__restrict__
req_pool_indices_tensor
,
const
int64_t
*
__restrict__
prefix_lens_tensor
,
int64_t
*
__restrict__
result
,
int64_t
num_tokens
,
int64_t
req_to_token_stride
){
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
num_tokens
)
return
;
int64_t
pre_len
=
prefix_lens_tensor
[
pid
];
if
(
pre_len
>
0
)
{
int64_t
req_idx
=
req_pool_indices_tensor
[
pid
];
int64_t
token_idx
=
req_idx
*
req_to_token_stride
+
(
pre_len
-
1
);
result
[
pid
]
=
static_cast
<
int64_t
>
(
req_to_token
[
token_idx
]);
}
else
{
result
[
pid
]
=
static_cast
<
int64_t
>
(
-
1
);
}
}
at
::
Tensor
dcu_get_last_loc
(
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
prefix_lens
)
{
TORCH_CHECK
(
req_to_token
.
device
().
is_cuda
(),
"req_to_token must be CUDA tensor"
);
TORCH_CHECK
(
req_pool_indices
.
device
().
is_cuda
(),
"req_pool_indices must be CUDA tensor"
);
TORCH_CHECK
(
prefix_lens
.
device
().
is_cuda
(),
"prefix_lens must be CUDA tensor"
);
TORCH_CHECK
(
req_to_token
.
dim
()
==
2
,
"req_to_token must be 2D tensor [batch, seq_len]"
);
TORCH_CHECK
(
prefix_lens
.
dim
()
==
1
,
"prefix_lens must be 1D"
);
TORCH_CHECK
(
req_pool_indices
.
dim
()
==
1
,
"req_pool_indices must be 1D"
);
int64_t
num_tokens
=
prefix_lens
.
numel
();
TORCH_CHECK
(
req_pool_indices
.
numel
()
==
num_tokens
,
"req_pool_indices must have same length as prefix_lens"
);
int64_t
req_to_token_stride
=
req_to_token
.
stride
(
0
);
auto
req_to_token_c
=
req_to_token
.
contiguous
();
auto
req_pool_indices_c
=
req_pool_indices
.
contiguous
();
auto
prefix_lens_c
=
prefix_lens
.
contiguous
();
const
int32_t
*
req_to_token_ptr
=
req_to_token_c
.
data_ptr
<
int32_t
>
();
const
int64_t
*
req_pool_indices_ptr
=
req_pool_indices_c
.
data_ptr
<
int64_t
>
();
const
int64_t
*
prefix_lens_ptr
=
prefix_lens_c
.
data_ptr
<
int64_t
>
();
auto
result
=
at
::
empty_like
(
prefix_lens_c
);
int64_t
*
result_ptr
=
result
.
data_ptr
<
int64_t
>
();
const
int64_t
block_size
=
64
;
const
int64_t
grid_size
=
(
num_tokens
+
block_size
-
1
)
/
block_size
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
get_last_loc_kernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
req_to_token_ptr
,
req_pool_indices_ptr
,
prefix_lens_ptr
,
result_ptr
,
num_tokens
,
req_to_token_stride
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
return
result
;
}
__global__
void
launch_assign_extend_cache_locs_kernel
(
const
int64_t
*
__restrict__
req_pool_indices
,
// [bs]
const
int32_t
*
__restrict__
req_to_token
,
// [max_num_req, pool_len]
const
int64_t
*
__restrict__
start_offset
,
// [bs]
const
int64_t
*
__restrict__
end_offset
,
// [bs]
int64_t
*
__restrict__
out_cache_loc
,
// [sum(draft_token_num)]
int64_t
pool_len
,
int64_t
bs
)
{
int
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
kv_start
=
start_offset
[
pid
];
int64_t
kv_end
=
end_offset
[
pid
];
int64_t
req_id
=
req_pool_indices
[
pid
];
int64_t
out_offset
=
0
;
for
(
int
i
=
0
;
i
<
pid
;
++
i
)
{
out_offset
+=
end_offset
[
i
]
-
start_offset
[
i
];
}
const
int32_t
*
src
=
req_to_token
+
req_id
*
pool_len
+
kv_start
;
int64_t
*
dst
=
out_cache_loc
+
out_offset
;
for
(
int64_t
i
=
0
;
i
<
kv_end
-
kv_start
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
void
dcu_assign_extend_cache_locs
(
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
start_offset
,
const
at
::
Tensor
end_offset
,
at
::
Tensor
out_cache_loc
,
int64_t
pool_len
,
int64_t
bs
)
{
const
int64_t
*
req_pool_indices_ptr
=
req_pool_indices
.
data_ptr
<
int64_t
>
();
const
int32_t
*
req_to_token_ptr
=
req_to_token
.
data_ptr
<
int32_t
>
();
const
int64_t
*
start_offset_ptr
=
start_offset
.
data_ptr
<
int64_t
>
();
const
int64_t
*
end_offset_ptr
=
end_offset
.
data_ptr
<
int64_t
>
();
int64_t
*
out_cache_loc_ptr
=
out_cache_loc
.
data_ptr
<
int64_t
>
();
constexpr
int64_t
threads
=
128
;
int64_t
blocks
=
(
bs
+
threads
-
1
)
/
threads
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_assign_extend_cache_locs_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
req_pool_indices_ptr
,
req_to_token_ptr
,
start_offset_ptr
,
end_offset_ptr
,
out_cache_loc_ptr
,
pool_len
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
int
PAGED_SIZE
>
__global__
void
dcu_create_flashmla_kv_indices_kernel
(
const
int32_t
*
__restrict__
req_to_token
,
const
int32_t
*
__restrict__
req_pool_indices
,
const
int32_t
*
__restrict__
page_kernel_lens
,
const
int32_t
*
__restrict__
kv_start_idx
,
int32_t
*
__restrict__
kv_indices
,
int
req_to_token_stride
,
int
kv_indices_stride
)
{
int
pid
=
blockIdx
.
x
;
// batch index
int
req_pool_index
=
req_pool_indices
[
pid
];
int
kv_start
=
0
;
int
kv_end
=
0
;
if
(
kv_start_idx
!=
nullptr
)
{
kv_start
=
kv_start_idx
[
pid
];
kv_end
=
kv_start
;
}
kv_end
+=
page_kernel_lens
[
pid
];
int
total_len
=
kv_end
-
kv_start
;
int
num_pages
=
(
total_len
+
PAGED_SIZE
-
1
)
/
PAGED_SIZE
;
for
(
int
pg
=
0
;
pg
<
num_pages
;
++
pg
)
{
int
offset
=
pg
*
PAGED_SIZE
;
// token id = req_to_token[req_pool_index][kv_start + offset]
int64_t
token
=
req_to_token
[
req_pool_index
*
req_to_token_stride
+
kv_start
+
offset
];
// 页索引
kv_indices
[
pid
*
kv_indices_stride
+
pg
]
=
token
/
PAGED_SIZE
;
}
}
void
dcu_create_flashmla_kv_indices
(
const
at
::
Tensor
&
req_to_token
,
const
at
::
Tensor
&
req_pool_indices
,
const
at
::
Tensor
&
page_kernel_lens
,
const
c10
::
optional
<
at
::
Tensor
>&
kv_start_idx
,
at
::
Tensor
&
kv_indices
,
int64_t
req_to_token_stride
,
int64_t
kv_indices_stride
,
int64_t
PAGED_SIZE
)
{
TORCH_CHECK
(
req_to_token
.
is_cuda
(),
"req_to_token must be CUDA tensor"
);
TORCH_CHECK
(
kv_indices
.
is_cuda
(),
"kv_indices must be CUDA tensor"
);
int
bs
=
req_pool_indices
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
grid
(
bs
);
dim3
block
(
1
);
const
int32_t
*
kv_start_idx_ptr
=
nullptr
;
if
(
kv_start_idx
.
has_value
())
{
kv_start_idx_ptr
=
kv_start_idx
.
value
().
data_ptr
<
int32_t
>
();
}
if
(
PAGED_SIZE
==
64
)
{
dcu_create_flashmla_kv_indices_kernel
<
64
><<<
grid
,
block
,
0
,
stream
>>>
(
req_to_token
.
data_ptr
<
int32_t
>
(),
req_pool_indices
.
data_ptr
<
int32_t
>
(),
page_kernel_lens
.
data_ptr
<
int32_t
>
(),
kv_start_idx_ptr
,
kv_indices
.
data_ptr
<
int32_t
>
(),
req_to_token_stride
,
kv_indices_stride
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported PAGED_SIZE"
);
}
}
__global__
void
launch_create_chunked_prefix_cache_kv_indices
(
int32_t
*
req_to_token_ptr
,
const
int64_t
*
req_pool_indices_ptr
,
const
int32_t
*
chunk_starts_ptr
,
const
int32_t
*
chunk_seq_lens_ptr
,
const
int32_t
*
chunk_cu_seq_lens_ptr
,
int32_t
*
chunk_kv_indices_ptr
,
int64_t
col_num
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
req_pool_index
=
req_pool_indices_ptr
[
pid
];
int64_t
chunk_kv_indices_offset
=
chunk_cu_seq_lens_ptr
[
pid
];
int32_t
chunk_start_pos
=
chunk_starts_ptr
[
pid
];
int32_t
chunk_seq_len
=
chunk_seq_lens_ptr
[
pid
];
#pragma unroll(32)
for
(
int32_t
offset
=
0
;
offset
<
chunk_seq_len
;
offset
++
){
chunk_kv_indices_ptr
[
chunk_kv_indices_offset
+
offset
]
=
req_to_token_ptr
[
req_pool_index
*
col_num
+
chunk_start_pos
+
offset
];
}
}
void
dcu_create_chunked_prefix_cache_kv_indices
(
at
::
Tensor
req_to_token_ptr
,
const
at
::
Tensor
req_pool_indices_ptr
,
const
at
::
Tensor
chunk_starts_ptr
,
const
at
::
Tensor
chunk_seq_lens_ptr
,
const
at
::
Tensor
chunk_cu_seq_lens_ptr
,
at
::
Tensor
chunk_kv_indices_ptr
,
int64_t
col_num
,
int64_t
bs
)
{
int32_t
*
req_to_token_ptr1
=
static_cast
<
int32_t
*>
(
req_to_token_ptr
.
data_ptr
());
const
int64_t
*
req_pool_indices_ptr1
=
static_cast
<
const
int64_t
*>
(
req_pool_indices_ptr
.
data_ptr
());
const
int32_t
*
chunk_starts_ptr1
=
static_cast
<
const
int32_t
*>
(
chunk_starts_ptr
.
data_ptr
());
const
int32_t
*
chunk_seq_lens_ptr1
=
static_cast
<
const
int32_t
*>
(
chunk_seq_lens_ptr
.
data_ptr
());
const
int32_t
*
chunk_cu_seq_lens_ptr1
=
static_cast
<
const
int32_t
*>
(
chunk_cu_seq_lens_ptr
.
data_ptr
());
int32_t
*
chunk_kv_indices_ptr1
=
static_cast
<
int32_t
*>
(
chunk_kv_indices_ptr
.
data_ptr
());
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_create_chunked_prefix_cache_kv_indices
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
req_to_token_ptr1
,
req_pool_indices_ptr1
,
chunk_starts_ptr1
,
chunk_seq_lens_ptr1
,
chunk_cu_seq_lens_ptr1
,
chunk_kv_indices_ptr1
,
col_num
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
a1175a4e
...
...
@@ -21,6 +21,7 @@ limitations under the License.
#include "utils.h"
#define WARP_SIZE 64
#define VEC_SIZE 4
using
Vec
=
int4
;
...
...
@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff
int
original
=
v
;
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARP_SIZE
;
offset
<<=
1
)
{
int
n
=
__shfl_up
_sync
(
mask
,
v
,
offset
);
int
n
=
__shfl_up
(
v
,
offset
);
if
((
threadIdx
.
x
&
(
WARP_SIZE
-
1
))
>=
offset
)
v
+=
n
;
}
return
v
-
original
;
...
...
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
View file @
a1175a4e
...
...
@@ -60,7 +60,7 @@ template <typename T>
__device__
float
convert_to_float
(
T
x
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
__half
>
)
{
return
__half2float
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
__
nv
_bfloat16
>
)
{
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
__
hip
_bfloat16
>
)
{
return
__bfloat162float
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
x
;
...
...
@@ -686,8 +686,8 @@ void topk_softmax(
bias_ptr
,
stream
);
}
else
if
(
dtype
==
at
::
ScalarType
::
BFloat16
)
{
topkGatingSoftmaxKernelLauncher
<
__
nv
_bfloat16
>
(
reinterpret_cast
<
const
__
nv
_bfloat16
*>
(
gating_output
.
data_ptr
<
at
::
BFloat16
>
()),
topkGatingSoftmaxKernelLauncher
<
__
hip
_bfloat16
>
(
reinterpret_cast
<
const
__
hip
_bfloat16
*>
(
gating_output
.
data_ptr
<
at
::
BFloat16
>
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
...
...
sgl-kernel/csrc/quantization/gguf/ggml-common.h
View file @
a1175a4e
...
...
@@ -3,7 +3,7 @@
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE_GGUF
32
#define WARP_SIZE_GGUF
64
#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
a1175a4e
...
...
@@ -515,6 +515,75 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
);
void
dcu_create_chunked_prefix_cache_kv_indices
(
at
::
Tensor
req_to_token
,
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
chunk_starts
,
const
at
::
Tensor
chunk_seq_lens
,
const
at
::
Tensor
chunk_cu_seq_lens
,
at
::
Tensor
chunk_kv_indices
,
int64_t
col_num
,
int64_t
bs
);
void
dcu_create_flashmla_kv_indices
(
const
at
::
Tensor
&
req_to_token
,
const
at
::
Tensor
&
req_pool_indices
,
const
at
::
Tensor
&
page_kernel_lens
,
const
c10
::
optional
<
at
::
Tensor
>&
kv_start_idx
,
at
::
Tensor
&
kv_indices
,
int64_t
req_to_token_stride
,
int64_t
kv_indices_stride
,
int64_t
PAGED_SIZE
);
void
dcu_assign_extend_cache_locs
(
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
start_offset
,
const
at
::
Tensor
end_offset
,
at
::
Tensor
out_cache_loc
,
int64_t
pool_len
,
int64_t
bs
);
at
::
Tensor
dcu_get_last_loc
(
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
prefix_lens
);
void
dcu_assign_req_to_token_pool
(
const
at
::
Tensor
req_pool_indices_ptr
,
at
::
Tensor
req_to_token_ptr
,
const
at
::
Tensor
allocate_lens_ptr
,
at
::
Tensor
new_allocate_lens
,
at
::
Tensor
out_cache_loc_ptr
,
int64_t
shape
,
int64_t
bs
);
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
last_loc_ptr
,
const
at
::
Tensor
free_page_ptr
,
at
::
Tensor
out_indices
,
int64_t
bs
,
int64_t
page_size
);
void
dcu_alloc_decode_kernel
(
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
last_loc_ptr
,
const
at
::
Tensor
free_page_ptr
,
at
::
Tensor
out_indices
,
int64_t
bs
,
int64_t
page_size
);
void
transfer_kv_per_layer
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
...
...
sgl-kernel/include/utils.h
View file @
a1175a4e
...
...
@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() {
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#ifndef USE_ROCM
#define WARP_SIZE
32
#define WARP_SIZE
64
#else
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
#define WARP_SIZE 64
...
...
@@ -369,25 +369,25 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) {
#endif
// add FP8 support
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else // USE_ROCM
#if HIP_FP8_TYPE_FNUZ
#include <c10/util/Float8_e4m3fnuz.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#else
#if HIP_FP8_TYPE_E4M3
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#error "fp8 is not supported in this processor (arch < gfx942)."
#endif // HIP_FP8_TYPE_E4M3
#endif // HIP_FP8_TYPE_FNUZ
#endif // USE_ROCM
//
#ifndef USE_ROCM
//
#include <c10/util/Float8_e4m3fn.h>
//
using FP8_TYPE = c10::Float8_e4m3fn;
//
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
//
#else // USE_ROCM
//
#if HIP_FP8_TYPE_FNUZ
//
#include <c10/util/Float8_e4m3fnuz.h>
//
using FP8_TYPE = c10::Float8_e4m3fnuz;
//
constexpr auto FP8_E4M3_MAX = 224.0f;
//
#else
//
#if HIP_FP8_TYPE_E4M3
//
#include <c10/util/Float8_e4m3fn.h>
//
using FP8_TYPE = c10::Float8_e4m3fn;
//
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
//
#else
//
#error "fp8 is not supported in this processor (arch < gfx942)."
//
#endif // HIP_FP8_TYPE_E4M3
//
#endif // HIP_FP8_TYPE_FNUZ
//
#endif // USE_ROCM
#define FULL_MASK 0xffffffff
...
...
sgl-kernel/python/sgl_kernel/flash_mla.py
View file @
a1175a4e
...
...
@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
)
def
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
,
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
,
kv_indices_ptr_stride
,
PAGED_SIZE
=
64
,
):
torch
.
ops
.
sgl_kernel
.
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
,
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
,
kv_indices_ptr_stride
,
PAGED_SIZE
,
)
def
get_mla_metadata
(
cache_seqlens
:
torch
.
Tensor
,
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment