Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f810bda3
Commit
f810bda3
authored
Nov 24, 2025
by
renzhc
Browse files
Merge branch 'v0.5.4_dev' into v0.5.4_rzc
parents
4167eff9
48542418
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1169 additions
and
140 deletions
+1169
-140
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+12
-0
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+70
-29
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+4
-9
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+48
-20
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+131
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+115
-8
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+16
-14
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+29
-1
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+68
-8
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+8
-2
python/sglang/srt/mem_cache/common.py
python/sglang/srt/mem_cache/common.py
+12
-7
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+38
-19
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+10
-3
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+52
-18
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+1
-1
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+155
-0
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+17
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+319
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+44
-0
sgl-kernel/python/sgl_kernel/flash_mla.py
sgl-kernel/python/sgl_kernel/flash_mla.py
+20
-0
No files found.
python/sglang/srt/_custom_ops.py
View file @
f810bda3
...
...
@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
blaslt_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
0
]
k
=
a
.
shape
[
1
]
_
,
out
=
quant_ops
.
hipblaslt_w8a8_gemm
(
a
,
b
,
scale_a
,
scale_b
,
m
,
n
,
k
,
'NT'
,
out_dtype
)
return
out
def
triton_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
...
...
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
f810bda3
...
...
@@ -11,6 +11,8 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sgl_kernel.flash_mla
import
dcu_create_flashmla_kv_indices
from
sglang.srt.utils
import
get_bool_env_var
try
:
from
flash_mla
import
(
...
...
@@ -104,6 +106,7 @@ class DCUMLABackend(AttentionBackend):
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
use_sglang_create_flashmla_kv_indices_triton
=
get_bool_env_var
(
"SGLANG_CREATE_FLASHMLA_KV_INDICES_TRITON"
)
bs
=
forward_batch
.
batch_size
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
...
...
@@ -118,15 +121,27 @@ class DCUMLABackend(AttentionBackend):
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
...
...
@@ -149,15 +164,27 @@ class DCUMLABackend(AttentionBackend):
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
...
...
@@ -185,15 +212,27 @@ class DCUMLABackend(AttentionBackend):
)
# 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
.
to
(
torch
.
int32
),
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
.
to
(
torch
.
int32
),
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
.
to
(
torch
.
int32
),
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
# MLA
mla_metadata
,
num_splits
=
get_mla_metadata
(
...
...
@@ -211,6 +250,7 @@ class DCUMLABackend(AttentionBackend):
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
...
...
@@ -489,9 +529,10 @@ class DCUMLABackend(AttentionBackend):
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
=
None
,
):
if
(
if
((
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
):
if
not
self
.
skip_prefill
:
return
self
.
flashattn_backend
.
forward_extend
(
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
f810bda3
...
...
@@ -674,16 +674,11 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
# if not self.use_mla:
if
k_rope
is
None
:
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
#layer.k_scale, layer.v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
...
...
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
f810bda3
...
...
@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
from
sgl_kernel.flash_mla
import
dcu_create_flashmla_kv_indices
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
use_sglang_create_flashmla_kv_indices_triton
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
bs
=
forward_batch
.
batch_size
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
...
...
@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
...
...
@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
...
...
@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
else
:
super
().
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
...
...
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
f810bda3
...
...
@@ -1013,3 +1013,134 @@ def zero_experts_compute_triton(
)
return
output
from
triton.language.extra
import
libdevice
from
typing
import
Optional
@
triton
.
jit
def
_per_token_quant_int8_one_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
row_id
=
tl
.
program_id
(
0
)
if
tokens_per_expert_ptr
is
not
None
:
e
=
row_id
//
T_dim
t
=
row_id
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
return
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
@
triton
.
jit
def
_per_token_quant_int8_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
E_dim
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
token_idx_start
=
tl
.
program_id
(
0
)
grid_size
=
tl
.
num_programs
(
0
)
num_total_tokens
=
E_dim
*
T_dim
for
token_idx
in
range
(
token_idx_start
,
num_total_tokens
,
grid_size
):
is_valid_token
=
True
if
tokens_per_expert_ptr
is
not
None
:
e
=
token_idx
//
T_dim
t
=
token_idx
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
is_valid_token
=
False
if
is_valid_token
:
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
token_idx
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
token_idx
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
token_idx
,
scale_x
)
def
per_token_quant_int8_triton_opt
(
x
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
):
if
x
.
dim
()
!=
3
:
raise
ValueError
(
f
"Input must be 3D [E, T, H], but got
{
x
.
shape
}
"
)
E
,
T
,
H
=
x
.
shape
N
=
H
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
num_warps
=
1
num_tokens
=
E
*
T
grid_opt
=
num_tokens
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
grid_opt
=
max
(
1
,
num_tokens
//
(
T
//
256
))
_per_token_quant_int8_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
E_dim
=
E
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
else
:
_per_token_quant_int8_one_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
\ No newline at end of file
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
f810bda3
...
...
@@ -2,7 +2,7 @@ from __future__ import annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
collections
import
defaultdict
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
...
...
@@ -20,6 +20,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
ep_scatter
,
silu_and_mul_masked_post_quant_fwd
,
tma_align_input_scale
,
per_token_quant_int8_triton_opt
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopKOutput
...
...
@@ -40,7 +41,7 @@ if TYPE_CHECKING:
DeepEPNormalOutput
,
DispatchOutput
,
)
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
,
m_grouped_w8a8_gemm_nt_masked
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
,
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_w8a8_gemm_nt_contig_asm
,
fuse_silu_mul_quant
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
_is_hip
=
is_hip
()
...
...
@@ -605,6 +606,8 @@ class DeepEPMoE(EPMoE):
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
elif
self
.
use_w4a8_marlin
:
return
self
.
forward_deepgemm_w4a8_marlin_contiguous
(
dispatch_output
)
elif
self
.
use_w8a8_marlin
:
return
self
.
forward_groupgemm_w8a8_marlin_contiguous
(
dispatch_output
)
else
:
raise
ValueError
(
f
"Dispatch output is not supported"
...
...
@@ -709,6 +712,111 @@ class DeepEPMoE(EPMoE):
)
return
expert_output
def
forward_groupgemm_w8a8_marlin_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
hidden_states
,
hidden_states_scale
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
all_tokens
=
sum
(
num_recv_tokens_per_expert
)
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
device
=
hidden_states
.
device
M
=
hidden_states
.
shape
[
0
]
K
=
hidden_states
.
shape
[
1
]
topk
=
topk_idx
.
shape
[
1
]
active_experts
=
set
()
token_expert_pos
=
[
None
]
*
M
for
t
in
range
(
M
):
lst
=
[]
for
pos
in
range
(
topk
):
e
=
int
(
topk_idx
[
t
,
pos
].
item
())
if
e
>=
0
:
lst
.
append
((
e
,
pos
))
active_experts
.
add
(
e
)
token_expert_pos
[
t
]
=
lst
active_experts
=
sorted
(
list
(
active_experts
))
num_active
=
len
(
active_experts
)
if
num_active
==
0
:
return
hidden_states
.
bfloat16
()
counts
=
defaultdict
(
int
)
for
t
in
range
(
M
):
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
counts
[
e
]
+=
1
per_expert_block
=
{}
for
e
in
active_experts
:
cnt
=
counts
.
get
(
e
,
0
)
if
cnt
<=
0
:
per_expert_block
[
e
]
=
0
else
:
needed
=
((
cnt
+
256
-
1
)
//
256
)
*
256
# next multiple of 256
per_expert_block
[
e
]
=
max
(
256
,
needed
)
expert_slot_offset
=
{}
offset
=
0
for
e
in
active_experts
:
expert_slot_offset
[
e
]
=
offset
offset
+=
per_expert_block
[
e
]
pad_M
=
offset
hidden_states_packed
=
torch
.
zeros
((
pad_M
,
K
),
device
=
device
,
dtype
=
hidden_states
.
dtype
)
m_indices
=
torch
.
full
((
pad_M
,),
-
1
,
device
=
device
,
dtype
=
torch
.
int32
)
slot_counters
=
{
e
:
0
for
e
in
active_experts
}
token_row_weight_list
=
{
t
:
[]
for
t
in
range
(
M
)}
for
t
in
range
(
M
):
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
start
=
expert_slot_offset
[
e
]
slot
=
slot_counters
[
e
]
if
slot
>=
per_expert_block
[
e
]:
raise
RuntimeError
(
f
"Internal error: expert
{
e
}
slot
{
slot
}
>= block
{
per_expert_block
[
e
]
}
"
)
row
=
start
+
slot
hidden_states_packed
[
row
]
=
hidden_states
[
t
]
m_indices
[
row
]
=
int
(
e
)
slot_counters
[
e
]
+=
1
w
=
topk_weights
[
t
,
pos
].
to
(
device
=
device
)
w_f
=
w
.
float
()
if
w
.
dtype
!=
torch
.
float32
else
w
token_row_weight_list
[
t
].
append
((
row
,
w_f
))
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states_packed
)
N
=
self
.
w13_weight
.
size
(
1
)
gateup_output
=
torch
.
empty
((
pad_M
,
N
*
16
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a1_all
,
q_a1_scale
),
(
self
.
w13_weight
,
self
.
w13_weight_scale
),
gateup_output
,
m_indices
,
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant
(
gateup_output
)
down_output
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a2_all
,
q_a2_scale
),
(
self
.
w2_weight
,
self
.
w2_weight_scale
),
down_output
,
m_indices
,
)
result
=
torch
.
zeros
((
M
,
K
),
device
=
device
,
dtype
=
down_output
.
dtype
)
for
t
in
range
(
M
):
pairs
=
token_row_weight_list
[
t
]
if
not
pairs
:
continue
acc
=
None
for
(
row
,
w
)
in
pairs
:
vec
=
down_output
[
row
].
float
()
weighted
=
vec
*
w
acc
=
weighted
if
acc
is
None
else
(
acc
+
weighted
)
result
[
t
]
=
acc
.
to
(
result
.
dtype
)
return
result
def
forward_deepgemm_contiguous
(
self
,
...
...
@@ -899,10 +1007,10 @@ class DeepEPMoE(EPMoE):
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
_triton_opt
(
hidden_states
,
masked_m
)
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
...
...
@@ -943,16 +1051,15 @@ class DeepEPMoE(EPMoE):
dispatch_output
:
DeepEPLLOutput
,
):
hidden_states
,
_
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
hidden_states
,
_
,
topk_ids
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
_triton_opt
(
hidden_states
,
masked_m
)
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
f810bda3
...
...
@@ -308,7 +308,7 @@ class _DeepEPDispatcherImplBase:
self
.
params_bytes
=
2
self
.
num_max_dispatch_tokens_per_rank
=
get_int_env_var
(
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
,
128
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
,
64
)
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
...
...
@@ -441,18 +441,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment
=
1
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
# get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
# num_recv_tokens_per_expert,
# num_tokens_per_rank=num_tokens_per_rank,
# num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
# num_tokens_per_expert=num_tokens_per_expert,
# )
self
.
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
recv_topk_ids
=
torch
.
where
(
recv_topk_ids
==
-
1
,
self
.
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
recv_topk_ids
+
self
.
rank_expert_offset
)
if
self
.
quant_config
.
get
(
"quant_method"
)
==
"slimquant_w4a8_marlin"
:
self
.
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
recv_topk_ids
=
torch
.
where
(
recv_topk_ids
==
-
1
,
self
.
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
recv_topk_ids
+
self
.
rank_expert_offset
)
else
:
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
)
return
(
recv_x
,
recv_topk_ids
,
...
...
@@ -541,7 +542,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
self
.
return_recv_hook
=
False
self
.
return_recv_hook
=
return_recv_hook
self
.
device_module
=
torch
.
get_device_module
()
self
.
quant_config
=
{}
...
...
@@ -724,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self
.
packed_recv_count
=
self
.
handle
=
None
return
combined_hidden_states
,
event
,
hook
@
torch
.
_dynamo
.
disable
()
def
_get_buffer
(
self
):
DeepEPBuffer
.
set_dispatch_mode_as_low_latency
()
return
DeepEPBuffer
.
get_deepep_buffer
(
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
f810bda3
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
os
import
logging
from
contextlib
import
suppress
from
typing
import
Any
,
Dict
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
cast
...
...
@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt.utils
import
W8a8GetCacheJSON
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"CompressedTensorsLinearMethod"
]
...
...
@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
[
n
,
k
]
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
([
n
,
k
])
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
elif
self
.
w8a8_strategy
==
3
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
T
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
self
.
tritonsingleton
.
gen_model_json
()
layer
.
scheme
.
process_weights_after_loading
(
layer
)
def
create_weights
(
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
f810bda3
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Callable
,
Optional
import
torch
...
...
@@ -19,11 +20,13 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.utils
import
is_cuda
from
lmslim
import
quant_ops
from
sglang.srt
import
_custom_ops
as
ops
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
from
sglang.srt.utils
import
W8a8GetCacheJSON
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
...
...
@@ -33,6 +36,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
input_symmetric
=
input_symmetric
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
# TODO
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
...
...
@@ -163,14 +167,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
@
torch
.
_dynamo
.
disable
()
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
# TODO: add cutlass_scaled_mm_azp support
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
# TODO: fix with lmslim/lightop
return
quant_ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
# return quant_ops.custom_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias)
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
elif
self
.
w8a8_strategy
==
3
:
return
ops
.
blaslt_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
None
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
python/sglang/srt/layers/quantization/slimquant_w4a8.py
View file @
f810bda3
...
...
@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8
,
per_token_quant_int8
)
from
sglang.srt
import
_custom_ops
as
ops
from
vllm
.utils
import
W8a8GetCacheJSON
from
sglang.srt
.utils
import
W8a8GetCacheJSON
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
import
os
...
...
@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需要lmslim/lightop配合
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
elif
self
.
w8a8_strategy
==
3
:
return
ops
.
blaslt_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
None
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
...
...
python/sglang/srt/mem_cache/common.py
View file @
f810bda3
...
...
@@ -13,7 +13,8 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from
sglang.srt.mem_cache.mamba_radix_cache
import
MambaRadixCache
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
,
ReqToTokenPool
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
support_triton
from
sglang.srt.utils
import
support_triton
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_get_last_loc
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
...
...
@@ -125,13 +126,17 @@ def get_last_loc(
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
(
get_global_server_args
().
attention_backend
!=
"ascend"
and
get_global_server_args
().
attention_backend
!=
"torch_native"
):
impl
=
get_last_loc_triton
use_sglang_get_last_loc
=
get_bool_env_var
(
"SGLANG_GET_LAST_LOC"
)
if
use_sglang_get_last_loc
:
impl
=
dcu_get_last_loc
else
:
impl
=
get_last_loc_torch
if
(
get_global_server_args
().
attention_backend
!=
"ascend"
and
get_global_server_args
().
attention_backend
!=
"torch_native"
):
impl
=
get_last_loc_triton
else
:
impl
=
get_last_loc_torch
return
impl
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
f810bda3
...
...
@@ -46,7 +46,11 @@ from sglang.srt.layers.dp_attention import (
set_dp_buffer_len
,
set_is_extend_in_batch
,
)
from
sglang.srt.utils
import
get_compiler_backend
,
is_npu
,
support_triton
from
sglang.srt.utils
import
get_compiler_backend
,
is_npu
,
support_triton
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_create_chunked_prefix_cache_kv_indices
import
logging
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
...
...
@@ -123,13 +127,13 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
def
is_extend_or_draft_extend_or_mixed
(
self
):
#nhb
def
is_extend_or_draft_extend_or_mixed
(
self
):
return
(
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
DRAFT_EXTEND_V2
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
)
def
is_cuda_graph
(
self
):
...
...
@@ -317,6 +321,8 @@ class ForwardBatch:
tbo_parent_token_range
:
Optional
[
Tuple
[
int
,
int
]]
=
None
tbo_children
:
Optional
[
List
[
ForwardBatch
]]
=
None
use_sglang_create_chunked_prefix_cache_kv_indices
=
get_bool_env_var
(
"SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES"
)
@
classmethod
def
init_new
(
cls
,
...
...
@@ -363,13 +369,13 @@ class ForwardBatch:
if
batch
.
extend_input_logprob_token_ids
is
not
None
:
ret
.
extend_input_logprob_token_ids_gpu
=
(
batch
.
extend_input_logprob_token_ids
.
to
(
device
,
non_blocking
=
True
)
batch
.
extend_input_logprob_token_ids
.
pin_memory
().
to
(
device
,
non_blocking
=
True
)
)
if
enable_num_token_non_padded
(
model_runner
.
server_args
):
ret
.
num_token_non_padded
=
torch
.
tensor
(
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
ret
.
num_token_non_padded_cpu
=
len
(
batch
.
input_ids
)
# For MLP sync
...
...
@@ -389,12 +395,12 @@ class ForwardBatch:
ret
.
global_num_tokens_cpu
=
global_num_tokens
ret
.
global_num_tokens_gpu
=
torch
.
tensor
(
global_num_tokens
,
dtype
=
torch
.
int64
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
ret
.
global_num_tokens_for_logprob_cpu
=
global_num_tokens_for_logprob
ret
.
global_num_tokens_for_logprob_gpu
=
torch
.
tensor
(
global_num_tokens_for_logprob
,
dtype
=
torch
.
int64
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
if
ret
.
forward_mode
.
is_idle
():
ret
.
positions
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
device
)
...
...
@@ -419,10 +425,10 @@ class ForwardBatch:
assert
isinstance
(
batch
.
extend_prefix_lens
,
list
)
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
extend_prefix_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
positions
,
ret
.
extend_start_loc
=
compute_position
(
model_runner
.
server_args
.
attention_backend
,
...
...
@@ -635,15 +641,28 @@ class ForwardBatch:
num_chunk_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
create_chunked_prefix_cache_kv_indices
[(
self
.
batch_size
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
chunk_starts
,
chunk_seq_lens
,
chunk_cu_seq_lens
,
chunk_kv_indices
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
if
self
.
use_sglang_create_chunked_prefix_cache_kv_indices
:
dcu_create_chunked_prefix_cache_kv_indices
(
req_to_token
=
self
.
req_to_token_pool
.
req_to_token
,
req_pool_indices
=
self
.
req_pool_indices
,
chunk_starts
=
chunk_starts
,
chunk_seq_lens
=
chunk_seq_lens
,
chunk_cu_seq_lens
=
chunk_cu_seq_lens
,
chunk_kv_indices
=
chunk_kv_indices
,
col_num
=
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
=
self
.
batch_size
,
)
else
:
# logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
create_chunked_prefix_cache_kv_indices
[(
self
.
batch_size
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
chunk_starts
,
chunk_seq_lens
,
chunk_cu_seq_lens
,
chunk_kv_indices
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
self
.
prefix_chunk_kv_indices
.
append
(
chunk_kv_indices
)
def
_pad_tensor_to_size
(
self
,
tensor
:
torch
.
Tensor
,
size
:
int
,
*
,
value
:
int
=
0
):
...
...
python/sglang/srt/speculative/draft_utils.py
View file @
f810bda3
...
...
@@ -237,7 +237,14 @@ class DraftBackendFactory:
return
None
def
_create_dcumla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
# logger.warning(
# "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
None
return
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
python/sglang/srt/speculative/eagle_info_v2.py
View file @
f810bda3
...
...
@@ -29,6 +29,12 @@ from sglang.srt.speculative.spec_utils import (
)
from
sglang.srt.utils.common
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_assign_req_to_token_pool
,
dcu_assign_extend_cache_locs
import
logging
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
...
...
@@ -77,6 +83,9 @@ def assign_draft_cache_locs_page_size_1(
@
dataclass
class
EagleDraftInputV2Mixin
:
use_sglang_assign_req_to_token_pool
=
get_bool_env_var
(
"SGLANG_ASSIGN_REQ_TO_TOKEN_POOL"
)
def
prepare_for_decode
(
self
:
EagleDraftInput
,
batch
:
ScheduleBatch
):
from
sglang.srt.speculative.spec_utils
import
assign_req_to_token_pool
...
...
@@ -112,15 +121,26 @@ class EagleDraftInputV2Mixin:
extend_num_tokens
,
)
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
self
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
if
self
.
use_sglang_assign_req_to_token_pool
:
dcu_assign_req_to_token_pool
(
req_pool_indices
=
batch
.
req_pool_indices
,
req_to_token
=
batch
.
req_to_token_pool
.
req_to_token
,
allocate_lens
=
self
.
allocate_lens
,
new_allocate_lens
=
new_allocate_lens
,
out_cache_loc
=
out_cache_loc
,
shape
=
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
=
bs
,
)
else
:
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
self
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
self
.
allocate_lens
=
new_allocate_lens
# FIXME(lsyin): make this sync optional
...
...
@@ -189,6 +209,9 @@ class EagleDraftInputV2Mixin:
@
dataclass
class
EagleVerifyInputV2Mixin
:
use_sglang_assign_extend_cache_locs
=
get_bool_env_var
(
"SGLANG_ASSIGN_EXTEND_CACHE_LOCS"
)
def
prepare_for_v2_verify
(
self
:
EagleVerifyInput
,
req_to_token_pool
:
ReqToTokenPool
,
...
...
@@ -205,15 +228,26 @@ class EagleVerifyInputV2Mixin:
device
=
device
,
)
assign_extend_cache_locs
[(
bs
,)](
batch
.
req_pool_indices
,
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
batch
.
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
if
self
.
use_sglang_assign_extend_cache_locs
:
dcu_assign_extend_cache_locs
(
batch
.
req_pool_indices
,
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
batch
.
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
,
)
else
:
assign_extend_cache_locs
[(
bs
,)](
batch
.
req_pool_indices
,
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
batch
.
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
# Get a forward batch
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
...
...
python/sglang/srt/two_batch_overlap.py
View file @
f810bda3
...
...
@@ -758,7 +758,7 @@ class TboForwardBatchPreparer:
# TODO we may make padding on both sub-batches to make it slightly more balanced
value_a
=
min
(
tbo_split_token_index
,
num_token_non_padded
)
value_b
=
max
(
0
,
num_token_non_padded
-
tbo_split_token_index
)
return
torch
.
tensor
([
value_a
,
value_b
],
dtype
=
torch
.
int32
).
to
(
return
torch
.
tensor
([
value_a
,
value_b
],
dtype
=
torch
.
int32
).
pin_memory
().
to
(
device
=
get_global_server_args
().
device
,
non_blocking
=
True
)
...
...
python/sglang/srt/utils/common.py
View file @
f810bda3
...
...
@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None):
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
# from vllm
class
W8a8GetCacheJSON
:
_instance
=
None
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
(
W8a8GetCacheJSON
,
cls
).
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
.
_initialize
()
return
cls
.
_instance
def
_initialize
(
self
):
current_folder_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
json_folder_path
=
current_folder_path
+
'/../../lmslim/configs/w8a8'
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
json_folder_path
))
self
.
triton_json_dict
=
{}
self
.
triton_moejson_dict
=
{}
self
.
triton_json_list
=
[]
self
.
weight_shapes
=
[]
self
.
moe_weight_shapes
=
[]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
self
.
device_name
=
device_name
self
.
topk
=
1
self
.
quant_method
=
None
#析构函数,最后会生成model.json的配置文件
def
gen_model_json
(
self
,
E
:
Optional
[
int
]
=
0
,
block_size
:
Optional
[
list
]
=
None
):
json_dir
=
os
.
getenv
(
'LMSLIM_TUNING_JSON'
,
"None"
)
if
json_dir
!=
"None"
and
os
.
path
.
exists
(
json_dir
):
#生成模型配置文件
# logger.info("model_tuning.json is at LMSLIM_TUNING_JSON:%s", json_dir)
config
=
{
"layers"
:
{
"linear"
:
{
"shapes"
:
[],
"m_range"
:
"None"
,
},
"moe"
:
{
"shapes"
:
[],
"m_range"
:
"None"
,
"topk"
:
self
.
topk
}
},
"quantization_config"
:
{
"quant_method"
:
self
.
quant_method
,
"weight_block_size"
:
"None"
}
}
# 处理 MoE shapes
for
shape
in
self
.
moe_weight_shapes
:
if
len
(
shape
)
==
4
:
# 假设 MoE shape 是 [N1, N2,K] 格式
moe_config
=
{
"E"
:
shape
[
0
],
"N1"
:
shape
[
1
],
"N2"
:
shape
[
2
],
"K"
:
shape
[
3
],
# 默认值
}
config
[
"layers"
][
"moe"
][
"shapes"
].
append
(
moe_config
)
for
shape
in
self
.
weight_shapes
:
config
[
"layers"
][
"linear"
][
"shapes"
].
append
(
shape
)
if
block_size
is
not
None
:
config
[
"quantization_config"
][
"weight_block_size"
]
=
block_size
with
open
(
json_dir
+
"/model.json"
,
'w'
)
as
f
:
json
.
dump
(
config
,
f
,
indent
=
4
)
# else:
# logger.info("LMSLIM_TUNING_JSON is not set")
def
getspec_config
(
self
,
configs_dict
,
M
,
N
,
K
):
if
f
"
{
M
}
_
{
N
}
_
{
K
}
"
in
configs_dict
:
return
configs_dict
[
f
"
{
M
}
_
{
N
}
_
{
K
}
"
]
else
:
return
None
def
get_triton_cache
(
self
,
file_path
,
n
,
k
):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
return
configs_dict
def
get_w8a8json_name
(
self
,
n
,
k
):
return
self
.
triton_json_dir
+
f
"/W8A8_
{
n
}
_
{
k
}
_
{
self
.
device_name
}
.json"
def
get_blockint8_triton_cache
(
self
,
file_path
,
n
,
k
,
block_n
,
block_k
):
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
return
configs_dict
def
get_blockint8json_name
(
self
,
n
,
k
,
block_n
,
block_k
):
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
):
if
use_int4_w4a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config1,config2]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
return
configs_dict
sgl-kernel/csrc/common_extension_rocm.cc
View file @
f810bda3
...
...
@@ -19,6 +19,14 @@ limitations under the License.
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND
(
sgl_kernel
,
m
)
{
/*
* From FlashMLA
*/
m
.
def
(
"dcu_create_flashmla_kv_indices(Tensor req_to_token, Tensor req_pool_indices,Tensor page_kernel_lens, Tensor? kv_start_idx, Tensor kv_indices, int req_to_token_stride, int kv_indices_stride, int PAGED_SIZE) -> ()"
);
m
.
impl
(
"dcu_create_flashmla_kv_indices"
,
torch
::
kCUDA
,
&
dcu_create_flashmla_kv_indices
);
/*
* From csrc/activation
*/
...
...
@@ -133,6 +141,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/
m
.
def
(
"dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"
);
m
.
impl
(
"dcu_create_extend_after_decode_spec_info"
,
torch
::
kCUDA
,
&
dcu_create_extend_after_decode_spec_info
);
m
.
def
(
"dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()"
);
m
.
impl
(
"dcu_create_chunked_prefix_cache_kv_indices"
,
torch
::
kCUDA
,
&
dcu_create_chunked_prefix_cache_kv_indices
);
m
.
def
(
"dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()"
);
m
.
impl
(
"dcu_assign_extend_cache_locs"
,
torch
::
kCUDA
,
&
dcu_assign_extend_cache_locs
);
m
.
def
(
"dcu_get_last_loc(Tensor req_to_token, Tensor req_pool_indices, Tensor prefix_lens) -> Tensor"
);
m
.
impl
(
"dcu_get_last_loc"
,
torch
::
kCUDA
,
&
dcu_get_last_loc
);
m
.
def
(
"dcu_assign_req_to_token_pool(Tensor req_pool_indices_ptr,Tensor req_to_token_ptr,Tensor allocate_lens_ptr,Tensor new_allocate_lens,Tensor out_cache_loc_ptr,int shape,int bs) -> ()"
);
m
.
impl
(
"dcu_assign_req_to_token_pool"
,
torch
::
kCUDA
,
&
dcu_assign_req_to_token_pool
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
f810bda3
...
...
@@ -836,4 +836,322 @@ void dcu_alloc_extend_kernel(
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_alloc_extend_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
pre_lens_ptr1
,
seq_lens_ptr1
,
last_loc_ptr1
,
free_page_ptr1
,
out_indices1
,
bs
,
page_size
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
\ No newline at end of file
}
__global__
void
launch_assign_req_to_token_pool
(
const
int64_t
*
req_pool_indices_ptr
,
int32_t
*
req_to_token_ptr
,
const
int64_t
*
allocate_lens_ptr
,
int64_t
*
new_allocate_lens
,
int64_t
*
out_cache_loc_ptr
,
int64_t
shape
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
kv_start
=
allocate_lens_ptr
[
pid
];
int64_t
kv_end
=
new_allocate_lens
[
pid
];
int64_t
pool_idx
=
req_pool_indices_ptr
[
pid
];
int32_t
*
token_pool
=
(
int32_t
*
)(
req_to_token_ptr
+
pool_idx
*
shape
);
int64_t
sum_out_offset
=
0
;
for
(
int
length_offset
=
0
;
length_offset
<
pid
;
length_offset
++
){
int64_t
start
=
allocate_lens_ptr
[
length_offset
];
int64_t
end
=
new_allocate_lens
[
length_offset
];
sum_out_offset
+=
(
end
-
start
);
}
int64_t
*
out_cache_ptr
=
out_cache_loc_ptr
+
sum_out_offset
;
int64_t
copy_length
=
kv_end
-
kv_start
;
#pragma unroll(32)
for
(
int
out_cache_index
=
0
;
out_cache_index
<
copy_length
;
out_cache_index
++
)
{
token_pool
[
kv_start
+
out_cache_index
]
=
out_cache_ptr
[
out_cache_index
];
}
}
void
dcu_assign_req_to_token_pool
(
const
at
::
Tensor
req_pool_indices_ptr
,
at
::
Tensor
req_to_token_ptr
,
const
at
::
Tensor
allocate_lens_ptr
,
at
::
Tensor
new_allocate_lens
,
at
::
Tensor
out_cache_loc_ptr
,
int64_t
shape
,
int64_t
bs
)
{
const
int64_t
*
req_pool_indices_ptr1
=
static_cast
<
const
int64_t
*>
(
req_pool_indices_ptr
.
data_ptr
());
int32_t
*
req_to_token_ptr1
=
static_cast
<
int32_t
*>
(
req_to_token_ptr
.
data_ptr
());
const
int64_t
*
allocate_lens_ptr1
=
static_cast
<
const
int64_t
*>
(
allocate_lens_ptr
.
data_ptr
());
int64_t
*
new_allocate_lens1
=
static_cast
<
int64_t
*>
(
new_allocate_lens
.
data_ptr
());
int64_t
*
out_cache_loc_ptr1
=
static_cast
<
int64_t
*>
(
out_cache_loc_ptr
.
data_ptr
());
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_assign_req_to_token_pool
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
req_pool_indices_ptr1
,
req_to_token_ptr1
,
allocate_lens_ptr1
,
new_allocate_lens1
,
out_cache_loc_ptr1
,
shape
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
__global__
void
get_last_loc_kernel
(
const
int32_t
*
__restrict__
req_to_token
,
const
int64_t
*
__restrict__
req_pool_indices_tensor
,
const
int64_t
*
__restrict__
prefix_lens_tensor
,
int64_t
*
__restrict__
result
,
int64_t
num_tokens
,
int64_t
req_to_token_stride
){
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
num_tokens
)
return
;
int64_t
pre_len
=
prefix_lens_tensor
[
pid
];
if
(
pre_len
>
0
)
{
int64_t
req_idx
=
req_pool_indices_tensor
[
pid
];
int64_t
token_idx
=
req_idx
*
req_to_token_stride
+
(
pre_len
-
1
);
result
[
pid
]
=
static_cast
<
int64_t
>
(
req_to_token
[
token_idx
]);
}
else
{
result
[
pid
]
=
static_cast
<
int64_t
>
(
-
1
);
}
}
at
::
Tensor
dcu_get_last_loc
(
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
prefix_lens
)
{
TORCH_CHECK
(
req_to_token
.
device
().
is_cuda
(),
"req_to_token must be CUDA tensor"
);
TORCH_CHECK
(
req_pool_indices
.
device
().
is_cuda
(),
"req_pool_indices must be CUDA tensor"
);
TORCH_CHECK
(
prefix_lens
.
device
().
is_cuda
(),
"prefix_lens must be CUDA tensor"
);
TORCH_CHECK
(
req_to_token
.
dim
()
==
2
,
"req_to_token must be 2D tensor [batch, seq_len]"
);
TORCH_CHECK
(
prefix_lens
.
dim
()
==
1
,
"prefix_lens must be 1D"
);
TORCH_CHECK
(
req_pool_indices
.
dim
()
==
1
,
"req_pool_indices must be 1D"
);
int64_t
num_tokens
=
prefix_lens
.
numel
();
TORCH_CHECK
(
req_pool_indices
.
numel
()
==
num_tokens
,
"req_pool_indices must have same length as prefix_lens"
);
int64_t
req_to_token_stride
=
req_to_token
.
stride
(
0
);
auto
req_to_token_c
=
req_to_token
.
contiguous
();
auto
req_pool_indices_c
=
req_pool_indices
.
contiguous
();
auto
prefix_lens_c
=
prefix_lens
.
contiguous
();
const
int32_t
*
req_to_token_ptr
=
req_to_token_c
.
data_ptr
<
int32_t
>
();
const
int64_t
*
req_pool_indices_ptr
=
req_pool_indices_c
.
data_ptr
<
int64_t
>
();
const
int64_t
*
prefix_lens_ptr
=
prefix_lens_c
.
data_ptr
<
int64_t
>
();
auto
result
=
at
::
empty_like
(
prefix_lens_c
);
int64_t
*
result_ptr
=
result
.
data_ptr
<
int64_t
>
();
const
int64_t
block_size
=
64
;
const
int64_t
grid_size
=
(
num_tokens
+
block_size
-
1
)
/
block_size
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
get_last_loc_kernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
req_to_token_ptr
,
req_pool_indices_ptr
,
prefix_lens_ptr
,
result_ptr
,
num_tokens
,
req_to_token_stride
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
return
result
;
}
__global__
void
launch_assign_extend_cache_locs_kernel
(
const
int64_t
*
__restrict__
req_pool_indices
,
// [bs]
const
int32_t
*
__restrict__
req_to_token
,
// [max_num_req, pool_len]
const
int64_t
*
__restrict__
start_offset
,
// [bs]
const
int64_t
*
__restrict__
end_offset
,
// [bs]
int64_t
*
__restrict__
out_cache_loc
,
// [sum(draft_token_num)]
int64_t
pool_len
,
int64_t
bs
)
{
int
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
kv_start
=
start_offset
[
pid
];
int64_t
kv_end
=
end_offset
[
pid
];
int64_t
req_id
=
req_pool_indices
[
pid
];
int64_t
out_offset
=
0
;
for
(
int
i
=
0
;
i
<
pid
;
++
i
)
{
out_offset
+=
end_offset
[
i
]
-
start_offset
[
i
];
}
const
int32_t
*
src
=
req_to_token
+
req_id
*
pool_len
+
kv_start
;
int64_t
*
dst
=
out_cache_loc
+
out_offset
;
for
(
int64_t
i
=
0
;
i
<
kv_end
-
kv_start
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
void
dcu_assign_extend_cache_locs
(
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
start_offset
,
const
at
::
Tensor
end_offset
,
at
::
Tensor
out_cache_loc
,
int64_t
pool_len
,
int64_t
bs
)
{
const
int64_t
*
req_pool_indices_ptr
=
req_pool_indices
.
data_ptr
<
int64_t
>
();
const
int32_t
*
req_to_token_ptr
=
req_to_token
.
data_ptr
<
int32_t
>
();
const
int64_t
*
start_offset_ptr
=
start_offset
.
data_ptr
<
int64_t
>
();
const
int64_t
*
end_offset_ptr
=
end_offset
.
data_ptr
<
int64_t
>
();
int64_t
*
out_cache_loc_ptr
=
out_cache_loc
.
data_ptr
<
int64_t
>
();
constexpr
int64_t
threads
=
128
;
int64_t
blocks
=
(
bs
+
threads
-
1
)
/
threads
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_assign_extend_cache_locs_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
req_pool_indices_ptr
,
req_to_token_ptr
,
start_offset_ptr
,
end_offset_ptr
,
out_cache_loc_ptr
,
pool_len
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
int
PAGED_SIZE
>
__global__
void
dcu_create_flashmla_kv_indices_kernel
(
const
int32_t
*
__restrict__
req_to_token
,
const
int32_t
*
__restrict__
req_pool_indices
,
const
int32_t
*
__restrict__
page_kernel_lens
,
const
int32_t
*
__restrict__
kv_start_idx
,
int32_t
*
__restrict__
kv_indices
,
int
req_to_token_stride
,
int
kv_indices_stride
)
{
int
pid
=
blockIdx
.
x
;
// batch index
int
req_pool_index
=
req_pool_indices
[
pid
];
int
kv_start
=
0
;
int
kv_end
=
0
;
if
(
kv_start_idx
!=
nullptr
)
{
kv_start
=
kv_start_idx
[
pid
];
kv_end
=
kv_start
;
}
kv_end
+=
page_kernel_lens
[
pid
];
int
total_len
=
kv_end
-
kv_start
;
int
num_pages
=
(
total_len
+
PAGED_SIZE
-
1
)
/
PAGED_SIZE
;
for
(
int
pg
=
0
;
pg
<
num_pages
;
++
pg
)
{
int
offset
=
pg
*
PAGED_SIZE
;
// token id = req_to_token[req_pool_index][kv_start + offset]
int64_t
token
=
req_to_token
[
req_pool_index
*
req_to_token_stride
+
kv_start
+
offset
];
// 页索引
kv_indices
[
pid
*
kv_indices_stride
+
pg
]
=
token
/
PAGED_SIZE
;
}
}
void
dcu_create_flashmla_kv_indices
(
const
at
::
Tensor
&
req_to_token
,
const
at
::
Tensor
&
req_pool_indices
,
const
at
::
Tensor
&
page_kernel_lens
,
const
c10
::
optional
<
at
::
Tensor
>&
kv_start_idx
,
at
::
Tensor
&
kv_indices
,
int64_t
req_to_token_stride
,
int64_t
kv_indices_stride
,
int64_t
PAGED_SIZE
)
{
TORCH_CHECK
(
req_to_token
.
is_cuda
(),
"req_to_token must be CUDA tensor"
);
TORCH_CHECK
(
kv_indices
.
is_cuda
(),
"kv_indices must be CUDA tensor"
);
int
bs
=
req_pool_indices
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
grid
(
bs
);
dim3
block
(
1
);
const
int32_t
*
kv_start_idx_ptr
=
nullptr
;
if
(
kv_start_idx
.
has_value
())
{
kv_start_idx_ptr
=
kv_start_idx
.
value
().
data_ptr
<
int32_t
>
();
}
if
(
PAGED_SIZE
==
64
)
{
dcu_create_flashmla_kv_indices_kernel
<
64
><<<
grid
,
block
,
0
,
stream
>>>
(
req_to_token
.
data_ptr
<
int32_t
>
(),
req_pool_indices
.
data_ptr
<
int32_t
>
(),
page_kernel_lens
.
data_ptr
<
int32_t
>
(),
kv_start_idx_ptr
,
kv_indices
.
data_ptr
<
int32_t
>
(),
req_to_token_stride
,
kv_indices_stride
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported PAGED_SIZE"
);
}
}
__global__
void
launch_create_chunked_prefix_cache_kv_indices
(
int32_t
*
req_to_token_ptr
,
const
int64_t
*
req_pool_indices_ptr
,
const
int32_t
*
chunk_starts_ptr
,
const
int32_t
*
chunk_seq_lens_ptr
,
const
int32_t
*
chunk_cu_seq_lens_ptr
,
int32_t
*
chunk_kv_indices_ptr
,
int64_t
col_num
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
req_pool_index
=
req_pool_indices_ptr
[
pid
];
int64_t
chunk_kv_indices_offset
=
chunk_cu_seq_lens_ptr
[
pid
];
int32_t
chunk_start_pos
=
chunk_starts_ptr
[
pid
];
int32_t
chunk_seq_len
=
chunk_seq_lens_ptr
[
pid
];
#pragma unroll(32)
for
(
int32_t
offset
=
0
;
offset
<
chunk_seq_len
;
offset
++
){
chunk_kv_indices_ptr
[
chunk_kv_indices_offset
+
offset
]
=
req_to_token_ptr
[
req_pool_index
*
col_num
+
chunk_start_pos
+
offset
];
}
}
void
dcu_create_chunked_prefix_cache_kv_indices
(
at
::
Tensor
req_to_token_ptr
,
const
at
::
Tensor
req_pool_indices_ptr
,
const
at
::
Tensor
chunk_starts_ptr
,
const
at
::
Tensor
chunk_seq_lens_ptr
,
const
at
::
Tensor
chunk_cu_seq_lens_ptr
,
at
::
Tensor
chunk_kv_indices_ptr
,
int64_t
col_num
,
int64_t
bs
)
{
int32_t
*
req_to_token_ptr1
=
static_cast
<
int32_t
*>
(
req_to_token_ptr
.
data_ptr
());
const
int64_t
*
req_pool_indices_ptr1
=
static_cast
<
const
int64_t
*>
(
req_pool_indices_ptr
.
data_ptr
());
const
int32_t
*
chunk_starts_ptr1
=
static_cast
<
const
int32_t
*>
(
chunk_starts_ptr
.
data_ptr
());
const
int32_t
*
chunk_seq_lens_ptr1
=
static_cast
<
const
int32_t
*>
(
chunk_seq_lens_ptr
.
data_ptr
());
const
int32_t
*
chunk_cu_seq_lens_ptr1
=
static_cast
<
const
int32_t
*>
(
chunk_cu_seq_lens_ptr
.
data_ptr
());
int32_t
*
chunk_kv_indices_ptr1
=
static_cast
<
int32_t
*>
(
chunk_kv_indices_ptr
.
data_ptr
());
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_create_chunked_prefix_cache_kv_indices
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
req_to_token_ptr1
,
req_pool_indices_ptr1
,
chunk_starts_ptr1
,
chunk_seq_lens_ptr1
,
chunk_cu_seq_lens_ptr1
,
chunk_kv_indices_ptr1
,
col_num
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
f810bda3
...
...
@@ -538,6 +538,7 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
...
...
@@ -545,6 +546,49 @@ void dcu_create_extend_after_decode_spec_info(
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
);
void
dcu_create_chunked_prefix_cache_kv_indices
(
at
::
Tensor
req_to_token
,
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
chunk_starts
,
const
at
::
Tensor
chunk_seq_lens
,
const
at
::
Tensor
chunk_cu_seq_lens
,
at
::
Tensor
chunk_kv_indices
,
int64_t
col_num
,
int64_t
bs
);
void
dcu_create_flashmla_kv_indices
(
const
at
::
Tensor
&
req_to_token
,
const
at
::
Tensor
&
req_pool_indices
,
const
at
::
Tensor
&
page_kernel_lens
,
const
c10
::
optional
<
at
::
Tensor
>&
kv_start_idx
,
at
::
Tensor
&
kv_indices
,
int64_t
req_to_token_stride
,
int64_t
kv_indices_stride
,
int64_t
PAGED_SIZE
);
void
dcu_assign_extend_cache_locs
(
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
start_offset
,
const
at
::
Tensor
end_offset
,
at
::
Tensor
out_cache_loc
,
int64_t
pool_len
,
int64_t
bs
);
at
::
Tensor
dcu_get_last_loc
(
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
prefix_lens
);
void
dcu_assign_req_to_token_pool
(
const
at
::
Tensor
req_pool_indices_ptr
,
at
::
Tensor
req_to_token_ptr
,
const
at
::
Tensor
allocate_lens_ptr
,
at
::
Tensor
new_allocate_lens
,
at
::
Tensor
out_cache_loc_ptr
,
int64_t
shape
,
int64_t
bs
);
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
...
...
sgl-kernel/python/sgl_kernel/flash_mla.py
View file @
f810bda3
...
...
@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
)
def
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
,
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
,
kv_indices_ptr_stride
,
PAGED_SIZE
=
64
,
):
torch
.
ops
.
sgl_kernel
.
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
,
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
,
kv_indices_ptr_stride
,
PAGED_SIZE
,
)
def
get_mla_metadata
(
cache_seqlens
:
torch
.
Tensor
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment