Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a1175a4e
Commit
a1175a4e
authored
Nov 22, 2025
by
maxiao1
Browse files
Merge remote-tracking branch 'origin/v0.5.4_dev' into sglang_v0.5.5
parents
0c006b88
31653dd9
Changes
62
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2608 additions
and
409 deletions
+2608
-409
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+10
-0
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+173
-3
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+5
-1
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+286
-0
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+13
-5
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+11
-0
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+5
-1
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+679
-0
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
...srt/layers/attention/dual_chunk_flashattention_backend.py
+2
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+131
-127
python/sglang/srt/layers/attention/flashattention_interface.py
...n/sglang/srt/layers/attention/flashattention_interface.py
+96
-0
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+48
-20
python/sglang/srt/layers/attention/lightop_concat.py
python/sglang/srt/layers/attention/lightop_concat.py
+67
-0
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+2
-1
python/sglang/srt/layers/attention/xpu_backend.py
python/sglang/srt/layers/attention/xpu_backend.py
+2
-1
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+44
-9
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+26
-5
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+124
-188
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+865
-37
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+19
-10
No files found.
python/sglang/bench_serving.py
View file @
a1175a4e
...
...
@@ -848,10 +848,12 @@ class BenchmarkMetrics:
mean_ttft_ms
:
float
median_ttft_ms
:
float
std_ttft_ms
:
float
p95_ttft_ms
:
float
p99_ttft_ms
:
float
mean_tpot_ms
:
float
median_tpot_ms
:
float
std_tpot_ms
:
float
p95_tpot_ms
:
float
p99_tpot_ms
:
float
mean_itl_ms
:
float
median_itl_ms
:
float
...
...
@@ -1721,10 +1723,12 @@ def calculate_metrics(
*
1000
,
# ttfts is empty if streaming is not supported by backend
median_ttft_ms
=
np
.
median
(
ttfts
or
0
)
*
1000
,
std_ttft_ms
=
np
.
std
(
ttfts
or
0
)
*
1000
,
p95_ttft_ms
=
np
.
percentile
(
ttfts
or
0
,
95
)
*
1000
,
p99_ttft_ms
=
np
.
percentile
(
ttfts
or
0
,
99
)
*
1000
,
mean_tpot_ms
=
np
.
mean
(
tpots
or
0
)
*
1000
,
median_tpot_ms
=
np
.
median
(
tpots
or
0
)
*
1000
,
std_tpot_ms
=
np
.
std
(
tpots
or
0
)
*
1000
,
p95_tpot_ms
=
np
.
percentile
(
tpots
or
0
,
95
)
*
1000
,
p99_tpot_ms
=
np
.
percentile
(
tpots
or
0
,
99
)
*
1000
,
mean_itl_ms
=
np
.
mean
(
itls
or
0
)
*
1000
,
median_itl_ms
=
np
.
median
(
itls
or
0
)
*
1000
,
...
...
@@ -2052,6 +2056,12 @@ async def benchmark(
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TTFT (ms):"
,
metrics
.
mean_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TTFT (ms):"
,
metrics
.
median_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TTFT (ms):"
,
metrics
.
p99_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P95 TTFT (ms):"
,
metrics
.
p95_ttft_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Time per Output Token (excl. 1st token)"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TPOT (ms):"
,
metrics
.
mean_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TPOT (ms):"
,
metrics
.
median_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TPOT (ms):"
,
metrics
.
p99_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P95 TPOT (ms):"
,
metrics
.
p95_tpot_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Inter-Token Latency"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean ITL (ms):"
,
metrics
.
mean_itl_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median ITL (ms):"
,
metrics
.
median_itl_ms
))
...
...
python/sglang/srt/_custom_ops.py
View file @
a1175a4e
...
...
@@ -4,10 +4,24 @@ from typing import List, Optional, Tuple
import
torch
from
sglang.srt.utils
import
is_hip
,
is_hpu
,
is_npu
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
,
is_npu
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
import
lightop
except
Exception
:
print
(
"INFO: Please install lightop if you want to infer awq of marlin.
\n
"
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
get_bool_env_var
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
"false"
)
use_dcu_custom_allreduce
=
get_bool_env_var
(
"USE_DCU_CUSTOM_ALLREDUCE"
,
default
=
"true"
)
if
not
is_hpu
():
try
:
...
...
@@ -15,6 +29,11 @@ if not is_hpu():
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
if
use_dcu_custom_allreduce
:
try
:
import
vllm._C
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
if
not
is_hip
()
and
not
is_npu
():
custom_op
=
sgl_kernel
.
allreduce
...
...
@@ -54,8 +73,79 @@ if not is_hip() and not is_npu():
)
->
None
:
custom_op
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
elif
is_hip
and
use_dcu_custom_allreduce
:
# custom ar
def
init_custom_ar
(
ipc_tensors
:
list
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
fully_connected
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
fully_connected
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
list
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
tuple
[
list
[
int
],
list
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
list
[
list
[
int
]],
offsets
:
list
[
list
[
int
]])
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_shared_buffer_and_handle
(
size
:
int
)
->
tuple
[
int
,
torch
.
Tensor
]:
return
torch
.
ops
.
_C_custom_ar
.
allocate_shared_buffer_and_handle
(
size
)
def
open_mem_handle
(
mem_handle
:
torch
.
Tensor
):
return
torch
.
ops
.
_C_custom_ar
.
open_mem_handle
(
mem_handle
)
def
free_shared_buffer
(
ptr
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
free_shared_buffer
(
ptr
)
def
read_cache
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
list
[
torch
.
Tensor
],
value_caches
:
list
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
def
write_cache_multi_layers
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
list
[
torch
.
Tensor
],
value_caches
:
list
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
else
:
# ROCM custom allreduce
#
sgl_kernel
ROCM custom allreduce
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
...
...
@@ -163,3 +253,83 @@ def mscclpp_allreduce(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
return
sgl_kernel
.
allreduce
.
mscclpp_allreduce
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
def
triton_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
list
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
,
best_config
)
def
cutlass_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
# m = a.shape[0]
# n = b.shape[1]
# cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
# if current_platform.is_rocm() or not cutlass_compatible_b:
# from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
# triton_scaled_mm)
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
# out = torch.empty((m, n), dtype=out_dtype, device=a.device)
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
triton_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
out_dtype
:
type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda:0"
,
best_config
:
Optional
[
list
]
=
None
,
repeat
:
Optional
[
int
]
=
2
):
return
quant_tools
.
triton_int8_gemm_helper
(
m
,
n
,
k
,
per_token_act_quant
,
per_out_channel_weight_quant
,
use_bias
,
out_dtype
,
device
,
best_config
,
repeat
)
\ No newline at end of file
python/sglang/srt/configs/model_config.py
View file @
a1175a4e
...
...
@@ -635,7 +635,9 @@ class ModelConfig:
"petit_nvfp4"
,
"quark"
,
"mxfp4"
,
"auto-round"
,
"slimquant_w4a8_marlin"
,
"w8a8_int8"
,
"slimquant_marlin"
,
]
optimized_quantization_methods
=
[
"fp8"
,
...
...
@@ -655,6 +657,8 @@ class ModelConfig:
"qoq"
,
"w4afp8"
,
"petit_nvfp4"
,
"slimquant_w4a8_marlin"
,
"slimquant_marlin"
,
]
compatible_quantization_methods
=
{
"modelopt_fp8"
:
[
"modelopt"
],
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
a1175a4e
...
...
@@ -34,6 +34,21 @@ except ImportError:
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
try
:
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
# Use vLLM custom allreduce
ops
.
meta_size
()
elif
ops
.
use_dcu_custom_allreduce
:
ops
.
meta_size
()
else
:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
# noqa: F401
custom_ar
=
True
except
Exception
:
# For CPUs
custom_ar
=
False
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -416,3 +431,274 @@ class CustomAllreduce:
def
__del__
(
self
):
self
.
close
()
class
DCUCustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
,
16
]
# max_size: max supported allreduce size
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
512
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
not
custom_ar
:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger
.
info
(
"Custom allreduce is disabled because "
"of missing custom allreduce library"
)
return
self
.
group
=
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"CustomAllreduce should be attached to a non-NCCL group."
)
if
not
all
(
in_the_same_node_as
(
group
,
source_rank
=
0
)):
# No need to initialize custom allreduce for multi-node case.
logger
.
warning
(
"Custom allreduce is disabled because this process group"
" spans across nodes."
)
return
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
self
.
rank
=
rank
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
if
world_size
>
16
:
return
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
))
return
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
cuda_visible_devices
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
)
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
torch
.
cuda
.
device_count
()))
physical_device_id
=
device_ids
[
device
.
index
]
tensor
=
torch
.
tensor
([
physical_device_id
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
gather_list
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
for
_
in
range
(
world_size
)
]
dist
.
all_gather
(
gather_list
,
tensor
,
group
=
self
.
group
)
physical_device_ids
=
[
t
.
item
()
for
t
in
gather_list
]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# assert current_platform.is_cuda_alike()
# fully_connected = current_platform.is_fully_connected(
# physical_device_ids)
if
_is_cuda
or
_is_hip
:
fully_connected
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
# if world_size > 2 and not fully_connected:
if
not
fully_connected
:
max_size
=
32
*
8192
*
2
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
# logger.warning(
# "Custom allreduce is disabled because it's not supported on"
# " more than two PCIe-only GPUs. To silence this warning, "
# "specify disable_custom_all_reduce=True explicitly.")
# return
logger
.
warning
(
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction."
)
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
_is_hip
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self
.
disabled
=
False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
,
uncached
=
True
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
fully_connected
=
fully_connected
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
fully_connected
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
@
contextmanager
def
capture
(
self
):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try
:
self
.
_IS_CAPTURING
=
True
yield
finally
:
self
.
_IS_CAPTURING
=
False
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
def
register_graph_buffers
(
self
):
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data
=
[[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))]
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
if
self
.
disabled
:
return
False
inp_size
=
inp
.
numel
()
*
inp
.
element_size
()
# custom allreduce requires input byte size to be multiples of 16
if
inp_size
%
16
!=
0
:
return
False
if
not
is_weak_contiguous
(
inp
):
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
return
inp_size
<=
self
.
max_size
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if
self
.
disabled
or
not
self
.
should_custom_ar
(
input
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
,
registered
=
False
)
else
:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
ops
is
not
None
:
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
,
rank
=
self
.
rank
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
,
rank
=
self
.
rank
)
def
__del__
(
self
):
self
.
close
()
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
,
uncached
:
Optional
[
bool
]
=
False
)
->
list
[
int
]:
pointer
,
handle
=
ops
.
allocate_shared_buffer_and_handle
(
size_in_bytes
)
world_size
=
dist
.
get_world_size
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
handles
=
[
None
]
*
world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
pointers
:
list
[
int
]
=
[]
for
i
,
h
in
enumerate
(
handles
):
if
i
==
rank
:
pointers
.
append
(
pointer
)
# type: ignore
else
:
pointers
.
append
(
ops
.
open_mem_handle
(
h
))
return
pointers
@
staticmethod
def
free_shared_buffer
(
pointers
:
list
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
,
rank
:
Optional
[
int
]
=
0
)
->
None
:
if
rank
is
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
if
ops
is
not
None
:
ops
.
free_shared_buffer
(
pointers
[
rank
])
python/sglang/srt/distributed/parallel_state.py
View file @
a1175a4e
...
...
@@ -54,6 +54,7 @@ from sglang.srt.utils import (
is_xpu
,
supports_custom_op
,
)
from
sglang.srt
import
_custom_ops
as
ops
_is_npu
=
is_npu
()
_is_cpu
=
is_cpu
()
...
...
@@ -327,7 +328,7 @@ class GroupCoordinator:
# Lazy import to avoid documentation build error
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
,
CustomAllreduce
,
DCUCustomAllreduce
)
from
sglang.srt.distributed.device_communicators.pymscclpp
import
(
PyMscclppCommunicator
,
...
...
@@ -371,10 +372,17 @@ class GroupCoordinator:
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
try
:
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
if
is_hip
()
and
ops
.
use_dcu_custom_allreduce
:
self
.
ca_comm
=
DCUCustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
else
:
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
max_size
=
ca_max_size
,
)
except
Exception
as
e
:
logger
.
warning
(
f
"Setup Custom allreduce failed with
{
e
}
. To silence this "
...
...
python/sglang/srt/environ.py
View file @
a1175a4e
...
...
@@ -188,6 +188,17 @@ class Envs:
SGLANG_USE_AITER
=
EnvBool
(
False
)
SGLANG_ROCM_FUSED_DECODE_MLA
=
EnvBool
(
False
)
SGLANG_ROCM_DISABLE_LINEARQUANT
=
EnvBool
(
False
)
# DCU Lightop
SGLANG_USE_LIGHTOP
=
EnvBool
(
False
)
# Fused
SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD
=
EnvBool
(
False
)
SGLANG_USE_OPT_CAT
=
EnvBool
(
False
)
SGLANG_USE_FUSED_RMS_QUANT
=
EnvBool
(
False
)
SGLANG_USE_FUSED_SILU_MUL_QUANT
=
EnvBool
(
False
)
# Quantization
SGLANG_INT4_WEIGHT
=
EnvBool
(
False
)
...
...
python/sglang/srt/layers/attention/attention_registry.py
View file @
a1175a4e
...
...
@@ -99,7 +99,6 @@ def create_triton_backend(runner):
return
TritonAttnBackend
(
runner
)
@
register_attention_backend
(
"torch_native"
)
def
create_torch_native_backend
(
runner
):
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
...
...
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
return
FlashMLABackend
(
runner
)
@
register_attention_backend
(
"dcu_mla"
)
def
create_dcu_mla_backend
(
runner
):
from
sglang.srt.layers.attention.dcu_mla_backend
import
DCUMLABackend
return
DCUMLABackend
(
runner
)
@
register_attention_backend
(
"fa3"
)
def
create_flashattention_v3_backend
(
runner
):
...
...
python/sglang/srt/layers/attention/dcu_mla_backend.py
0 → 100644
View file @
a1175a4e
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
View file @
a1175a4e
...
...
@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
sgl_kernel.sparse_flash_attn
import
(
convert_vertical_slash_indexes
,
convert_vertical_slash_indexes_mergehead
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a1175a4e
...
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
@
dataclass
...
...
@@ -328,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
skip_prefill
=
skip_prefill
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
k_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
v_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
if
self
.
is_hybrid
:
self
.
full_to_swa_index_mapping
=
(
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
...
...
@@ -596,9 +599,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
if
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
...
...
@@ -608,10 +613,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
if
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
:
# # Setup local attention if enabled
# if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if
forward_batch
.
forward_mode
in
(
ForwardMode
.
EXTEND
,
ForwardMode
.
DRAFT_EXTEND_V2
):
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
assert
(
...
...
@@ -668,10 +676,11 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
self
.
use_mla
:
if
k_rope
is
None
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
#
layer.k_scale, layer.v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
...
...
@@ -690,7 +699,8 @@ class FlashAttentionBackend(AttentionBackend):
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
...
...
@@ -704,7 +714,7 @@ class FlashAttentionBackend(AttentionBackend):
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
q
=
q
.
to
(
self
.
kv_cache_dtype
)
#
q = q.to(self.kv_cache_dtype)
q_rope
=
q_rope
.
to
(
self
.
kv_cache_dtype
)
if
q_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
causal
=
True
...
...
@@ -774,61 +784,59 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
result
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
if
forward_batch
.
attn_attend_prefix_cache
:
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA
assert
forward_batch
.
prefix_chunk_idx
is
not
None
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
assert
forward_batch
.
prefix_chunk_max_seq_lens
is
not
None
chunk_idx
=
forward_batch
.
prefix_chunk_idx
assert
chunk_idx
>=
0
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
)
else
:
o
=
result
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
)
if
forward_batch
.
mha_return_lse
:
output
,
lse
,
*
rest
=
output
lse
=
torch
.
transpose
(
lse
,
0
,
1
).
contiguous
()
return
output
,
lse
return
output
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
if
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
):
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
# Do multi-head attention with chunked prefix cache
if
forward_batch
.
attn_attend_prefix_cache
:
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
...
...
@@ -843,39 +851,32 @@ class FlashAttentionBackend(AttentionBackend):
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
**
kwargs
,
)
else
:
# MHA for extend part of sequence without attending prefix kv cache
cu_seqlens_k
=
(
metadata
.
cu_seqlens_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
cu_seqlens_k
)
max_seqlen_k
=
(
metadata
.
max_seq_len_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
max_seq_len_k
)
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
)
...
...
@@ -985,10 +986,16 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
# if not self.use_mla:
if
k_rope
is
None
:
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
...
...
@@ -1030,7 +1037,8 @@ class FlashAttentionBackend(AttentionBackend):
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
...
...
@@ -1044,7 +1052,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
if
not
self
.
use_mla
:
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
...
...
@@ -1096,65 +1103,62 @@ class FlashAttentionBackend(AttentionBackend):
**
kwargs
,
)
else
:
cu_seqlens_q
=
metadata
.
cu_seqlens_q
max_seqlen_q
=
metadata
.
max_seq_len_q
page_table
=
metadata
.
page_table
cache_seqlens
=
metadata
.
cache_seqlens_int32
cu_seqlens_k
=
metadata
.
cu_seqlens_k
max
_seqlen
_q
=
metadata
.
max
_seq
_
len
_q
q_reshaped
=
q
.
contiguous
()
.
view
(
-
1
,
layer
.
tp_
q
_head_num
,
layer
.
head_dim
cache
_seqlen
s
=
metadata
.
cache
_seqlen
s_int32
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_
k
_head_num
,
layer
.
head_dim
)
# Default: single-token self-attention
result
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
(
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
if
layer
.
is_cross_attention
:
page_table
=
metadata
.
encoder_page_table
cache_seqlens
=
metadata
.
encoder_lens_int32
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
if
max_seqlen_q
>
1
:
result
=
flash_attn_varlen_func
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
else
:
o
=
result
result
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
o
=
result
else
:
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
...
...
python/sglang/srt/layers/attention/flashattention_interface.py
0 → 100644
View file @
a1175a4e
from
flash_attn
import
(
flash_attn_varlen_func
as
flash_attn_varlen_func_interface
,
flash_attn_with_kvcache
as
flash_attn_with_kvcache_interface
)
from
typing
import
Optional
,
Union
import
torch
def
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
=
None
,
v
=
None
,
qv
=
None
,
rotary_cos
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[
int
,
torch
.
Tensor
]]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
page_table
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_k_new
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
rotary_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
q_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
k_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
v_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
attention_chunk
:
Optional
[
int
]
=
None
,
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
scheduler_metadata
=
None
,
num_splits
=
0
,
# Can be tuned for speed
pack_gqa
=
None
,
# Can be tuned for speed
sm_margin
=
0
,
# Can be tuned if some SMs are used for communication
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_with_kvcache_interface
(
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
k_cache
=
k_cache
.
view
(
q
.
dtype
),
v_cache
=
v_cache
.
view
(
q
.
dtype
),
block_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
return_softmax_lse
,
num_splits
=
num_splits
,
)
def
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
=
None
,
max_seqlen_k
=
None
,
seqused_q
=
None
,
seqused_k
=
None
,
page_table
=
None
,
softmax_scale
=
None
,
causal
=
False
,
qv
=
None
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
window_size
=
(
-
1
,
-
1
),
attention_chunk
=
0
,
softcap
=
0.0
,
num_splits
=
1
,
pack_gqa
=
None
,
sm_margin
=
0
,
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_varlen_func_interface
(
q
=
q
,
k
=
k
.
view
(
q
.
dtype
),
v
=
v
.
view
(
q
.
dtype
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
return_attn_probs
=
return_softmax_lse
,
softcap
=
softcap
,
)
\ No newline at end of file
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
a1175a4e
...
...
@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
from
sgl_kernel.flash_mla
import
dcu_create_flashmla_kv_indices
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
use_sglang_create_flashmla_kv_indices_triton
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
bs
=
forward_batch
.
batch_size
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
...
...
@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
...
...
@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
...
...
@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
else
:
super
().
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
...
...
python/sglang/srt/layers/attention/lightop_concat.py
0 → 100644
View file @
a1175a4e
from
__future__
import
annotations
import
warnings
import
torch
from
sglang.srt.utils
import
get_bool_env_var
,
direct_register_custom_op
_USE_OPT_CAT
=
get_bool_env_var
(
"SGLANG_USE_OPT_CAT"
)
if
_USE_OPT_CAT
:
try
:
from
lightop
import
ds_cat
# type: ignore
except
ImportError
:
# pragma: no cover
ds_cat
=
None
warnings
.
warn
(
"SGLANG_USE_OPT_CAT 已开启但无法导入 lightop.ds_cat,退回 torch.cat"
)
else
:
ds_cat
=
None
# TODO: 单独注册有些问题
def
ds_cat_wrapper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
ds_cat
(
A
,
B
,
C
,
mode
)
return
C
def
ds_cat_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
# 使用标准cat作为fake实现
return
torch
.
cat
([
A
,
B
],
dim
=
dim
)
direct_register_custom_op
(
op_name
=
"ds_cat"
,
op_func
=
ds_cat_wrapper
,
mutates_args
=
[],
# 没有修改参数,只有返回值
fake_impl
=
ds_cat_fake
)
def
concat_decode_opt
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
mode
=
0
if
dim
!=
0
:
return
torch
.
ops
.
sglang
.
ds_cat
(
A
,
B
,
dim
,
mode
)
assert
False
,
"not support"
# def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
python/sglang/srt/layers/attention/nsa_backend.py
View file @
a1175a4e
...
...
@@ -47,7 +47,8 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
else
:
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_with_kvcache
@
dataclass
(
frozen
=
True
)
...
...
python/sglang/srt/layers/attention/xpu_backend.py
View file @
a1175a4e
...
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
class
XPUAttentionBackend
(
AttentionBackend
):
...
...
python/sglang/srt/layers/layernorm.py
View file @
a1175a4e
...
...
@@ -160,21 +160,53 @@ class RMSNorm(CustomOp):
return
output
,
residual_out
return
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
not
x
.
is_contiguous
():
# NOTE: Remove this if aiter kernel supports discontinuous input
x
=
x
.
contiguous
()
if
residual
is
not
None
:
if
_vllm_version
<
Version
(
"0.9"
):
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
try
:
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
else
:
residual_out
=
torch
.
empty_like
(
x
)
except
TypeError
:
output
=
torch
.
empty_like
(
x
)
residual_out
=
torch
.
empty_like
(
x
)
fused_add_rms_norm
(
output
,
x
,
...
...
@@ -184,10 +216,13 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
,
)
return
output
,
residual_out
out
=
torch
.
empty_like
(
x
)
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/linear.py
View file @
a1175a4e
...
...
@@ -45,6 +45,18 @@ _is_hip = is_hip()
_disable_hip_linear_quant
=
_is_hip
and
get_bool_env_var
(
"SGLANG_ROCM_DISABLE_LINEARQUANT"
)
_use_fused_rms_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_RMS_QUANT"
)
_use_fused_silu_mul_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_SILU_MUL_QUANT"
)
if
_use_fused_rms_quant
:
try
:
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
except
Exception
as
e
:
print
(
f
"Error: Import fused rmsquant error:
{
e
}
"
)
if
_use_fused_silu_mul_quant
:
try
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
except
Exception
as
e
:
print
(
f
"Error: Import fused silu_mul_quant error:
{
e
}
"
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -1360,7 +1372,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters.
param
.
load_row_parallel_weight
(
loaded_weight
)
def
forward
(
self
,
input_
,
skip_all_reduce
=
False
):
def
forward
(
self
,
input_
,
skip_all_reduce
=
False
,
use_fused_silu_mul_quant
=
False
):
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
...
...
@@ -1374,10 +1386,19 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
with
use_symmetric_memory
(
get_tp_group
(),
disabled
=
not
is_allocation_symmetric
()
):
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
use_fused_silu_mul_quant
:
xq
,
xs
=
lm_fuse_silu_mul_quant
(
input_parallel
)
silu_quant_args
=
[
xq
,
xs
]
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
,
silu_quant_args
=
silu_quant_args
)
sm
.
tag
(
output_parallel
)
else
:
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
sm
.
tag
(
output_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
skip_all_reduce
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
...
...
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
a1175a4e
...
...
@@ -2,7 +2,6 @@ import logging
import
torch
import
triton
from
sglang.srt.utils
import
ceil_div
,
is_cuda
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -1015,196 +1014,133 @@ def zero_experts_compute_triton(
return
output
from
triton.language.extra
import
libdevice
from
typing
import
Optional
@
triton
.
jit
def
compute_problem_sizes_w4a8_kernel
(
masked_m_ptr
,
problem_sizes1_ptr
,
problem_sizes2_ptr
,
n
,
k
,
num_experts
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
pid
<
num_experts
final_occurrences
=
tl
.
load
(
masked_m_ptr
+
pid
,
mask
=
mask
,
other
=
0
)
ps1_idx_0
=
pid
*
3
ps1_idx_1
=
ps1_idx_0
+
1
ps1_idx_2
=
ps1_idx_0
+
2
ps2_idx_0
=
pid
*
3
ps2_idx_1
=
ps2_idx_0
+
1
ps2_idx_2
=
ps2_idx_0
+
2
ps1_mask_0
=
ps1_idx_0
<
num_experts
*
3
ps1_mask_1
=
ps1_idx_1
<
num_experts
*
3
ps1_mask_2
=
ps1_idx_2
<
num_experts
*
3
ps2_mask_0
=
ps2_idx_0
<
num_experts
*
3
ps2_mask_1
=
ps2_idx_1
<
num_experts
*
3
ps2_mask_2
=
ps2_idx_2
<
num_experts
*
3
tl
.
store
(
problem_sizes1_ptr
+
ps1_idx_0
,
2
*
n
,
mask
=
ps1_mask_0
)
tl
.
store
(
problem_sizes1_ptr
+
ps1_idx_1
,
final_occurrences
,
mask
=
ps1_mask_1
)
tl
.
store
(
problem_sizes1_ptr
+
ps1_idx_2
,
k
,
mask
=
ps1_mask_2
)
tl
.
store
(
problem_sizes2_ptr
+
ps2_idx_0
,
k
,
mask
=
ps2_mask_0
)
tl
.
store
(
problem_sizes2_ptr
+
ps2_idx_1
,
final_occurrences
,
mask
=
ps2_mask_1
)
tl
.
store
(
problem_sizes2_ptr
+
ps2_idx_2
,
n
,
mask
=
ps2_mask_2
)
def
compute_problem_sizes_w4a8
(
masked_m
,
problem_sizes1
,
problem_sizes2
,
n
,
k
,
num_experts
):
BLOCK_SIZE
=
256
grid
=
lambda
meta
:
(
triton
.
cdiv
(
num_experts
,
meta
[
"BLOCK_SIZE"
]),)
compute_problem_sizes_w4a8_kernel
[
grid
](
masked_m
,
problem_sizes1
,
problem_sizes2
,
n
,
k
,
num_experts
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
problem_sizes1
,
problem_sizes2
def
deepep_ll_get_cutlass_w4a8_moe_mm_data
(
masked_m
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
def
_per_token_quant_int8_one_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
problem_sizes1
,
problem_sizes2
=
compute_problem_sizes_w4a8
(
masked_m
,
problem_sizes1
,
problem_sizes2
,
n
,
k
,
num_experts
)
return
(
problem_sizes1
.
to
(
torch
.
int32
),
problem_sizes2
.
to
(
torch
.
int32
),
)
row_id
=
tl
.
program_id
(
0
)
if
tokens_per_expert_ptr
is
not
None
:
e
=
row_id
//
T_dim
t
=
row_id
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
return
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
@
triton
.
jit
def
_silu_and_mul_post_per_tensor_quant_kernel
(
input_ptr
,
stride_input_expert
,
stride_input_token
,
stride_input_dim
,
output_ptr
,
stride_output_expert
,
stride_output_token
,
stride_output_dim
,
def
_per_token_quant_int8_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
masked_m_ptr
,
inner_dim
,
fp8_max
,
fp8_min
,
BLOCK_N
:
tl
.
constexpr
,
NUM_STAGE
:
tl
.
constexpr
,
stride_x
,
stride_xq
,
N
,
E_dim
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
"""
Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
Shape:
input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D]
output: [E, T_padded, D], dtype=float8_e4m3fn
"""
expert_id
=
tl
.
program_id
(
2
)
block_id_token
=
tl
.
program_id
(
1
)
block_id_dim
=
tl
.
program_id
(
0
)
num_token_blocks
=
tl
.
num_programs
(
1
)
token_num_cur_expert
=
tl
.
load
(
masked_m_ptr
+
expert_id
)
scale
=
1.0
/
tl
.
load
(
scale_ptr
).
to
(
tl
.
float32
)
stride_input_expert
=
tl
.
cast
(
stride_input_expert
,
tl
.
int32
)
stride_output_expert
=
tl
.
cast
(
stride_output_expert
,
tl
.
int32
)
stride_input_token
=
tl
.
cast
(
stride_input_token
,
tl
.
int32
)
stride_output_token
=
tl
.
cast
(
stride_output_token
,
tl
.
int32
)
offset_d
=
block_id_dim
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_d
=
offset_d
<
inner_dim
# base pointers for current expert and dim block
input_base_offs
=
input_ptr
+
expert_id
*
stride_input_expert
+
offset_d
output_base_offs
=
output_ptr
+
expert_id
*
stride_output_expert
+
offset_d
for
token_idx
in
tl
.
range
(
block_id_token
,
token_num_cur_expert
,
num_token_blocks
,
num_stages
=
NUM_STAGE
):
gate_ptr
=
input_base_offs
+
token_idx
*
stride_input_token
up_ptr
=
gate_ptr
+
inner_dim
gate
=
tl
.
load
(
gate_ptr
,
mask
=
mask_d
,
other
=
0.0
).
to
(
tl
.
float32
)
up
=
tl
.
load
(
up_ptr
,
mask
=
mask_d
,
other
=
0.0
).
to
(
tl
.
float32
)
# SiLU: x * sigmoid(x)
gate
=
gate
/
(
1
+
tl
.
exp
(
-
gate
))
gate
=
gate
.
to
(
input_ptr
.
dtype
.
element_ty
)
gate_up
=
up
*
gate
scaled
=
gate_up
*
scale
output_q
=
tl
.
clamp
(
scaled
,
fp8_min
,
fp8_max
).
to
(
output_ptr
.
dtype
.
element_ty
)
out_ptr
=
output_base_offs
+
token_idx
*
stride_output_token
tl
.
store
(
out_ptr
,
output_q
,
mask
=
mask_d
)
def
silu_and_mul_masked_post_per_tensor_quant_fwd
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Fused SiLU + Mul + Per-Tensor Quantization to FP8.
Args:
input: [expert_num, token_num_padded, 2 * inner_dim]
output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
masked_m: [expert_num], actual token count for each expert
scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
Returns:
output tensor
"""
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
assert
output
.
dtype
==
torch
.
float8_e4m3fn
assert
input
.
ndim
==
3
assert
input
.
shape
[
0
]
==
masked_m
.
shape
[
0
]
assert
input
.
shape
[
-
1
]
%
2
==
0
assert
scale
.
numel
()
==
1
or
scale
.
shape
[
0
]
==
input
.
shape
[
0
]
expert_num
=
input
.
shape
[
0
]
# 3584
inner_dim
=
input
.
shape
[
-
1
]
//
2
BLOCK_N
=
256
BLOCK_M
=
64
if
expert_num
<
4
else
32
NUM_STAGES
=
3
hidden_dim_split_block_num
=
triton
.
cdiv
(
inner_dim
,
BLOCK_N
)
grid
=
(
hidden_dim_split_block_num
,
BLOCK_M
,
expert_num
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
_silu_and_mul_post_per_tensor_quant_kernel
[
grid
](
input
,
*
input
.
stride
(),
output
,
*
output
.
stride
(),
scale
,
masked_m
,
inner_dim
,
fp8_max
,
fp8_min
,
BLOCK_N
=
BLOCK_N
,
NUM_STAGE
=
NUM_STAGES
,
)
return
output
token_idx_start
=
tl
.
program_id
(
0
)
grid_size
=
tl
.
num_programs
(
0
)
num_total_tokens
=
E_dim
*
T_dim
for
token_idx
in
range
(
token_idx_start
,
num_total_tokens
,
grid_size
):
is_valid_token
=
True
if
tokens_per_expert_ptr
is
not
None
:
e
=
token_idx
//
T_dim
t
=
token_idx
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
is_valid_token
=
False
if
is_valid_token
:
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
token_idx
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
token_idx
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
token_idx
,
scale_x
)
def
per_token_quant_int8_triton_opt
(
x
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
):
if
x
.
dim
()
!=
3
:
raise
ValueError
(
f
"Input must be 3D [E, T, H], but got
{
x
.
shape
}
"
)
E
,
T
,
H
=
x
.
shape
N
=
H
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
num_warps
=
1
num_tokens
=
E
*
T
grid_opt
=
num_tokens
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
grid_opt
=
max
(
1
,
num_tokens
//
(
T
//
256
))
_per_token_quant_int8_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
E_dim
=
E
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
else
:
_per_token_quant_int8_one_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
a1175a4e
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
a1175a4e
...
...
@@ -65,7 +65,7 @@ def inplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -84,6 +84,8 @@ def inplace_fused_experts(
gemm1_limit
:
Optional
[
float
]
=
None
,
filter_expert
:
bool
=
True
,
)
->
None
:
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
fused_experts_impl
(
hidden_states
,
w1
,
...
...
@@ -123,7 +125,7 @@ def inplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -161,7 +163,7 @@ def outplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -181,6 +183,8 @@ def outplace_fused_experts(
gemm1_limit
:
Optional
[
float
]
=
None
,
filter_expert
:
bool
=
True
,
)
->
torch
.
Tensor
:
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
return
fused_experts_impl
(
hidden_states
,
w1
,
...
...
@@ -220,7 +224,7 @@ def outplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -273,9 +277,12 @@ def fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
topk_weights
,
topk_ids
,
_
=
topk_output
filter_expert
=
(
moe_runner_config
.
num_experts
is
None
or
moe_runner_config
.
num_experts
!=
moe_runner_config
.
num_local_experts
act_id
=
(
0
if
(
moe_runner_config
.
activation
==
0
or
(
isinstance
(
moe_runner_config
.
activation
,
str
)
and
moe_runner_config
.
activation
.
lower
()
==
"silu"
)
)
else
1
)
if
moe_runner_config
.
inplace
:
assert
not
moe_runner_config
.
no_combine
,
"no combine + inplace makes no sense"
...
...
@@ -287,7 +294,7 @@ def fused_experts(
topk_ids
,
b1
,
b2
,
moe_runner_config
.
activation
,
act_id
,
moe_runner_config
.
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
...
...
@@ -316,7 +323,7 @@ def fused_experts(
topk_ids
,
b1
,
b2
,
moe_runner_config
.
activation
,
act_id
,
moe_runner_config
.
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
...
...
@@ -366,7 +373,7 @@ def fused_experts_impl(
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -386,6 +393,8 @@ def fused_experts_impl(
gemm1_limit
:
Optional
[
float
]
=
None
,
filter_expert
:
bool
=
True
,
):
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
padded_size
=
padding_size
if
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
_use_aiter
:
padded_size
=
0
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment