Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a50eb0e6
Commit
a50eb0e6
authored
Nov 19, 2025
by
yiqa
Browse files
Merge remote-tracking branch 'origin/v0.5.4_dev_yiqa' into v0.5.4_dev_yiqa
parents
474b26ab
0dc51b09
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
500 additions
and
186 deletions
+500
-186
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+106
-47
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+125
-107
python/sglang/srt/layers/attention/flashattention_interface.py
...n/sglang/srt/layers/attention/flashattention_interface.py
+15
-14
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+69
-2
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+22
-3
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/mem_cache/common.py
python/sglang/srt/mem_cache/common.py
+2
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+22
-9
sgl-kernel/csrc/attention/merge_attn_states.cu
sgl-kernel/csrc/attention/merge_attn_states.cu
+4
-0
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+3
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+102
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+16
-0
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
a50eb0e6
...
@@ -103,54 +103,112 @@ class DCUMLABackend(AttentionBackend):
...
@@ -103,54 +103,112 @@ class DCUMLABackend(AttentionBackend):
skip_prefill
=
False
,
skip_prefill
=
False
,
)
)
def
_build_decode_metadata
(
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
self
,
forward_batch
:
ForwardBatch
,
seq_lens
:
torch
.
Tensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
]:
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
# 参考vllm官方博客分页
block_kv_indices
=
torch
.
full
(
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
-
1
,
)
dtype
=
torch
.
int32
,
create_flashmla_kv_indices_triton
[(
bs
,)](
device
=
forward_batch
.
seq_lens
.
device
self
.
req_to_token
,
)
forward_batch
.
req_pool_indices
,
create_flashmla_kv_indices_triton
[(
bs
,)](
seq_lens
,
self
.
req_to_token
,
None
,
forward_batch
.
req_pool_indices
,
block_kv_indices
,
forward_batch
.
seq_lens
,
self
.
req_to_token
.
stride
(
0
),
None
,
max_seqlen_pad
,
block_kv_indices
,
)
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
mla_metadata
,
num_splits
=
get_mla_metadata
(
)
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
)
return
(
mla_metadata
,
num_splits
),
num_splits
,
block_kv_indices
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
mla_metadata
,
num_splits
=
get_mla_metadata
(
# decode用flashmla
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
num_q_heads
,
self
.
_build_decode_metadata
(
forward_batch
,
forward_batch
.
seq_lens
)
1
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
elif
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
+
self
.
num_draft_tokens
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
seq_lens
)
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
,
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
)
else
:
else
:
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# === DRAFT_EXTEND_V2 MLA metadata === nhb
if
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
:
bs
=
forward_batch
.
batch_size
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
seq_lens
=
forward_batch
.
seq_lens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
# 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
# MLA
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
,
)
# save forward_metadata
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits
,
block_kv_indices
,
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
def
init_cuda_graph_state
(
...
@@ -389,7 +447,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -389,7 +447,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
num_draft_tokens
=
self
.
num_draft_tokens
if
self
.
num_draft_tokens
is
not
None
else
0
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
...
@@ -401,7 +459,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -401,7 +459,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
)
.
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
k_scale
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
...
@@ -411,7 +469,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -411,7 +469,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
)
.
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
)
)
...
@@ -431,12 +489,9 @@ class DCUMLABackend(AttentionBackend):
...
@@ -431,12 +489,9 @@ class DCUMLABackend(AttentionBackend):
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
=
None
,
sinks
=
None
,
):
):
if
save_kv_cache
:
if
(
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
)
if
((
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
):
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
return
self
.
flashattn_backend
.
forward_extend
(
return
self
.
flashattn_backend
.
forward_extend
(
...
@@ -449,14 +504,19 @@ class DCUMLABackend(AttentionBackend):
...
@@ -449,14 +504,19 @@ class DCUMLABackend(AttentionBackend):
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
num_draft_tokens
=
self
.
num_draft_tokens
if
self
.
num_draft_tokens
is
not
None
else
0
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
...
@@ -468,7 +528,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -468,7 +528,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
k_scale
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
...
@@ -478,13 +538,12 @@ class DCUMLABackend(AttentionBackend):
...
@@ -478,13 +538,12 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
class
DCUMLAMultiStepDraftBackend
:
class
DCUMLAMultiStepDraftBackend
:
"""
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
Wrap multiple flashmla attention backends as one for multiple consecutive
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a50eb0e6
...
@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
skip_prefill
=
skip_prefill
self
.
skip_prefill
=
skip_prefill
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
k_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
v_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
if
self
.
is_hybrid
:
if
self
.
is_hybrid
:
self
.
full_to_swa_index_mapping
=
(
self
.
full_to_swa_index_mapping
=
(
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
...
@@ -598,6 +600,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -598,6 +600,7 @@ class FlashAttentionBackend(AttentionBackend):
if
(
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
):
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
...
@@ -608,10 +611,13 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -608,10 +611,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
# # Setup local attention if enabled
if
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
:
# if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if
forward_batch
.
forward_mode
in
(
ForwardMode
.
EXTEND
,
ForwardMode
.
DRAFT_EXTEND_V2
):
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
# Encoder metadata for cross attention
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
if
forward_batch
.
encoder_lens
is
not
None
:
assert
(
assert
(
...
@@ -668,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -668,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
if
not
self
.
use_mla
:
# if not self.use_mla:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
k_rope
is
None
:
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
if
not
self
.
use_mla
:
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
else
:
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
layer
,
...
@@ -690,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -690,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend):
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
)
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
...
@@ -774,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -774,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
window_size
=
(
-
1
,
-
1
)
result
=
flash_attn_with_kvcache
(
if
forward_batch
.
attn_attend_prefix_cache
:
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
k_cache
=
key_cache
,
# MHA for chunked prefix kv cache when running model with MLA
v_cache
=
value_cache
,
assert
forward_batch
.
prefix_chunk_idx
is
not
None
page_table
=
page_table
,
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
cache_seqlens
=
cache_seqlens
,
assert
forward_batch
.
prefix_chunk_max_seq_lens
is
not
None
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
chunk_idx
=
forward_batch
.
prefix_chunk_idx
max_seqlen_q
=
max_seqlen_q
,
assert
chunk_idx
>=
0
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
assert
forward_batch
.
mha_return_lse
window_size
=
window_size
,
output
=
flash_attn_varlen_func
(
softcap
=
layer
.
logit_cap
,
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_descale
=
k_descale
,
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v_descale
=
v_descale
,
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
return_softmax_lse
=
use_cascade_attn
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
num_splits
=
self
.
num_splits
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
**
kwargs
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
)
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
causal
=
False
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
)
else
:
else
:
o
=
result
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
)
if
forward_batch
.
mha_return_lse
:
output
,
lse
,
*
rest
=
output
lse
=
torch
.
transpose
(
lse
,
0
,
1
).
contiguous
()
return
output
,
lse
return
output
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
else
:
if
(
if
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
forward_batch
.
attn_attend_prefix_cache
is
not
None
...
@@ -851,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -851,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
causal
=
False
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -865,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -865,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -974,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -974,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
if
not
self
.
use_mla
:
# if not self.use_mla:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
k_rope
is
None
:
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
if
not
self
.
use_mla
:
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
else
:
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
layer
,
...
@@ -1019,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1019,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend):
if
sinks
is
not
None
:
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
...
@@ -1033,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1033,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
# Do multi-head attention
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
layer
.
layer_id
)
)
...
@@ -1085,65 +1106,62 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1085,65 +1106,62 @@ class FlashAttentionBackend(AttentionBackend):
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
cu_seqlens_q
=
metadata
.
cu_seqlens_q
max_seqlen_q
=
metadata
.
max_seq_len_q
page_table
=
metadata
.
page_table
page_table
=
metadata
.
page_table
cache_seqlens
=
metadata
.
cache_seqlens_int32
cu_seqlens_k
=
metadata
.
cu_seqlens_k
cu_seqlens_k
=
metadata
.
cu_seqlens_k
max
_seqlen
_q
=
metadata
.
max
_seq
_
len
_q
cache
_seqlen
s
=
metadata
.
cache
_seqlen
s_int32
q_reshaped
=
q
.
contiguous
()
.
view
(
key_cache
=
key_cache
.
view
(
-
1
,
layer
.
tp_
q
_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_
k
_head_num
,
layer
.
head_dim
)
)
value_cache
=
value_cache
.
view
(
# Default: single-token self-attention
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
result
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
)
if
use_cascade_attn
:
if
layer
.
is_cross_attention
:
o
,
softmax_lse
,
*
rest
=
result
page_table
=
metadata
.
encoder_page_table
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
(
cache_seqlens
=
metadata
.
encoder_lens_int32
flash_attn_with_kvcache
(
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
q
=
q_reshaped
,
window_size
=
(
-
1
,
-
1
)
k_cache
=
key_cache
,
if
max_seqlen_q
>
1
:
v_cache
=
value_cache
,
result
=
flash_attn_varlen_func
(
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
cu_seqlens_k
=
cu_seqlens_k
,
softmax_scale
=
layer
.
scaling
,
max_seqlen_q
=
max_seqlen_q
,
causal
=
False
,
max_seqlen_k
=
max_seqlen_q
,
window_size
=
window_size
,
softmax_scale
=
layer
.
scaling
,
softcap
=
layer
.
logit_cap
,
causal
=
True
,
k_descale
=
k_descale
,
window_size
=
window_size
,
v_descale
=
v_descale
,
softcap
=
layer
.
logit_cap
,
return_softmax_lse
=
True
,
k_descale
=
k_descale
,
num_splits
=
self
.
num_splits
,
v_descale
=
v_descale
,
**
kwargs
,
return_softmax_lse
=
use_cascade_attn
,
)
num_splits
=
self
.
num_splits
,
)
**
kwargs
,
o
,
_
=
merge_state_v2
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
)
)
else
:
else
:
o
=
result
result
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
o
=
result
else
:
else
:
# Do absorbed multi-latent attention
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
...
...
python/sglang/srt/layers/attention/flashattention_interface.py
View file @
a50eb0e6
...
@@ -41,18 +41,18 @@ def flash_attn_with_kvcache(
...
@@ -41,18 +41,18 @@ def flash_attn_with_kvcache(
ver
=
3
,
ver
=
3
,
):
):
return
flash_attn_with_kvcache_interface
(
return
flash_attn_with_kvcache_interface
(
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
k_cache
=
k_cache
,
k_cache
=
k_cache
.
view
(
q
.
dtype
)
,
v_cache
=
v_cache
,
v_cache
=
v_cache
.
view
(
q
.
dtype
)
,
block_table
=
page_table
,
block_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
softcap
=
softcap
,
return_softmax_lse
=
return_softmax_lse
,
return_softmax_lse
=
return_softmax_lse
,
num_splits
=
num_splits
,
num_splits
=
num_splits
,
)
)
def
flash_attn_varlen_func
(
def
flash_attn_varlen_func
(
q
,
q
,
...
@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
...
@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
):
):
return
flash_attn_varlen_func_interface
(
return
flash_attn_varlen_func_interface
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
.
view
(
q
.
dtype
)
,
v
=
v
,
v
=
v
.
view
(
q
.
dtype
)
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
...
@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
...
@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
return_attn_probs
=
return_softmax_lse
,
return_attn_probs
=
return_softmax_lse
,
softcap
=
softcap
,
)
)
\ No newline at end of file
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
a50eb0e6
...
@@ -3,6 +3,7 @@ from __future__ import annotations
...
@@ -3,6 +3,7 @@ from __future__ import annotations
import
logging
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -39,7 +40,7 @@ if TYPE_CHECKING:
...
@@ -39,7 +40,7 @@ if TYPE_CHECKING:
DeepEPNormalOutput
,
DeepEPNormalOutput
,
DispatchOutput
,
DispatchOutput
,
)
)
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
,
m_grouped_w8a8_gemm_nt_masked
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -127,6 +128,7 @@ class EPMoE(FusedMoE):
...
@@ -127,6 +128,7 @@ class EPMoE(FusedMoE):
self
.
fp8_dtype
=
torch
.
float8_e4m3fn
self
.
fp8_dtype
=
torch
.
float8_e4m3fn
self
.
activation_scheme
=
quant_config
.
activation_scheme
self
.
activation_scheme
=
quant_config
.
activation_scheme
self
.
use_w4a8_marlin
=
False
self
.
use_w4a8_marlin
=
False
self
.
use_w8a8_marlin
=
False
elif
isinstance
(
quant_config
,
SlimQuantW4A8Int8MarlinConfig
):
elif
isinstance
(
quant_config
,
SlimQuantW4A8Int8MarlinConfig
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
block_shape
=
(
...
@@ -137,12 +139,25 @@ class EPMoE(FusedMoE):
...
@@ -137,12 +139,25 @@ class EPMoE(FusedMoE):
self
.
use_fp8_w8a8
=
False
self
.
use_fp8_w8a8
=
False
self
.
activation_scheme
=
None
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
True
self
.
use_w4a8_marlin
=
True
self
.
use_w8a8_marlin
=
False
elif
isinstance
(
quant_config
,
SlimQuantCompressedTensorsMarlinConfig
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
quant_method
.
quant_config
.
weight_block_size
if
self
.
use_block_quant
else
None
)
self
.
use_fp8_w8a8
=
False
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
False
self
.
use_w8a8_marlin
=
True
else
:
else
:
self
.
use_fp8_w8a8
=
False
self
.
use_fp8_w8a8
=
False
self
.
use_block_quant
=
False
self
.
use_block_quant
=
False
self
.
block_shape
=
None
self
.
block_shape
=
None
self
.
activation_scheme
=
None
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
False
self
.
use_w4a8_marlin
=
False
self
.
use_w8a8_marlin
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
...
@@ -498,6 +513,8 @@ class DeepEPMoE(EPMoE):
...
@@ -498,6 +513,8 @@ class DeepEPMoE(EPMoE):
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
if
self
.
use_w4a8_marlin
:
if
self
.
use_w4a8_marlin
:
return
self
.
forward_groupgemm_w4a8_marlin_masked
(
dispatch_output
)
return
self
.
forward_groupgemm_w4a8_marlin_masked
(
dispatch_output
)
elif
self
.
use_w8a8_marlin
:
return
self
.
forward_groupgemm_w8a8_marlin_masked
(
dispatch_output
)
else
:
else
:
if
(
if
(
get_moe_runner_backend
().
is_flashinfer_cutedsl
()
get_moe_runner_backend
().
is_flashinfer_cutedsl
()
...
@@ -783,7 +800,7 @@ class DeepEPMoE(EPMoE):
...
@@ -783,7 +800,7 @@ class DeepEPMoE(EPMoE):
# base shapes
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
m
//
2
# 算子要求形状
expected_m
=
m
in
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
...
@@ -822,6 +839,56 @@ class DeepEPMoE(EPMoE):
...
@@ -822,6 +839,56 @@ class DeepEPMoE(EPMoE):
return
down_output
return
down_output
def
forward_groupgemm_w8a8_marlin_masked
(
self
,
dispatch_output
:
DeepEPLLOutput
,
):
hidden_states
,
_
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
w13_scales
=
self
.
w13_weight_scale
w2_weight
=
self
.
w2_weight
w2_scales
=
self
.
w2_weight_scale
n1
=
w13_scales
.
size
(
1
)
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
# ---- first GEMM ----
m_grouped_w8a8_gemm_nt_masked
(
(
q_a1_all
,
q_a1_scale
),
(
w13_weight
,
w13_scales
),
gateup_output
,
masked_m
,
expected_m
,
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
# ---- second GEMM ----
n2
=
w2_scales
.
size
(
1
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_masked
(
(
q_a2_all
,
q_a2_scale
),
(
w2_weight
,
w2_scales
),
down_output
,
masked_m
,
expected_m
,
)
return
down_output
def
forward_deepgemm_masked
(
def
forward_deepgemm_masked
(
self
,
self
,
dispatch_output
:
DeepEPLLOutput
,
dispatch_output
:
DeepEPLLOutput
,
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
a50eb0e6
...
@@ -15,6 +15,7 @@ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
...
@@ -15,6 +15,7 @@ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.utils
import
get_moe_a2a_backend
try
:
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
except
Exception
:
except
Exception
:
...
@@ -39,6 +40,18 @@ def get_w8a8_int8_marlin_weights(
...
@@ -39,6 +40,18 @@ def get_w8a8_int8_marlin_weights(
return
weight
return
weight
def
w8a8_nt_kpack2_marlin_weight
(
w8a8_w
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
n_tile
=
16
,
):
assert
w8a8_w
.
dtype
==
torch
.
int8
,
"w8a8_w 必须是 int8 类型"
size_n
,
size_k
=
w8a8_w
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
q
=
w8a8_w
.
reshape
((
size_n
//
n_tile
,
n_tile
,
size_k
//
k_tile
,
k_tile
))
q
=
q
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
q
=
q
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
q
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
@
staticmethod
@
staticmethod
def
get_moe_method
(
def
get_moe_method
(
...
@@ -65,7 +78,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -65,7 +78,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"weights"
)
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
"input_activations"
)
self
.
use_deepep
=
get_moe_a2a_backend
().
is_deepep
()
per_channel
=
(
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
...
@@ -138,13 +151,19 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -138,13 +151,19 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w2_marlin_list
=
[]
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
a50eb0e6
...
@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
loc_tensor
=
torch
.
tensor
([
-
1
],
device
=
self
.
device
)
# Allocate memory
# Allocate memory
out_cache_loc
,
req_pool_indices_tensor
,
req_pool_indices
=
alloc_for_extend
(
out_cache_loc
,
req_pool_indices_tensor
,
req_pool_indices
=
alloc_for_extend
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
a50eb0e6
...
@@ -1940,7 +1940,7 @@ class Scheduler(
...
@@ -1940,7 +1940,7 @@ class Scheduler(
batch
.
spec_info
=
batch_result
.
next_draft_input
batch
.
spec_info
=
batch_result
.
next_draft_input
batch
.
spec_info
.
future_indices
=
future_indices
batch
.
spec_info
.
future_indices
=
future_indices
batch
.
sampling_info
.
is_all_greedy
=
True
#nhb
# batch.spec_info = EagleDraftInput(
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
# verify_done=batch_result.next_draft_input.verify_done,
...
...
python/sglang/srt/mem_cache/common.py
View file @
a50eb0e6
...
@@ -356,7 +356,8 @@ def alloc_for_extend(
...
@@ -356,7 +356,8 @@ def alloc_for_extend(
else
:
else
:
# Paged allocation - build last_loc
# Paged allocation - build last_loc
last_loc
=
[
last_loc
=
[
(
t
[
-
1
:]
if
len
(
t
)
>
0
else
torch
.
tensor
([
-
1
],
device
=
batch
.
device
))
# (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
(
t
[
-
1
:]
if
len
(
t
)
>
0
else
batch
.
loc_tensor
)
for
t
in
prefix_tensors
for
t
in
prefix_tensors
]
]
out_cache_loc
=
alloc_paged_token_slots_extend
(
out_cache_loc
=
alloc_paged_token_slots_extend
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
a50eb0e6
...
@@ -123,12 +123,13 @@ class ForwardMode(IntEnum):
...
@@ -123,12 +123,13 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker
# For fixed shape logits output in v2 eagle worker
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
def
is_extend_or_draft_extend_or_mixed
(
self
):
def
is_extend_or_draft_extend_or_mixed
(
self
):
#nhb
return
(
return
(
self
==
ForwardMode
.
EXTEND
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
DRAFT_EXTEND_V2
)
)
def
is_cuda_graph
(
self
):
def
is_cuda_graph
(
self
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a50eb0e6
...
@@ -2241,6 +2241,7 @@ class ModelRunner:
...
@@ -2241,6 +2241,7 @@ class ModelRunner:
and
self
.
graph_runner
and
self
.
graph_runner
and
self
.
graph_runner
.
can_run
(
forward_batch
)
and
self
.
graph_runner
.
can_run
(
forward_batch
)
)
)
if
can_run_graph
:
if
can_run_graph
:
ret
=
self
.
graph_runner
.
replay
(
ret
=
self
.
graph_runner
.
replay
(
forward_batch
,
forward_batch
,
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
a50eb0e6
...
@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
...
@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc
,
get_src_tgt_cache_loc
,
get_target_cache_loc
,
get_target_cache_loc
,
)
)
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_create_extend_after_decode_spec_info
if
is_cuda
():
if
is_cuda
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
...
@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
use_sglang_create_extend_after_decode_spec_info
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
...
@@ -679,14 +682,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
...
@@ -679,14 +682,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
if
self
.
use_sglang_create_extend_after_decode_spec_info
:
batch
.
input_ids
,
dcu_create_extend_after_decode_spec_info
(
batch
.
seq_lens
,
verified_id
=
batch
.
input_ids
,
self
.
accept_length
,
seq_lens
=
batch
.
seq_lens
,
self
.
positions
,
accept_lens
=
self
.
accept_length
,
self
.
verified_id
,
positions
=
self
.
positions
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
new_verified_id
=
self
.
verified_id
,
)
bs
=
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
)),
)
else
:
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
def
generate_attn_arg_prefill
(
def
generate_attn_arg_prefill
(
self
,
self
,
...
...
sgl-kernel/csrc/attention/merge_attn_states.cu
View file @
a50eb0e6
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_bf16.h>
#endif
#include <algorithm>
#include <algorithm>
#include <optional>
#include <optional>
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
a50eb0e6
...
@@ -131,6 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -131,6 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
/*
* From csrc/kvcacheio
* From csrc/kvcacheio
*/
*/
m
.
def
(
"dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"
);
m
.
impl
(
"dcu_create_extend_after_decode_spec_info"
,
torch
::
kCUDA
,
&
dcu_create_extend_after_decode_spec_info
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
m
.
def
(
"dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
def
(
"dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
a50eb0e6
...
@@ -693,6 +693,65 @@ __global__ void launch_alloc_extend_kernel(
...
@@ -693,6 +693,65 @@ __global__ void launch_alloc_extend_kernel(
out_indices
[
output_idx
]
=
start_loc
*
page_size
+
offset
;
out_indices
[
output_idx
]
=
start_loc
*
page_size
+
offset
;
}
}
}
}
__global__
void
launch_create_extend_after_decode_spec_info_int32_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int32_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int32_t
accept_length
=
accept_lens_ptr
[
pid
];
int32_t
accept_len_cumsum
=
0
;
for
(
int32_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int32_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int32_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
__global__
void
launch_create_extend_after_decode_spec_info_int64_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int64_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int64_t
accept_length
=
accept_lens_ptr
[
pid
];
int64_t
accept_len_cumsum
=
0
;
for
(
int64_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int64_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int64_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
void
dcu_alloc_decode_kernel
(
void
dcu_alloc_decode_kernel
(
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
@@ -714,6 +773,49 @@ void dcu_alloc_decode_kernel(
...
@@ -714,6 +773,49 @@ void dcu_alloc_decode_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
}
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
)
{
const
int32_t
*
verified_id_ptr
;
const
int64_t
*
seq_lens_ptr
;
const
int32_t
*
accept_lens_ptr_int32
;
const
int64_t
*
accept_lens_ptr_int64
;
int64_t
*
positions_ptr
;
int32_t
*
new_verified_id_ptr
;
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
accept_lens
.
dtype
()
==
torch
::
kInt32
)
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int32
=
static_cast
<
const
int32_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int32_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int32
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
else
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int64
=
static_cast
<
const
int64_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int64_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int64
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
};
void
dcu_alloc_extend_kernel
(
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
a50eb0e6
...
@@ -538,6 +538,14 @@ void segment_packbits(
...
@@ -538,6 +538,14 @@ void segment_packbits(
/*
/*
* From csrc/kvcacheio
* From csrc/kvcacheio
*/
*/
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
);
void
dcu_alloc_extend_kernel
(
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
a50eb0e6
...
@@ -9,6 +9,22 @@ def is_hip() -> bool:
...
@@ -9,6 +9,22 @@ def is_hip() -> bool:
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
def
dcu_create_extend_after_decode_spec_info
(
verified_id
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
accept_lens
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
new_verified_id
:
torch
.
Tensor
,
bs
:
int
,
):
torch
.
ops
.
sgl_kernel
.
dcu_create_extend_after_decode_spec_info
(
verified_id
,
seq_lens
,
accept_lens
,
positions
,
new_verified_id
,
bs
,
)
def
dcu_alloc_extend_kernel
(
def
dcu_alloc_extend_kernel
(
pre_lens_ptr
:
torch
.
Tensor
,
pre_lens_ptr
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment