Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a1175a4e
Commit
a1175a4e
authored
Nov 22, 2025
by
maxiao1
Browse files
Merge remote-tracking branch 'origin/v0.5.4_dev' into sglang_v0.5.5
parents
0c006b88
31653dd9
Changes
62
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2608 additions
and
409 deletions
+2608
-409
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+10
-0
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+173
-3
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+5
-1
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+286
-0
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+13
-5
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+11
-0
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+5
-1
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+679
-0
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
...srt/layers/attention/dual_chunk_flashattention_backend.py
+2
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+131
-127
python/sglang/srt/layers/attention/flashattention_interface.py
...n/sglang/srt/layers/attention/flashattention_interface.py
+96
-0
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+48
-20
python/sglang/srt/layers/attention/lightop_concat.py
python/sglang/srt/layers/attention/lightop_concat.py
+67
-0
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+2
-1
python/sglang/srt/layers/attention/xpu_backend.py
python/sglang/srt/layers/attention/xpu_backend.py
+2
-1
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+44
-9
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+26
-5
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+124
-188
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+865
-37
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+19
-10
No files found.
python/sglang/bench_serving.py
View file @
a1175a4e
...
...
@@ -848,10 +848,12 @@ class BenchmarkMetrics:
mean_ttft_ms
:
float
median_ttft_ms
:
float
std_ttft_ms
:
float
p95_ttft_ms
:
float
p99_ttft_ms
:
float
mean_tpot_ms
:
float
median_tpot_ms
:
float
std_tpot_ms
:
float
p95_tpot_ms
:
float
p99_tpot_ms
:
float
mean_itl_ms
:
float
median_itl_ms
:
float
...
...
@@ -1721,10 +1723,12 @@ def calculate_metrics(
*
1000
,
# ttfts is empty if streaming is not supported by backend
median_ttft_ms
=
np
.
median
(
ttfts
or
0
)
*
1000
,
std_ttft_ms
=
np
.
std
(
ttfts
or
0
)
*
1000
,
p95_ttft_ms
=
np
.
percentile
(
ttfts
or
0
,
95
)
*
1000
,
p99_ttft_ms
=
np
.
percentile
(
ttfts
or
0
,
99
)
*
1000
,
mean_tpot_ms
=
np
.
mean
(
tpots
or
0
)
*
1000
,
median_tpot_ms
=
np
.
median
(
tpots
or
0
)
*
1000
,
std_tpot_ms
=
np
.
std
(
tpots
or
0
)
*
1000
,
p95_tpot_ms
=
np
.
percentile
(
tpots
or
0
,
95
)
*
1000
,
p99_tpot_ms
=
np
.
percentile
(
tpots
or
0
,
99
)
*
1000
,
mean_itl_ms
=
np
.
mean
(
itls
or
0
)
*
1000
,
median_itl_ms
=
np
.
median
(
itls
or
0
)
*
1000
,
...
...
@@ -2052,6 +2056,12 @@ async def benchmark(
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TTFT (ms):"
,
metrics
.
mean_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TTFT (ms):"
,
metrics
.
median_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TTFT (ms):"
,
metrics
.
p99_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P95 TTFT (ms):"
,
metrics
.
p95_ttft_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Time per Output Token (excl. 1st token)"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TPOT (ms):"
,
metrics
.
mean_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TPOT (ms):"
,
metrics
.
median_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TPOT (ms):"
,
metrics
.
p99_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P95 TPOT (ms):"
,
metrics
.
p95_tpot_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Inter-Token Latency"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean ITL (ms):"
,
metrics
.
mean_itl_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median ITL (ms):"
,
metrics
.
median_itl_ms
))
...
...
python/sglang/srt/_custom_ops.py
View file @
a1175a4e
...
...
@@ -4,10 +4,24 @@ from typing import List, Optional, Tuple
import
torch
from
sglang.srt.utils
import
is_hip
,
is_hpu
,
is_npu
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
,
is_npu
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
import
lightop
except
Exception
:
print
(
"INFO: Please install lightop if you want to infer awq of marlin.
\n
"
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
get_bool_env_var
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
"false"
)
use_dcu_custom_allreduce
=
get_bool_env_var
(
"USE_DCU_CUSTOM_ALLREDUCE"
,
default
=
"true"
)
if
not
is_hpu
():
try
:
...
...
@@ -15,6 +29,11 @@ if not is_hpu():
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
if
use_dcu_custom_allreduce
:
try
:
import
vllm._C
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
if
not
is_hip
()
and
not
is_npu
():
custom_op
=
sgl_kernel
.
allreduce
...
...
@@ -54,8 +73,79 @@ if not is_hip() and not is_npu():
)
->
None
:
custom_op
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
elif
is_hip
and
use_dcu_custom_allreduce
:
# custom ar
def
init_custom_ar
(
ipc_tensors
:
list
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
fully_connected
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
fully_connected
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
list
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
tuple
[
list
[
int
],
list
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
list
[
list
[
int
]],
offsets
:
list
[
list
[
int
]])
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_shared_buffer_and_handle
(
size
:
int
)
->
tuple
[
int
,
torch
.
Tensor
]:
return
torch
.
ops
.
_C_custom_ar
.
allocate_shared_buffer_and_handle
(
size
)
def
open_mem_handle
(
mem_handle
:
torch
.
Tensor
):
return
torch
.
ops
.
_C_custom_ar
.
open_mem_handle
(
mem_handle
)
def
free_shared_buffer
(
ptr
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
free_shared_buffer
(
ptr
)
def
read_cache
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
list
[
torch
.
Tensor
],
value_caches
:
list
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
def
write_cache_multi_layers
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
list
[
torch
.
Tensor
],
value_caches
:
list
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
else
:
# ROCM custom allreduce
#
sgl_kernel
ROCM custom allreduce
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
...
...
@@ -163,3 +253,83 @@ def mscclpp_allreduce(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
return
sgl_kernel
.
allreduce
.
mscclpp_allreduce
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
def
triton_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
list
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
,
best_config
)
def
cutlass_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
# m = a.shape[0]
# n = b.shape[1]
# cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
# if current_platform.is_rocm() or not cutlass_compatible_b:
# from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
# triton_scaled_mm)
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
# out = torch.empty((m, n), dtype=out_dtype, device=a.device)
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
triton_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
out_dtype
:
type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda:0"
,
best_config
:
Optional
[
list
]
=
None
,
repeat
:
Optional
[
int
]
=
2
):
return
quant_tools
.
triton_int8_gemm_helper
(
m
,
n
,
k
,
per_token_act_quant
,
per_out_channel_weight_quant
,
use_bias
,
out_dtype
,
device
,
best_config
,
repeat
)
\ No newline at end of file
python/sglang/srt/configs/model_config.py
View file @
a1175a4e
...
...
@@ -635,7 +635,9 @@ class ModelConfig:
"petit_nvfp4"
,
"quark"
,
"mxfp4"
,
"auto-round"
,
"slimquant_w4a8_marlin"
,
"w8a8_int8"
,
"slimquant_marlin"
,
]
optimized_quantization_methods
=
[
"fp8"
,
...
...
@@ -655,6 +657,8 @@ class ModelConfig:
"qoq"
,
"w4afp8"
,
"petit_nvfp4"
,
"slimquant_w4a8_marlin"
,
"slimquant_marlin"
,
]
compatible_quantization_methods
=
{
"modelopt_fp8"
:
[
"modelopt"
],
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
a1175a4e
...
...
@@ -34,6 +34,21 @@ except ImportError:
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
try
:
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
# Use vLLM custom allreduce
ops
.
meta_size
()
elif
ops
.
use_dcu_custom_allreduce
:
ops
.
meta_size
()
else
:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
# noqa: F401
custom_ar
=
True
except
Exception
:
# For CPUs
custom_ar
=
False
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -416,3 +431,274 @@ class CustomAllreduce:
def
__del__
(
self
):
self
.
close
()
class
DCUCustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
,
16
]
# max_size: max supported allreduce size
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
512
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
not
custom_ar
:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger
.
info
(
"Custom allreduce is disabled because "
"of missing custom allreduce library"
)
return
self
.
group
=
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"CustomAllreduce should be attached to a non-NCCL group."
)
if
not
all
(
in_the_same_node_as
(
group
,
source_rank
=
0
)):
# No need to initialize custom allreduce for multi-node case.
logger
.
warning
(
"Custom allreduce is disabled because this process group"
" spans across nodes."
)
return
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
self
.
rank
=
rank
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
if
world_size
>
16
:
return
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
))
return
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
cuda_visible_devices
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
)
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
torch
.
cuda
.
device_count
()))
physical_device_id
=
device_ids
[
device
.
index
]
tensor
=
torch
.
tensor
([
physical_device_id
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
gather_list
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
for
_
in
range
(
world_size
)
]
dist
.
all_gather
(
gather_list
,
tensor
,
group
=
self
.
group
)
physical_device_ids
=
[
t
.
item
()
for
t
in
gather_list
]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# assert current_platform.is_cuda_alike()
# fully_connected = current_platform.is_fully_connected(
# physical_device_ids)
if
_is_cuda
or
_is_hip
:
fully_connected
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
# if world_size > 2 and not fully_connected:
if
not
fully_connected
:
max_size
=
32
*
8192
*
2
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
# logger.warning(
# "Custom allreduce is disabled because it's not supported on"
# " more than two PCIe-only GPUs. To silence this warning, "
# "specify disable_custom_all_reduce=True explicitly.")
# return
logger
.
warning
(
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction."
)
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
_is_hip
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self
.
disabled
=
False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
,
uncached
=
True
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
fully_connected
=
fully_connected
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
fully_connected
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
@
contextmanager
def
capture
(
self
):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try
:
self
.
_IS_CAPTURING
=
True
yield
finally
:
self
.
_IS_CAPTURING
=
False
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
def
register_graph_buffers
(
self
):
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data
=
[[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))]
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
if
self
.
disabled
:
return
False
inp_size
=
inp
.
numel
()
*
inp
.
element_size
()
# custom allreduce requires input byte size to be multiples of 16
if
inp_size
%
16
!=
0
:
return
False
if
not
is_weak_contiguous
(
inp
):
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
return
inp_size
<=
self
.
max_size
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if
self
.
disabled
or
not
self
.
should_custom_ar
(
input
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
,
registered
=
False
)
else
:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
ops
is
not
None
:
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
,
rank
=
self
.
rank
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
,
rank
=
self
.
rank
)
def
__del__
(
self
):
self
.
close
()
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
,
uncached
:
Optional
[
bool
]
=
False
)
->
list
[
int
]:
pointer
,
handle
=
ops
.
allocate_shared_buffer_and_handle
(
size_in_bytes
)
world_size
=
dist
.
get_world_size
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
handles
=
[
None
]
*
world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
pointers
:
list
[
int
]
=
[]
for
i
,
h
in
enumerate
(
handles
):
if
i
==
rank
:
pointers
.
append
(
pointer
)
# type: ignore
else
:
pointers
.
append
(
ops
.
open_mem_handle
(
h
))
return
pointers
@
staticmethod
def
free_shared_buffer
(
pointers
:
list
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
,
rank
:
Optional
[
int
]
=
0
)
->
None
:
if
rank
is
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
if
ops
is
not
None
:
ops
.
free_shared_buffer
(
pointers
[
rank
])
python/sglang/srt/distributed/parallel_state.py
View file @
a1175a4e
...
...
@@ -54,6 +54,7 @@ from sglang.srt.utils import (
is_xpu
,
supports_custom_op
,
)
from
sglang.srt
import
_custom_ops
as
ops
_is_npu
=
is_npu
()
_is_cpu
=
is_cpu
()
...
...
@@ -327,7 +328,7 @@ class GroupCoordinator:
# Lazy import to avoid documentation build error
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
,
CustomAllreduce
,
DCUCustomAllreduce
)
from
sglang.srt.distributed.device_communicators.pymscclpp
import
(
PyMscclppCommunicator
,
...
...
@@ -371,10 +372,17 @@ class GroupCoordinator:
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
try
:
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
if
is_hip
()
and
ops
.
use_dcu_custom_allreduce
:
self
.
ca_comm
=
DCUCustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
else
:
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
max_size
=
ca_max_size
,
)
except
Exception
as
e
:
logger
.
warning
(
f
"Setup Custom allreduce failed with
{
e
}
. To silence this "
...
...
python/sglang/srt/environ.py
View file @
a1175a4e
...
...
@@ -188,6 +188,17 @@ class Envs:
SGLANG_USE_AITER
=
EnvBool
(
False
)
SGLANG_ROCM_FUSED_DECODE_MLA
=
EnvBool
(
False
)
SGLANG_ROCM_DISABLE_LINEARQUANT
=
EnvBool
(
False
)
# DCU Lightop
SGLANG_USE_LIGHTOP
=
EnvBool
(
False
)
# Fused
SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD
=
EnvBool
(
False
)
SGLANG_USE_OPT_CAT
=
EnvBool
(
False
)
SGLANG_USE_FUSED_RMS_QUANT
=
EnvBool
(
False
)
SGLANG_USE_FUSED_SILU_MUL_QUANT
=
EnvBool
(
False
)
# Quantization
SGLANG_INT4_WEIGHT
=
EnvBool
(
False
)
...
...
python/sglang/srt/layers/attention/attention_registry.py
View file @
a1175a4e
...
...
@@ -99,7 +99,6 @@ def create_triton_backend(runner):
return
TritonAttnBackend
(
runner
)
@
register_attention_backend
(
"torch_native"
)
def
create_torch_native_backend
(
runner
):
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
...
...
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
return
FlashMLABackend
(
runner
)
@
register_attention_backend
(
"dcu_mla"
)
def
create_dcu_mla_backend
(
runner
):
from
sglang.srt.layers.attention.dcu_mla_backend
import
DCUMLABackend
return
DCUMLABackend
(
runner
)
@
register_attention_backend
(
"fa3"
)
def
create_flashattention_v3_backend
(
runner
):
...
...
python/sglang/srt/layers/attention/dcu_mla_backend.py
0 → 100644
View file @
a1175a4e
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
triton
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sgl_kernel.flash_mla
import
dcu_create_flashmla_kv_indices
from
sglang.srt.utils
import
get_bool_env_var
try
:
from
flash_mla
import
(
flash_mla_with_kvcache
,
flash_mla_with_kvcache_quantization
,
get_mla_metadata
)
_has_flash_mla
=
True
except
Exception
:
try
:
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
)
_has_flash_mla
=
False
except
Exception
:
raise
ImportError
(
"Can not import FlashMLA。Please perform the following operations to use flashmla:
\n
"
" pip install flash-mla
\n
"
" or
\n
"
" pip install vllm"
)
PAGE_SIZE
=
64
# 强制64
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInput
@
dataclass
class
VllmMLADecodeMetadata
:
flashmla_metadata
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
def
__init__
(
self
,
flashmla_metadata
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
,
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
flashmla_metadata
=
flashmla_metadata
self
.
num_splits
=
num_splits
self
.
block_kv_indices
=
block_kv_indices
class
DCUMLABackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
"ModelRunner"
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_len_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
()
if
model_runner
.
server_args
.
page_size
!=
PAGE_SIZE
:
raise
ValueError
(
f
"dcu_mla backend requires page_size=
{
PAGE_SIZE
}
, "
f
"but got the
{
model_runner
.
server_args
.
page_size
}
"
)
self
.
num_q_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
qk_nope_head_dim
=
model_runner
.
model_config
.
qk_nope_head_dim
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
v_head_dim
=
model_runner
.
model_config
.
v_head_dim
self
.
kv_cache_dim
=
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
device
=
model_runner
.
device
self
.
k_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
forward_metadata
:
Union
[
VllmMLADecodeMetadata
]
=
None
self
.
skip_prefill
=
skip_prefill
if
not
skip_prefill
:
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
self
.
flashattn_backend
=
FlashAttentionBackend
(
model_runner
,
skip_prefill
=
False
,
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
use_sglang_create_flashmla_kv_indices_triton
=
get_bool_env_var
(
"SGLANG_CREATE_FLASHMLA_KV_INDICES_TRITON"
)
bs
=
forward_batch
.
batch_size
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits
,
block_kv_indices
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
+
self
.
num_draft_tokens
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
,
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits
,
block_kv_indices
)
else
:
if
not
self
.
skip_prefill
:
# === DRAFT_EXTEND_V2 MLA metadata === nhb
if
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
:
bs
=
forward_batch
.
batch_size
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
seq_lens
=
forward_batch
.
seq_lens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
# 调用 Triton kernel 生成 block_kv_indices
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
.
to
(
torch
.
int32
),
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
.
to
(
torch
.
int32
),
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
.
to
(
torch
.
int32
),
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
# MLA
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
,
)
# save forward_metadata
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits
,
block_kv_indices
,
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
,
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
block_kv_indices
is
None
:
cuda_graph_kv_indices
=
torch
.
full
(
(
max_bs
,
(
self
.
max_context_len
+
PAGE_SIZE
)
//
PAGE_SIZE
),
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
else
:
cuda_graph_kv_indices
=
block_kv_indices
if
self
.
num_draft_tokens
:
mla_metadata
,
num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
,
)
else
:
mla_metadata
,
num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
),
self
.
num_q_heads
,
1
,
)
self
.
cuda_graph_mla_metadata
=
mla_metadata
self
.
cuda_graph_num_splits
=
num_splits
self
.
cuda_graph_kv_indices
=
cuda_graph_kv_indices
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
"SpecInput"
],
):
if
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
num_q_heads
=
self
.
num_q_heads
*
(
self
.
num_draft_tokens
or
1
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
[:
bs
+
1
],
self
.
cuda_graph_kv_indices
[:
bs
,
:
max_seqlen_pad
],
)
elif
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
+
self
.
num_draft_tokens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
[:
bs
+
1
],
self
.
cuda_graph_kv_indices
[:
bs
,
:
max_seqlen_pad
],
)
else
:
if
not
self
.
skip_prefill
:
self
.
flashattn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_mode
,
spec_info
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
"SpecInput"
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
if
forward_mode
.
is_decode_or_idle
():
assert
seq_lens_cpu
is
not
None
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
num_q_heads
=
self
.
num_q_heads
*
(
self
.
num_draft_tokens
or
1
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
.
flashmla_metadata
=
self
.
cuda_graph_mla_metadata
self
.
forward_metadata
.
num_splits
=
self
.
cuda_graph_num_splits
[:
bs
+
1
]
self
.
forward_metadata
.
block_kv_indices
=
self
.
cuda_graph_kv_indices
[
:
bs
,
:
max_seqlen_pad
]
elif
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
[:
bs
]
+
self
.
num_draft_tokens
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
+
self
.
num_draft_tokens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
.
flashmla_metadata
=
self
.
cuda_graph_mla_metadata
self
.
forward_metadata
.
num_splits
=
self
.
cuda_graph_num_splits
[:
bs
+
1
]
self
.
forward_metadata
.
block_kv_indices
=
self
.
cuda_graph_kv_indices
[
:
bs
,
:
max_seqlen_pad
]
else
:
if
not
self
.
skip_prefill
:
self
.
flashattn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
encoder_lens
,
forward_mode
,
spec_info
,
seq_lens_cpu
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
_call_decode
(
self
,
reshape_q
:
torch
.
Tensor
,
k_cache_reshaped
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
scaling
:
float
):
o
,
_
=
flash_mla_with_kvcache
(
q
=
reshape_q
,
k_cache
=
k_cache_reshaped
,
block_table
=
block_table
,
cache_seqlens
=
cache_seqlens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
self
.
forward_metadata
.
flashmla_metadata
,
num_splits
=
self
.
forward_metadata
.
num_splits
,
softmax_scale
=
scaling
,
causal
=
True
,
)
return
o
def
_call_fp8_decode
(
self
,
reshape_q
:
torch
.
Tensor
,
k_cache_reshaped
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
scaling
:
float
,
k_scale
=
None
,
kv_cache_dtype
=
None
):
assert
_has_flash_mla
,
"FP8 KV cache 需要flash_mla包"
o
,
_
=
flash_mla_with_kvcache_quantization
(
q
=
reshape_q
,
k_cache
=
k_cache_reshaped
,
block_table
=
block_table
,
cache_seqlens
=
cache_seqlens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
self
.
forward_metadata
.
flashmla_metadata
,
num_splits
=
self
.
forward_metadata
.
num_splits
,
softmax_scale
=
scaling
,
causal
=
True
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
return
o
@
torch
.
_dynamo
.
disable
()
# NOTE: FP8 cache decode不支持compile
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
):
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
num_draft_tokens
=
self
.
num_draft_tokens
if
self
.
num_draft_tokens
is
not
None
else
0
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
kv_cache_dtype
=
"fp8_e4m3"
else
:
kv_cache_dtype
=
"fp8_e5m2"
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
self
.
k_scale
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
else
:
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
@
torch
.
_dynamo
.
disable
()
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
=
None
,
):
if
((
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
):
if
not
self
.
skip_prefill
:
return
self
.
flashattn_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
sinks
)
else
:
raise
RuntimeError
(
"skip prefill but use forward_extend"
)
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
num_draft_tokens
=
self
.
num_draft_tokens
if
self
.
num_draft_tokens
is
not
None
else
0
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
kv_cache_dtype
=
"fp8_e4m3"
else
:
kv_cache_dtype
=
"fp8_e5m2"
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
self
.
k_scale
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
else
:
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
class
DCUMLAMultiStepDraftBackend
:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
draft decoding steps.
"""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
if
topk
>
1
:
raise
ValueError
(
"Currently FlashMLA only supports topk=1 for speculative decoding"
)
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
+
1
,
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
DCUMLABackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
kv_last_page_len_buf
=
None
,
)
)
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
call_fn
:
Callable
,
):
assert
forward_batch
.
spec_info
is
not
None
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
block_kv_indices
=
None
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
View file @
a1175a4e
...
...
@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
sgl_kernel.sparse_flash_attn
import
(
convert_vertical_slash_indexes
,
convert_vertical_slash_indexes_mergehead
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a1175a4e
...
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
@
dataclass
...
...
@@ -328,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
skip_prefill
=
skip_prefill
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
k_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
v_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
if
self
.
is_hybrid
:
self
.
full_to_swa_index_mapping
=
(
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
...
...
@@ -596,9 +599,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
if
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
...
...
@@ -608,10 +613,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
if
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
:
# # Setup local attention if enabled
# if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if
forward_batch
.
forward_mode
in
(
ForwardMode
.
EXTEND
,
ForwardMode
.
DRAFT_EXTEND_V2
):
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
assert
(
...
...
@@ -668,10 +676,11 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
self
.
use_mla
:
if
k_rope
is
None
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
#
layer.k_scale, layer.v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
...
...
@@ -690,7 +699,8 @@ class FlashAttentionBackend(AttentionBackend):
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
...
...
@@ -704,7 +714,7 @@ class FlashAttentionBackend(AttentionBackend):
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
q
=
q
.
to
(
self
.
kv_cache_dtype
)
#
q = q.to(self.kv_cache_dtype)
q_rope
=
q_rope
.
to
(
self
.
kv_cache_dtype
)
if
q_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
causal
=
True
...
...
@@ -774,61 +784,59 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
result
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
if
forward_batch
.
attn_attend_prefix_cache
:
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA
assert
forward_batch
.
prefix_chunk_idx
is
not
None
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
assert
forward_batch
.
prefix_chunk_max_seq_lens
is
not
None
chunk_idx
=
forward_batch
.
prefix_chunk_idx
assert
chunk_idx
>=
0
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
)
else
:
o
=
result
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
)
if
forward_batch
.
mha_return_lse
:
output
,
lse
,
*
rest
=
output
lse
=
torch
.
transpose
(
lse
,
0
,
1
).
contiguous
()
return
output
,
lse
return
output
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
if
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
):
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
# Do multi-head attention with chunked prefix cache
if
forward_batch
.
attn_attend_prefix_cache
:
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
...
...
@@ -843,39 +851,32 @@ class FlashAttentionBackend(AttentionBackend):
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
**
kwargs
,
)
else
:
# MHA for extend part of sequence without attending prefix kv cache
cu_seqlens_k
=
(
metadata
.
cu_seqlens_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
cu_seqlens_k
)
max_seqlen_k
=
(
metadata
.
max_seq_len_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
max_seq_len_k
)
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
)
...
...
@@ -985,10 +986,16 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
# if not self.use_mla:
if
k_rope
is
None
:
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
...
...
@@ -1030,7 +1037,8 @@ class FlashAttentionBackend(AttentionBackend):
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
...
...
@@ -1044,7 +1052,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
if
not
self
.
use_mla
:
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
...
...
@@ -1096,65 +1103,62 @@ class FlashAttentionBackend(AttentionBackend):
**
kwargs
,
)
else
:
cu_seqlens_q
=
metadata
.
cu_seqlens_q
max_seqlen_q
=
metadata
.
max_seq_len_q
page_table
=
metadata
.
page_table
cache_seqlens
=
metadata
.
cache_seqlens_int32
cu_seqlens_k
=
metadata
.
cu_seqlens_k
max
_seqlen
_q
=
metadata
.
max
_seq
_
len
_q
q_reshaped
=
q
.
contiguous
()
.
view
(
-
1
,
layer
.
tp_
q
_head_num
,
layer
.
head_dim
cache
_seqlen
s
=
metadata
.
cache
_seqlen
s_int32
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_
k
_head_num
,
layer
.
head_dim
)
# Default: single-token self-attention
result
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
(
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
if
layer
.
is_cross_attention
:
page_table
=
metadata
.
encoder_page_table
cache_seqlens
=
metadata
.
encoder_lens_int32
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
if
max_seqlen_q
>
1
:
result
=
flash_attn_varlen_func
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
else
:
o
=
result
result
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
o
=
result
else
:
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
...
...
python/sglang/srt/layers/attention/flashattention_interface.py
0 → 100644
View file @
a1175a4e
from
flash_attn
import
(
flash_attn_varlen_func
as
flash_attn_varlen_func_interface
,
flash_attn_with_kvcache
as
flash_attn_with_kvcache_interface
)
from
typing
import
Optional
,
Union
import
torch
def
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
=
None
,
v
=
None
,
qv
=
None
,
rotary_cos
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[
int
,
torch
.
Tensor
]]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
page_table
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_k_new
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
rotary_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
q_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
k_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
v_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
attention_chunk
:
Optional
[
int
]
=
None
,
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
scheduler_metadata
=
None
,
num_splits
=
0
,
# Can be tuned for speed
pack_gqa
=
None
,
# Can be tuned for speed
sm_margin
=
0
,
# Can be tuned if some SMs are used for communication
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_with_kvcache_interface
(
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
k_cache
=
k_cache
.
view
(
q
.
dtype
),
v_cache
=
v_cache
.
view
(
q
.
dtype
),
block_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
return_softmax_lse
,
num_splits
=
num_splits
,
)
def
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
=
None
,
max_seqlen_k
=
None
,
seqused_q
=
None
,
seqused_k
=
None
,
page_table
=
None
,
softmax_scale
=
None
,
causal
=
False
,
qv
=
None
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
window_size
=
(
-
1
,
-
1
),
attention_chunk
=
0
,
softcap
=
0.0
,
num_splits
=
1
,
pack_gqa
=
None
,
sm_margin
=
0
,
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_varlen_func_interface
(
q
=
q
,
k
=
k
.
view
(
q
.
dtype
),
v
=
v
.
view
(
q
.
dtype
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
return_attn_probs
=
return_softmax_lse
,
softcap
=
softcap
,
)
\ No newline at end of file
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
a1175a4e
...
...
@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
from
sgl_kernel.flash_mla
import
dcu_create_flashmla_kv_indices
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
use_sglang_create_flashmla_kv_indices_triton
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
bs
=
forward_batch
.
batch_size
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
...
...
@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
...
...
@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
...
...
@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
else
:
super
().
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
...
...
python/sglang/srt/layers/attention/lightop_concat.py
0 → 100644
View file @
a1175a4e
from
__future__
import
annotations
import
warnings
import
torch
from
sglang.srt.utils
import
get_bool_env_var
,
direct_register_custom_op
_USE_OPT_CAT
=
get_bool_env_var
(
"SGLANG_USE_OPT_CAT"
)
if
_USE_OPT_CAT
:
try
:
from
lightop
import
ds_cat
# type: ignore
except
ImportError
:
# pragma: no cover
ds_cat
=
None
warnings
.
warn
(
"SGLANG_USE_OPT_CAT 已开启但无法导入 lightop.ds_cat,退回 torch.cat"
)
else
:
ds_cat
=
None
# TODO: 单独注册有些问题
def
ds_cat_wrapper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
ds_cat
(
A
,
B
,
C
,
mode
)
return
C
def
ds_cat_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
# 使用标准cat作为fake实现
return
torch
.
cat
([
A
,
B
],
dim
=
dim
)
direct_register_custom_op
(
op_name
=
"ds_cat"
,
op_func
=
ds_cat_wrapper
,
mutates_args
=
[],
# 没有修改参数,只有返回值
fake_impl
=
ds_cat_fake
)
def
concat_decode_opt
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
mode
=
0
if
dim
!=
0
:
return
torch
.
ops
.
sglang
.
ds_cat
(
A
,
B
,
dim
,
mode
)
assert
False
,
"not support"
# def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
python/sglang/srt/layers/attention/nsa_backend.py
View file @
a1175a4e
...
...
@@ -47,7 +47,8 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
else
:
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_with_kvcache
@
dataclass
(
frozen
=
True
)
...
...
python/sglang/srt/layers/attention/xpu_backend.py
View file @
a1175a4e
...
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
class
XPUAttentionBackend
(
AttentionBackend
):
...
...
python/sglang/srt/layers/layernorm.py
View file @
a1175a4e
...
...
@@ -160,21 +160,53 @@ class RMSNorm(CustomOp):
return
output
,
residual_out
return
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
not
x
.
is_contiguous
():
# NOTE: Remove this if aiter kernel supports discontinuous input
x
=
x
.
contiguous
()
if
residual
is
not
None
:
if
_vllm_version
<
Version
(
"0.9"
):
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
try
:
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
else
:
residual_out
=
torch
.
empty_like
(
x
)
except
TypeError
:
output
=
torch
.
empty_like
(
x
)
residual_out
=
torch
.
empty_like
(
x
)
fused_add_rms_norm
(
output
,
x
,
...
...
@@ -184,10 +216,13 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
,
)
return
output
,
residual_out
out
=
torch
.
empty_like
(
x
)
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/linear.py
View file @
a1175a4e
...
...
@@ -45,6 +45,18 @@ _is_hip = is_hip()
_disable_hip_linear_quant
=
_is_hip
and
get_bool_env_var
(
"SGLANG_ROCM_DISABLE_LINEARQUANT"
)
_use_fused_rms_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_RMS_QUANT"
)
_use_fused_silu_mul_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_SILU_MUL_QUANT"
)
if
_use_fused_rms_quant
:
try
:
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
except
Exception
as
e
:
print
(
f
"Error: Import fused rmsquant error:
{
e
}
"
)
if
_use_fused_silu_mul_quant
:
try
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
except
Exception
as
e
:
print
(
f
"Error: Import fused silu_mul_quant error:
{
e
}
"
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -1360,7 +1372,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters.
param
.
load_row_parallel_weight
(
loaded_weight
)
def
forward
(
self
,
input_
,
skip_all_reduce
=
False
):
def
forward
(
self
,
input_
,
skip_all_reduce
=
False
,
use_fused_silu_mul_quant
=
False
):
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
...
...
@@ -1374,10 +1386,19 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
with
use_symmetric_memory
(
get_tp_group
(),
disabled
=
not
is_allocation_symmetric
()
):
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
use_fused_silu_mul_quant
:
xq
,
xs
=
lm_fuse_silu_mul_quant
(
input_parallel
)
silu_quant_args
=
[
xq
,
xs
]
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
,
silu_quant_args
=
silu_quant_args
)
sm
.
tag
(
output_parallel
)
else
:
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
sm
.
tag
(
output_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
skip_all_reduce
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
...
...
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
a1175a4e
...
...
@@ -2,7 +2,6 @@ import logging
import
torch
import
triton
from
sglang.srt.utils
import
ceil_div
,
is_cuda
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -1015,196 +1014,133 @@ def zero_experts_compute_triton(
return
output
from
triton.language.extra
import
libdevice
from
typing
import
Optional
@
triton
.
jit
def
compute_problem_sizes_w4a8_kernel
(
masked_m_ptr
,
problem_sizes1_ptr
,
problem_sizes2_ptr
,
n
,
k
,
num_experts
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
pid
<
num_experts
final_occurrences
=
tl
.
load
(
masked_m_ptr
+
pid
,
mask
=
mask
,
other
=
0
)
ps1_idx_0
=
pid
*
3
ps1_idx_1
=
ps1_idx_0
+
1
ps1_idx_2
=
ps1_idx_0
+
2
ps2_idx_0
=
pid
*
3
ps2_idx_1
=
ps2_idx_0
+
1
ps2_idx_2
=
ps2_idx_0
+
2
ps1_mask_0
=
ps1_idx_0
<
num_experts
*
3
ps1_mask_1
=
ps1_idx_1
<
num_experts
*
3
ps1_mask_2
=
ps1_idx_2
<
num_experts
*
3
ps2_mask_0
=
ps2_idx_0
<
num_experts
*
3
ps2_mask_1
=
ps2_idx_1
<
num_experts
*
3
ps2_mask_2
=
ps2_idx_2
<
num_experts
*
3
tl
.
store
(
problem_sizes1_ptr
+
ps1_idx_0
,
2
*
n
,
mask
=
ps1_mask_0
)
tl
.
store
(
problem_sizes1_ptr
+
ps1_idx_1
,
final_occurrences
,
mask
=
ps1_mask_1
)
tl
.
store
(
problem_sizes1_ptr
+
ps1_idx_2
,
k
,
mask
=
ps1_mask_2
)
tl
.
store
(
problem_sizes2_ptr
+
ps2_idx_0
,
k
,
mask
=
ps2_mask_0
)
tl
.
store
(
problem_sizes2_ptr
+
ps2_idx_1
,
final_occurrences
,
mask
=
ps2_mask_1
)
tl
.
store
(
problem_sizes2_ptr
+
ps2_idx_2
,
n
,
mask
=
ps2_mask_2
)
def
compute_problem_sizes_w4a8
(
masked_m
,
problem_sizes1
,
problem_sizes2
,
n
,
k
,
num_experts
):
BLOCK_SIZE
=
256
grid
=
lambda
meta
:
(
triton
.
cdiv
(
num_experts
,
meta
[
"BLOCK_SIZE"
]),)
compute_problem_sizes_w4a8_kernel
[
grid
](
masked_m
,
problem_sizes1
,
problem_sizes2
,
n
,
k
,
num_experts
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
problem_sizes1
,
problem_sizes2
def
deepep_ll_get_cutlass_w4a8_moe_mm_data
(
masked_m
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
def
_per_token_quant_int8_one_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
problem_sizes1
,
problem_sizes2
=
compute_problem_sizes_w4a8
(
masked_m
,
problem_sizes1
,
problem_sizes2
,
n
,
k
,
num_experts
)
return
(
problem_sizes1
.
to
(
torch
.
int32
),
problem_sizes2
.
to
(
torch
.
int32
),
)
row_id
=
tl
.
program_id
(
0
)
if
tokens_per_expert_ptr
is
not
None
:
e
=
row_id
//
T_dim
t
=
row_id
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
return
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
@
triton
.
jit
def
_silu_and_mul_post_per_tensor_quant_kernel
(
input_ptr
,
stride_input_expert
,
stride_input_token
,
stride_input_dim
,
output_ptr
,
stride_output_expert
,
stride_output_token
,
stride_output_dim
,
def
_per_token_quant_int8_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
masked_m_ptr
,
inner_dim
,
fp8_max
,
fp8_min
,
BLOCK_N
:
tl
.
constexpr
,
NUM_STAGE
:
tl
.
constexpr
,
stride_x
,
stride_xq
,
N
,
E_dim
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
"""
Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
Shape:
input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D]
output: [E, T_padded, D], dtype=float8_e4m3fn
"""
expert_id
=
tl
.
program_id
(
2
)
block_id_token
=
tl
.
program_id
(
1
)
block_id_dim
=
tl
.
program_id
(
0
)
num_token_blocks
=
tl
.
num_programs
(
1
)
token_num_cur_expert
=
tl
.
load
(
masked_m_ptr
+
expert_id
)
scale
=
1.0
/
tl
.
load
(
scale_ptr
).
to
(
tl
.
float32
)
stride_input_expert
=
tl
.
cast
(
stride_input_expert
,
tl
.
int32
)
stride_output_expert
=
tl
.
cast
(
stride_output_expert
,
tl
.
int32
)
stride_input_token
=
tl
.
cast
(
stride_input_token
,
tl
.
int32
)
stride_output_token
=
tl
.
cast
(
stride_output_token
,
tl
.
int32
)
offset_d
=
block_id_dim
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_d
=
offset_d
<
inner_dim
# base pointers for current expert and dim block
input_base_offs
=
input_ptr
+
expert_id
*
stride_input_expert
+
offset_d
output_base_offs
=
output_ptr
+
expert_id
*
stride_output_expert
+
offset_d
for
token_idx
in
tl
.
range
(
block_id_token
,
token_num_cur_expert
,
num_token_blocks
,
num_stages
=
NUM_STAGE
):
gate_ptr
=
input_base_offs
+
token_idx
*
stride_input_token
up_ptr
=
gate_ptr
+
inner_dim
gate
=
tl
.
load
(
gate_ptr
,
mask
=
mask_d
,
other
=
0.0
).
to
(
tl
.
float32
)
up
=
tl
.
load
(
up_ptr
,
mask
=
mask_d
,
other
=
0.0
).
to
(
tl
.
float32
)
# SiLU: x * sigmoid(x)
gate
=
gate
/
(
1
+
tl
.
exp
(
-
gate
))
gate
=
gate
.
to
(
input_ptr
.
dtype
.
element_ty
)
gate_up
=
up
*
gate
scaled
=
gate_up
*
scale
output_q
=
tl
.
clamp
(
scaled
,
fp8_min
,
fp8_max
).
to
(
output_ptr
.
dtype
.
element_ty
)
out_ptr
=
output_base_offs
+
token_idx
*
stride_output_token
tl
.
store
(
out_ptr
,
output_q
,
mask
=
mask_d
)
def
silu_and_mul_masked_post_per_tensor_quant_fwd
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Fused SiLU + Mul + Per-Tensor Quantization to FP8.
Args:
input: [expert_num, token_num_padded, 2 * inner_dim]
output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
masked_m: [expert_num], actual token count for each expert
scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
Returns:
output tensor
"""
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
assert
output
.
dtype
==
torch
.
float8_e4m3fn
assert
input
.
ndim
==
3
assert
input
.
shape
[
0
]
==
masked_m
.
shape
[
0
]
assert
input
.
shape
[
-
1
]
%
2
==
0
assert
scale
.
numel
()
==
1
or
scale
.
shape
[
0
]
==
input
.
shape
[
0
]
expert_num
=
input
.
shape
[
0
]
# 3584
inner_dim
=
input
.
shape
[
-
1
]
//
2
BLOCK_N
=
256
BLOCK_M
=
64
if
expert_num
<
4
else
32
NUM_STAGES
=
3
hidden_dim_split_block_num
=
triton
.
cdiv
(
inner_dim
,
BLOCK_N
)
grid
=
(
hidden_dim_split_block_num
,
BLOCK_M
,
expert_num
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
_silu_and_mul_post_per_tensor_quant_kernel
[
grid
](
input
,
*
input
.
stride
(),
output
,
*
output
.
stride
(),
scale
,
masked_m
,
inner_dim
,
fp8_max
,
fp8_min
,
BLOCK_N
=
BLOCK_N
,
NUM_STAGE
=
NUM_STAGES
,
)
return
output
token_idx_start
=
tl
.
program_id
(
0
)
grid_size
=
tl
.
num_programs
(
0
)
num_total_tokens
=
E_dim
*
T_dim
for
token_idx
in
range
(
token_idx_start
,
num_total_tokens
,
grid_size
):
is_valid_token
=
True
if
tokens_per_expert_ptr
is
not
None
:
e
=
token_idx
//
T_dim
t
=
token_idx
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
is_valid_token
=
False
if
is_valid_token
:
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
token_idx
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
token_idx
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
token_idx
,
scale_x
)
def
per_token_quant_int8_triton_opt
(
x
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
):
if
x
.
dim
()
!=
3
:
raise
ValueError
(
f
"Input must be 3D [E, T, H], but got
{
x
.
shape
}
"
)
E
,
T
,
H
=
x
.
shape
N
=
H
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
num_warps
=
1
num_tokens
=
E
*
T
grid_opt
=
num_tokens
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
grid_opt
=
max
(
1
,
num_tokens
//
(
T
//
256
))
_per_token_quant_int8_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
E_dim
=
E
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
else
:
_per_token_quant_int8_one_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
a1175a4e
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
collections
import
defaultdict
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
import
torch.distributed
as
dist
from
sglang.srt
import
single_batch_overlap
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.moe
import
(
get_deepep_mode
,
get_moe_a2a_backend
,
get_moe_runner_backend
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
ep_gather
,
ep_scatter
,
silu_and_mul_masked_post_quant_fwd
,
tma_align_input_scale
,
per_token_quant_int8_triton_opt
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.token_dispatcher.deepep
import
(
...
...
@@ -23,7 +33,8 @@ from sglang.srt.layers.quantization.fp8 import Fp8Config
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
,
W4AFp8MoEMethod
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_npu
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
,
direct_register_custom_op
from
sglang.srt.utils.offloader
import
get_offloader
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher
import
(
...
...
@@ -31,6 +42,8 @@ if TYPE_CHECKING:
DeepEPNormalDispatchOutput
,
DispatchOutput
,
)
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
,
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_w8a8_gemm_nt_contig_asm
,
fuse_silu_mul_quant
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
_is_hip
=
is_hip
()
_is_npu
=
is_npu
()
...
...
@@ -46,8 +59,400 @@ if _use_aiter:
logger
=
logging
.
getLogger
(
__name__
)
#------ custom op for lightop
def
m_grouped_w4a8_gemm_nt_masked_wrapper
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
m_grouped_w4a8_gemm_nt_masked
(
(
a0
,
a1
),
(
b0
,
b1
),
d
,
masked_m
,
expected_m_per_group
,
config
=
{
"MODE"
:
1000
,}
)
def
m_grouped_w4a8_gemm_nt_masked_fake
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
d
def
m_grouped_w8a8_gemm_nt_masked_wrapper
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
m_grouped_w8a8_gemm_nt_masked
(
(
a0
,
a1
),
(
b0
,
b1
),
d
,
masked_m
,
expected_m_per_group
,
config
=
{
"MODE"
:
1000
,}
)
def
m_grouped_w8a8_gemm_nt_masked_fake
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
d
def
fuse_silu_mul_quant_ep_wrapper
(
input
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_tokens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
topk
:
int
=
1
,
expect_m
:
int
=-
1
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
fuse_silu_mul_quant_ep
(
input
,
tokens_per_expert
,
num_local_tokens_tensor
,
topk
,
expect_m
)
def
fuse_silu_mul_quant_ep_fake
(
input
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_tokens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
topk
:
int
=
1
,
expect_m
:
int
=-
1
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
E
,
T
,
H
=
input
.
shape
d
=
H
//
2
output
=
torch
.
empty
(
E
,
T
,
d
,
dtype
=
torch
.
int8
,
device
=
input
.
device
)
scales
=
torch
.
empty
((
E
,
T
,
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
return
output
,
scales
direct_register_custom_op
(
op_name
=
"m_grouped_w4a8_gemm_nt_masked"
,
op_func
=
m_grouped_w4a8_gemm_nt_masked_wrapper
,
mutates_args
=
[],
fake_impl
=
m_grouped_w4a8_gemm_nt_masked_fake
)
direct_register_custom_op
(
op_name
=
"m_grouped_w8a8_gemm_nt_masked"
,
op_func
=
m_grouped_w8a8_gemm_nt_masked_wrapper
,
mutates_args
=
[],
fake_impl
=
m_grouped_w8a8_gemm_nt_masked_fake
)
direct_register_custom_op
(
op_name
=
"fuse_silu_mul_quant_ep"
,
op_func
=
fuse_silu_mul_quant_ep_wrapper
,
mutates_args
=
[],
fake_impl
=
fuse_silu_mul_quant_ep_fake
)
#------
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@
torch
.
compile
def
_cast_to_e8m0_with_rounding_up
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
temp
=
x
.
to
(
torch
.
float32
).
view
(
torch
.
int32
)
exp
=
torch
.
bitwise_right_shift
(
temp
,
23
)
mant
=
torch
.
bitwise_and
(
temp
,
0x7FFFFF
)
is_ru
=
torch
.
logical_and
(
torch
.
logical_and
((
mant
>
0
),
(
exp
!=
0xFE
)),
~
torch
.
logical_and
((
exp
==
0
),
(
mant
<=
0x400000
)),
)
exp
=
torch
.
where
(
is_ru
,
exp
+
1
,
exp
)
new_x
=
exp
.
to
(
torch
.
uint8
).
view
(
torch
.
int
)
return
new_x
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
class
EPMoE
(
FusedMoE
):
"""
MoE Expert Parallel Impl
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
gemm1_alpha
:
Optional
[
float
]
=
None
,
gemm1_clamp_limit
:
Optional
[
float
]
=
None
,
with_bias
:
bool
=
False
,
):
super
().
__init__
(
num_experts
=
num_experts
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_fused_shared_experts
=
num_fused_shared_experts
,
layer_id
=
layer_id
,
top_k
=
top_k
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
prefix
=
prefix
,
activation
=
activation
,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor
=
routed_scaling_factor
,
gemm1_alpha
=
gemm1_alpha
,
gemm1_clamp_limit
=
gemm1_clamp_limit
,
with_bias
=
with_bias
,
)
self
.
intermediate_size
=
intermediate_size
if
isinstance
(
quant_config
,
Fp8Config
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
quant_method
.
quant_config
.
weight_block_size
if
self
.
use_block_quant
else
None
)
self
.
use_fp8_w8a8
=
True
self
.
fp8_dtype
=
torch
.
float8_e4m3fn
self
.
activation_scheme
=
quant_config
.
activation_scheme
self
.
use_w4a8_marlin
=
False
self
.
use_w8a8_marlin
=
False
elif
isinstance
(
quant_config
,
SlimQuantW4A8Int8MarlinConfig
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
quant_method
.
quant_config
.
weight_block_size
if
self
.
use_block_quant
else
None
)
self
.
use_fp8_w8a8
=
False
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
True
self
.
use_w8a8_marlin
=
False
elif
isinstance
(
quant_config
,
SlimQuantCompressedTensorsMarlinConfig
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
quant_method
.
quant_config
.
weight_block_size
if
self
.
use_block_quant
else
None
)
self
.
use_fp8_w8a8
=
False
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
False
self
.
use_w8a8_marlin
=
True
else
:
self
.
use_fp8_w8a8
=
False
self
.
use_block_quant
=
False
self
.
block_shape
=
None
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
False
self
.
use_w8a8_marlin
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
return
self
.
forward_deepgemm
(
hidden_states
,
topk_output
)
else
:
return
super
().
forward
(
hidden_states
,
topk_output
)
def
forward_deepgemm
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
):
self
.
w13_weight_fp8
=
(
self
.
w13_weight
,
(
self
.
w13_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w13_weight_scale
),
)
self
.
w2_weight_fp8
=
(
self
.
w2_weight
,
self
.
w2_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w2_weight_scale
,
)
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
hidden_states_shape
=
hidden_states
.
shape
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_device
=
hidden_states
.
device
topk_weights
,
topk_ids
,
_
=
topk_output
if
not
self
.
use_block_quant
:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size
=
128
w13_weight_scale_n
=
2
*
(
(
self
.
intermediate_size
+
scale_block_size
-
1
)
//
scale_block_size
)
w13_weight_scale_k
=
(
hidden_states_shape
[
-
1
]
+
scale_block_size
-
1
)
//
scale_block_size
w13_weight_scale
=
(
self
.
w13_weight_scale
.
unsqueeze
(
1
)
.
repeat_interleave
(
w13_weight_scale_n
,
dim
=
1
)
.
unsqueeze
(
2
)
.
repeat_interleave
(
w13_weight_scale_k
,
dim
=
2
)
)
self
.
w13_weight_fp8
=
(
self
.
w13_weight
,
w13_weight_scale
,
)
w2_weight_scale_n
=
(
hidden_states_shape
[
-
1
]
+
scale_block_size
-
1
)
//
scale_block_size
w2_weight_scale_k
=
(
self
.
intermediate_size
+
scale_block_size
-
1
)
//
scale_block_size
w2_weight_scale
=
(
self
.
w2_weight_scale
.
unsqueeze
(
1
)
.
repeat_interleave
(
w2_weight_scale_n
,
dim
=
1
)
.
unsqueeze
(
2
)
.
repeat_interleave
(
w2_weight_scale_k
,
dim
=
2
)
)
self
.
w2_weight_fp8
=
(
self
.
w2_weight
,
w2_weight_scale
,
)
# PreReorder
m_max
,
masked_m
,
expected_m
,
src2dst
,
gateup_input
,
gateup_input_scale
=
(
moe_ep_deepgemm_preprocess
(
topk_ids
,
self
.
num_experts
,
hidden_states
,
self
.
top_k
,
self
.
start_expert_id
,
self
.
end_expert_id
,
self
.
block_shape
,
)
)
dispose_tensor
(
hidden_states
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
b
,
s_mn
,
s_k
=
gateup_input_scale
.
shape
assert
(
s_mn
%
4
==
0
and
s_k
%
4
==
0
),
f
"scales must be aligned to 4, but got (
{
b
}
,
{
s_mn
}
,
{
s_k
}
)"
# GroupGemm-0
gateup_input_fp8
=
(
gateup_input
,
(
_cast_to_e8m0_with_rounding_up
(
gateup_input_scale
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
deep_gemm_wrapper
.
get_mn_major_tma_aligned_tensor
(
gateup_input_scale
)
),
)
num_groups
,
m
,
k
=
gateup_input_fp8
[
0
].
size
()
n
=
self
.
w13_weight
.
size
(
1
)
gateup_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
gateup_input_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
,
)
del
gateup_input
del
gateup_input_fp8
# Act
down_input
=
torch
.
empty
(
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
],
gateup_output
.
shape
[
2
]
//
2
,
),
device
=
hidden_states_device
,
dtype
=
self
.
fp8_dtype
,
)
scale_block_size
=
128
down_input_scale
=
torch
.
empty
(
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
],
gateup_output
.
shape
[
2
]
//
2
//
scale_block_size
,
),
device
=
hidden_states_device
,
dtype
=
torch
.
float32
,
)
silu_and_mul_masked_post_quant_fwd
(
gateup_output
,
down_input
,
down_input_scale
,
scale_block_size
,
masked_m
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
del
gateup_output
# GroupGemm-1
n
=
self
.
w2_weight
.
size
(
1
)
down_input_fp8
=
(
down_input
,
(
down_input_scale
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
deep_gemm_wrapper
.
get_mn_major_tma_aligned_tensor
(
down_input_scale
)
),
)
down_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
down_input_fp8
,
self
.
w2_weight_fp8
,
down_output
,
masked_m
,
expected_m
,
)
del
down_input
del
down_input_fp8
# PostReorder
output
=
torch
.
empty
(
hidden_states_shape
,
dtype
=
hidden_states_dtype
,
device
=
hidden_states_device
)
post_reorder_triton_kernel
[(
hidden_states_shape
[
0
],)](
down_output
,
output
,
src2dst
,
topk_ids
,
topk_weights
,
self
.
start_expert_id
,
self
.
end_expert_id
,
self
.
top_k
,
hidden_states_shape
[
1
],
m_max
*
self
.
start_expert_id
,
BLOCK_SIZE
=
512
,
)
if
self
.
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
self
.
moe_runner_config
.
routed_scaling_factor
return
output
class
DeepEPMoE
(
Fused
MoE
):
class
DeepEPMoE
(
EP
MoE
):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
Mooncake EP shares the same class, as they expose the same interface.
...
...
@@ -112,11 +517,28 @@ class DeepEPMoE(FusedMoE):
self
.
deepep_mode
=
get_deepep_mode
()
if
self
.
deepep_mode
.
enable_low_latency
()
and
not
_is_npu
:
# NPU supports low_latency deepep without deepgemm
assert
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
),
f
"DeepEP
{
self
.
deepep_mode
}
mode requires deep_gemm"
# TODO: move to the beginning of the file
from
sglang.srt.distributed.parallel_state
import
get_tp_group
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
self
.
deepep_dispatcher
=
MaybeTboDeepEPDispatcher
(
group
=
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
num_experts
=
self
.
num_experts
,
num_local_experts
=
self
.
num_local_experts
,
hidden_size
=
hidden_size
,
params_dtype
=
params_dtype
,
deepep_mode
=
self
.
deepep_mode
,
async_finish
=
True
,
# TODO
return_recv_hook
=
True
,
)
# if self.deepep_mode.enable_low_latency() and not _is_npu:
# # NPU supports low_latency deepep without deepgemm
# assert (
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if
_use_aiter
:
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
...
...
@@ -130,6 +552,23 @@ class DeepEPMoE(FusedMoE):
)
# the last one is invalid rank_id
self
.
expert_mask
[:
-
1
]
=
1
# elif not _is_npu:
# self.w13_weight_fp8 = (
# self.w13_weight,
# (
# self.w13_weight_scale_inv
# if self.use_block_quant
# else self.w13_weight_scale
# ),
# )
# self.w2_weight_fp8 = (
# self.w2_weight,
# (
# self.w2_weight_scale_inv
# if self.use_block_quant
# else self.w2_weight_scale
# ),
# )
def
forward
(
self
,
...
...
@@ -189,35 +628,39 @@ class DeepEPMoE(FusedMoE):
output
=
self
.
forward_aiter
(
dispatch_output
)
elif
_is_npu
:
assert
DispatchOutputChecker
.
format_is_deepep
(
dispatch_output
)
output
=
self
.
forward_npu
(
dispatch_output
)
elif
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
if
self
.
use_w4afp8
:
output
=
self
.
forward_cutlass_w4afp8
(
dispatch_output
)
return
self
.
forward_npu
(
dispatch_output
)
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
elif
self
.
use_w4a8_marlin
:
return
self
.
forward_deepgemm_w4a8_marlin_contiguous
(
dispatch_output
)
elif
self
.
use_w8a8_marlin
:
return
self
.
forward_groupgemm_w8a8_marlin_contiguous
(
dispatch_output
)
else
:
assert
False
,
"forward_deepgemm_contiguous is deprecated"
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
if
(
get_moe_runner_backend
().
is_flashinfer_cutedsl
()
and
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
):
output
=
self
.
forward_flashinfer_cutedsl
(
dispatch_output
,
down_gemm_overlap_args
=
down_gemm_overlap_args
raise
ValueError
(
f
"Dispatch output is not supported"
)
elif
self
.
use_w4afp8
:
output
=
self
.
forward_cutlass_w4afp8_masked
(
dispatch_output
)
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
if
self
.
use_w4a8_marlin
:
return
self
.
forward_groupgemm_w4a8_marlin_masked
(
dispatch_output
)
elif
self
.
use_w8a8_marlin
:
return
self
.
forward_groupgemm_w8a8_marlin_masked
(
dispatch_output
)
else
:
assert
False
,
"forward_deepgemm_masked is deprecated"
combine_input_wrapper
=
(
DeepEPNormalCombineInput
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
)
else
DeepEPLLCombineInput
)
return
combine_input_wrapper
(
hidden_states
=
output
,
topk_ids
=
dispatch_output
.
topk_ids
,
topk_weights
=
dispatch_output
.
topk_weights
,
)
if
(
get_moe_runner_backend
().
is_flashinfer_cutedsl
()
and
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
):
return
self
.
forward_flashinfer_cutedsl
(
dispatch_output
,
down_gemm_overlap_args
=
down_gemm_overlap_args
)
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
assert
down_gemm_overlap_args
is
None
return
self
.
forward_deepgemm_masked
(
dispatch_output
)
else
:
raise
ValueError
(
f
"Dispatch output format
{
dispatch_output
.
format
}
is not supported"
)
def
combine
(
self
,
...
...
@@ -267,6 +710,292 @@ class DeepEPMoE(FusedMoE):
expert_mask
=
self
.
expert_mask
,
)
def
forward_deepgemm_w4a8_marlin_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
hidden_states
,
hidden_states_scale
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
(
dispatch_output
)
#hidden_states_int8, hidden_states_scale = hidden_states_int8
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
all_tokens
=
sum
(
num_recv_tokens_per_expert
)
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
expert_output
=
self
.
quant_method
.
apply_ep
(
x
=
hidden_states
,
w1
=
self
.
w13_weight
,
w2
=
self
.
w2_weight
,
topk_ids
=
topk_idx
,
topk_weights
=
topk_weights
,
global_num_experts
=
self
.
moe_runner_config
.
num_experts
,
expert_map
=
self
.
expert_map
,
activation
=
self
.
moe_runner_config
.
activation
,
apply_router_weight_on_input
=
self
.
moe_runner_config
.
apply_router_weight_on_input
,
use_nn_moe
=
False
,
w1_scale
=
self
.
w13_weight_scale
,
w2_scale
=
self
.
w2_weight_scale
,
routed_scaling_factor
=
self
.
moe_runner_config
.
routed_scaling_factor
,
)
return
expert_output
def
forward_groupgemm_w8a8_marlin_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
hidden_states
,
hidden_states_scale
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
all_tokens
=
sum
(
num_recv_tokens_per_expert
)
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
device
=
hidden_states
.
device
M
=
hidden_states
.
shape
[
0
]
K
=
hidden_states
.
shape
[
1
]
topk
=
topk_idx
.
shape
[
1
]
active_experts
=
set
()
token_expert_pos
=
[
None
]
*
M
for
t
in
range
(
M
):
lst
=
[]
for
pos
in
range
(
topk
):
e
=
int
(
topk_idx
[
t
,
pos
].
item
())
if
e
>=
0
:
lst
.
append
((
e
,
pos
))
active_experts
.
add
(
e
)
token_expert_pos
[
t
]
=
lst
active_experts
=
sorted
(
list
(
active_experts
))
num_active
=
len
(
active_experts
)
if
num_active
==
0
:
return
hidden_states
.
bfloat16
()
counts
=
defaultdict
(
int
)
for
t
in
range
(
M
):
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
counts
[
e
]
+=
1
per_expert_block
=
{}
for
e
in
active_experts
:
cnt
=
counts
.
get
(
e
,
0
)
if
cnt
<=
0
:
per_expert_block
[
e
]
=
0
else
:
needed
=
((
cnt
+
256
-
1
)
//
256
)
*
256
# next multiple of 256
per_expert_block
[
e
]
=
max
(
256
,
needed
)
expert_slot_offset
=
{}
offset
=
0
for
e
in
active_experts
:
expert_slot_offset
[
e
]
=
offset
offset
+=
per_expert_block
[
e
]
pad_M
=
offset
hidden_states_packed
=
torch
.
zeros
((
pad_M
,
K
),
device
=
device
,
dtype
=
hidden_states
.
dtype
)
m_indices
=
torch
.
full
((
pad_M
,),
-
1
,
device
=
device
,
dtype
=
torch
.
int32
)
slot_counters
=
{
e
:
0
for
e
in
active_experts
}
token_row_weight_list
=
{
t
:
[]
for
t
in
range
(
M
)}
for
t
in
range
(
M
):
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
start
=
expert_slot_offset
[
e
]
slot
=
slot_counters
[
e
]
if
slot
>=
per_expert_block
[
e
]:
raise
RuntimeError
(
f
"Internal error: expert
{
e
}
slot
{
slot
}
>= block
{
per_expert_block
[
e
]
}
"
)
row
=
start
+
slot
hidden_states_packed
[
row
]
=
hidden_states
[
t
]
m_indices
[
row
]
=
int
(
e
)
slot_counters
[
e
]
+=
1
w
=
topk_weights
[
t
,
pos
].
to
(
device
=
device
)
w_f
=
w
.
float
()
if
w
.
dtype
!=
torch
.
float32
else
w
token_row_weight_list
[
t
].
append
((
row
,
w_f
))
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states_packed
)
N
=
self
.
w13_weight
.
size
(
1
)
gateup_output
=
torch
.
empty
((
pad_M
,
N
*
16
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a1_all
,
q_a1_scale
),
(
self
.
w13_weight
,
self
.
w13_weight_scale
),
gateup_output
,
m_indices
,
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant
(
gateup_output
)
down_output
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a2_all
,
q_a2_scale
),
(
self
.
w2_weight
,
self
.
w2_weight_scale
),
down_output
,
m_indices
,
)
result
=
torch
.
zeros
((
M
,
K
),
device
=
device
,
dtype
=
down_output
.
dtype
)
for
t
in
range
(
M
):
pairs
=
token_row_weight_list
[
t
]
if
not
pairs
:
continue
acc
=
None
for
(
row
,
w
)
in
pairs
:
vec
=
down_output
[
row
].
float
()
weighted
=
vec
*
w
acc
=
weighted
if
acc
is
None
else
(
acc
+
weighted
)
result
[
t
]
=
acc
.
to
(
result
.
dtype
)
return
result
def
forward_deepgemm_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
(
hidden_states
,
hidden_states_scale
,
topk_ids
,
topk_weights
,
num_recv_tokens_per_expert
,
)
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
if
num_recv_tokens_per_expert
is
None
:
return
hidden_states
.
bfloat16
()
all_tokens
=
sum
(
num_recv_tokens_per_expert
)
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
M
,
K
=
hidden_states
.
size
()
N
=
self
.
w13_weight
.
size
(
1
)
scale_block_size
=
128
w13_weight_fp8
=
(
self
.
w13_weight
,
(
self
.
w13_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w13_weight_scale
),
)
w2_weight_fp8
=
(
self
.
w2_weight
,
(
self
.
w2_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w2_weight_scale
),
)
hidden_states_shape
=
hidden_states
.
shape
hidden_states_device
=
hidden_states
.
device
hidden_states_dtype
=
hidden_states
.
dtype
input_tensor
=
[
torch
.
empty
(
(
all_tokens
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
),
(
# TODO check whether need `zeros`
torch
.
zeros
(
(
ceil_div
(
K
//
128
,
4
),
all_tokens
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int
,
).
transpose
(
0
,
1
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
torch
.
empty
(
(
all_tokens
,
K
//
128
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
,
)
),
]
m_indices
=
torch
.
empty
(
all_tokens
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int32
)
output_index
=
torch
.
empty_like
(
topk_ids
)
if
get_offloader
().
forbid_copy_engine_usage
:
num_recv_tokens_per_expert_gpu
=
copy_list_to_gpu_no_ce
(
num_recv_tokens_per_expert
)
else
:
num_recv_tokens_per_expert_gpu
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int32
,
pin_memory
=
True
,
device
=
"cpu"
,
).
cuda
(
non_blocking
=
True
)
expert_start_loc
=
torch
.
empty_like
(
num_recv_tokens_per_expert_gpu
)
ep_scatter
(
hidden_states
,
hidden_states_scale
,
topk_ids
,
num_recv_tokens_per_expert_gpu
,
expert_start_loc
,
input_tensor
[
0
],
input_tensor
[
1
],
m_indices
,
output_index
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
dispose_tensor
(
hidden_states
)
gateup_output
=
torch
.
empty
(
(
all_tokens
,
N
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
)
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
input_tensor
,
w13_weight_fp8
,
gateup_output
,
m_indices
)
del
input_tensor
down_input
=
torch
.
empty
(
(
all_tokens
,
N
//
2
,
),
device
=
gateup_output
.
device
,
dtype
=
torch
.
bfloat16
,
)
silu_and_mul
(
gateup_output
.
view
(
-
1
,
N
),
down_input
)
del
gateup_output
down_output
=
torch
.
empty
(
(
all_tokens
,
K
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
)
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
down_input
,
scale_block_size
,
column_major_scales
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_tma_aligned
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
del
down_input
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
(
down_input_fp8
,
down_input_scale
),
w2_weight_fp8
,
down_output
,
m_indices
,
)
del
down_input_fp8
,
down_input_scale
gather_out
=
torch
.
empty
(
hidden_states_shape
,
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
)
ep_gather
(
down_output
,
topk_ids
,
topk_weights
,
output_index
,
gather_out
)
return
gather_out
def
forward_flashinfer_cutedsl
(
self
,
dispatch_output
:
DeepEPLLDispatchOutput
,
...
...
@@ -296,7 +1025,106 @@ class DeepEPMoE(FusedMoE):
dispatch_output
=
dispatch_output
,
)
def
forward_cutlass_w4afp8_masked
(
def
forward_groupgemm_w4a8_marlin_masked
(
self
,
dispatch_output
:
DeepEPLLOutput
,
):
hidden_states
,
_
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8_triton_opt
(
hidden_states
,
masked_m
)
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
w13_scales
=
self
.
w13_weight_scale
w2_weight
=
self
.
w2_weight
w2_scales
=
self
.
w2_weight_scale
n1
=
w13_scales
.
size
(
1
)
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
# ---- first GEMM ----
torch
.
ops
.
sglang
.
m_grouped_w4a8_gemm_nt_masked
(
q_a1_all
,
q_a1_scale
,
w13_weight
,
w13_scales
,
gateup_output
,
masked_m
,
expected_m
,
)
q_a2_all
,
q_a2_scale
=
torch
.
ops
.
sglang
.
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
# ---- second GEMM ----
n2
=
w2_scales
.
size
(
1
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
torch
.
ops
.
sglang
.
m_grouped_w4a8_gemm_nt_masked
(
q_a2_all
,
q_a2_scale
,
w2_weight
,
w2_scales
,
down_output
,
masked_m
,
expected_m
,
)
return
down_output
def
forward_groupgemm_w8a8_marlin_masked
(
self
,
dispatch_output
:
DeepEPLLOutput
,
):
hidden_states
,
_
,
topk_ids
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8_triton_opt
(
hidden_states
,
masked_m
)
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
w13_scales
=
self
.
w13_weight_scale
w2_weight
=
self
.
w2_weight
w2_scales
=
self
.
w2_weight_scale
n1
=
w13_scales
.
size
(
1
)
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
# ---- first GEMM ----
torch
.
ops
.
sglang
.
m_grouped_w8a8_gemm_nt_masked
(
q_a1_all
,
q_a1_scale
,
w13_weight
,
w13_scales
,
gateup_output
,
masked_m
,
expected_m
,
)
q_a2_all
,
q_a2_scale
=
torch
.
ops
.
sglang
.
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
# ---- second GEMM ----
n2
=
w2_scales
.
size
(
1
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
torch
.
ops
.
sglang
.
m_grouped_w8a8_gemm_nt_masked
(
q_a2_all
,
q_a2_scale
,
w2_weight
,
w2_scales
,
down_output
,
masked_m
,
expected_m
,
)
return
down_output
def
forward_deepgemm_masked
(
self
,
dispatch_output
:
DeepEPLLDispatchOutput
,
):
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
a1175a4e
...
...
@@ -65,7 +65,7 @@ def inplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -84,6 +84,8 @@ def inplace_fused_experts(
gemm1_limit
:
Optional
[
float
]
=
None
,
filter_expert
:
bool
=
True
,
)
->
None
:
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
fused_experts_impl
(
hidden_states
,
w1
,
...
...
@@ -123,7 +125,7 @@ def inplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -161,7 +163,7 @@ def outplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -181,6 +183,8 @@ def outplace_fused_experts(
gemm1_limit
:
Optional
[
float
]
=
None
,
filter_expert
:
bool
=
True
,
)
->
torch
.
Tensor
:
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
return
fused_experts_impl
(
hidden_states
,
w1
,
...
...
@@ -220,7 +224,7 @@ def outplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -273,9 +277,12 @@ def fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
topk_weights
,
topk_ids
,
_
=
topk_output
filter_expert
=
(
moe_runner_config
.
num_experts
is
None
or
moe_runner_config
.
num_experts
!=
moe_runner_config
.
num_local_experts
act_id
=
(
0
if
(
moe_runner_config
.
activation
==
0
or
(
isinstance
(
moe_runner_config
.
activation
,
str
)
and
moe_runner_config
.
activation
.
lower
()
==
"silu"
)
)
else
1
)
if
moe_runner_config
.
inplace
:
assert
not
moe_runner_config
.
no_combine
,
"no combine + inplace makes no sense"
...
...
@@ -287,7 +294,7 @@ def fused_experts(
topk_ids
,
b1
,
b2
,
moe_runner_config
.
activation
,
act_id
,
moe_runner_config
.
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
...
...
@@ -316,7 +323,7 @@ def fused_experts(
topk_ids
,
b1
,
b2
,
moe_runner_config
.
activation
,
act_id
,
moe_runner_config
.
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
...
...
@@ -366,7 +373,7 @@ def fused_experts_impl(
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -386,6 +393,8 @@ def fused_experts_impl(
gemm1_limit
:
Optional
[
float
]
=
None
,
filter_expert
:
bool
=
True
,
):
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
padded_size
=
padding_size
if
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
_use_aiter
:
padded_size
=
0
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment