Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a1175a4e
Commit
a1175a4e
authored
Nov 22, 2025
by
maxiao1
Browse files
Merge remote-tracking branch 'origin/v0.5.4_dev' into sglang_v0.5.5
parents
0c006b88
31653dd9
Changes
62
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2608 additions
and
409 deletions
+2608
-409
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+10
-0
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+173
-3
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+5
-1
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+286
-0
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+13
-5
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+11
-0
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+5
-1
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+679
-0
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
...srt/layers/attention/dual_chunk_flashattention_backend.py
+2
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+131
-127
python/sglang/srt/layers/attention/flashattention_interface.py
...n/sglang/srt/layers/attention/flashattention_interface.py
+96
-0
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+48
-20
python/sglang/srt/layers/attention/lightop_concat.py
python/sglang/srt/layers/attention/lightop_concat.py
+67
-0
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+2
-1
python/sglang/srt/layers/attention/xpu_backend.py
python/sglang/srt/layers/attention/xpu_backend.py
+2
-1
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+44
-9
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+26
-5
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+124
-188
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+865
-37
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+19
-10
No files found.
python/sglang/bench_serving.py
View file @
a1175a4e
...
@@ -848,10 +848,12 @@ class BenchmarkMetrics:
...
@@ -848,10 +848,12 @@ class BenchmarkMetrics:
mean_ttft_ms
:
float
mean_ttft_ms
:
float
median_ttft_ms
:
float
median_ttft_ms
:
float
std_ttft_ms
:
float
std_ttft_ms
:
float
p95_ttft_ms
:
float
p99_ttft_ms
:
float
p99_ttft_ms
:
float
mean_tpot_ms
:
float
mean_tpot_ms
:
float
median_tpot_ms
:
float
median_tpot_ms
:
float
std_tpot_ms
:
float
std_tpot_ms
:
float
p95_tpot_ms
:
float
p99_tpot_ms
:
float
p99_tpot_ms
:
float
mean_itl_ms
:
float
mean_itl_ms
:
float
median_itl_ms
:
float
median_itl_ms
:
float
...
@@ -1721,10 +1723,12 @@ def calculate_metrics(
...
@@ -1721,10 +1723,12 @@ def calculate_metrics(
*
1000
,
# ttfts is empty if streaming is not supported by backend
*
1000
,
# ttfts is empty if streaming is not supported by backend
median_ttft_ms
=
np
.
median
(
ttfts
or
0
)
*
1000
,
median_ttft_ms
=
np
.
median
(
ttfts
or
0
)
*
1000
,
std_ttft_ms
=
np
.
std
(
ttfts
or
0
)
*
1000
,
std_ttft_ms
=
np
.
std
(
ttfts
or
0
)
*
1000
,
p95_ttft_ms
=
np
.
percentile
(
ttfts
or
0
,
95
)
*
1000
,
p99_ttft_ms
=
np
.
percentile
(
ttfts
or
0
,
99
)
*
1000
,
p99_ttft_ms
=
np
.
percentile
(
ttfts
or
0
,
99
)
*
1000
,
mean_tpot_ms
=
np
.
mean
(
tpots
or
0
)
*
1000
,
mean_tpot_ms
=
np
.
mean
(
tpots
or
0
)
*
1000
,
median_tpot_ms
=
np
.
median
(
tpots
or
0
)
*
1000
,
median_tpot_ms
=
np
.
median
(
tpots
or
0
)
*
1000
,
std_tpot_ms
=
np
.
std
(
tpots
or
0
)
*
1000
,
std_tpot_ms
=
np
.
std
(
tpots
or
0
)
*
1000
,
p95_tpot_ms
=
np
.
percentile
(
tpots
or
0
,
95
)
*
1000
,
p99_tpot_ms
=
np
.
percentile
(
tpots
or
0
,
99
)
*
1000
,
p99_tpot_ms
=
np
.
percentile
(
tpots
or
0
,
99
)
*
1000
,
mean_itl_ms
=
np
.
mean
(
itls
or
0
)
*
1000
,
mean_itl_ms
=
np
.
mean
(
itls
or
0
)
*
1000
,
median_itl_ms
=
np
.
median
(
itls
or
0
)
*
1000
,
median_itl_ms
=
np
.
median
(
itls
or
0
)
*
1000
,
...
@@ -2052,6 +2056,12 @@ async def benchmark(
...
@@ -2052,6 +2056,12 @@ async def benchmark(
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TTFT (ms):"
,
metrics
.
mean_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TTFT (ms):"
,
metrics
.
mean_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TTFT (ms):"
,
metrics
.
median_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TTFT (ms):"
,
metrics
.
median_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TTFT (ms):"
,
metrics
.
p99_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TTFT (ms):"
,
metrics
.
p99_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P95 TTFT (ms):"
,
metrics
.
p95_ttft_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Time per Output Token (excl. 1st token)"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TPOT (ms):"
,
metrics
.
mean_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TPOT (ms):"
,
metrics
.
median_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TPOT (ms):"
,
metrics
.
p99_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P95 TPOT (ms):"
,
metrics
.
p95_tpot_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Inter-Token Latency"
,
n
=
50
,
c
=
"-"
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Inter-Token Latency"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean ITL (ms):"
,
metrics
.
mean_itl_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean ITL (ms):"
,
metrics
.
mean_itl_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median ITL (ms):"
,
metrics
.
median_itl_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median ITL (ms):"
,
metrics
.
median_itl_ms
))
...
...
python/sglang/srt/_custom_ops.py
View file @
a1175a4e
...
@@ -4,10 +4,24 @@ from typing import List, Optional, Tuple
...
@@ -4,10 +4,24 @@ from typing import List, Optional, Tuple
import
torch
import
torch
from
sglang.srt.utils
import
is_hip
,
is_hpu
,
is_npu
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
,
is_npu
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
import
lightop
except
Exception
:
print
(
"INFO: Please install lightop if you want to infer awq of marlin.
\n
"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
get_bool_env_var
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
"false"
)
use_dcu_custom_allreduce
=
get_bool_env_var
(
"USE_DCU_CUSTOM_ALLREDUCE"
,
default
=
"true"
)
if
not
is_hpu
():
if
not
is_hpu
():
try
:
try
:
...
@@ -15,6 +29,11 @@ if not is_hpu():
...
@@ -15,6 +29,11 @@ if not is_hpu():
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
if
use_dcu_custom_allreduce
:
try
:
import
vllm._C
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
if
not
is_hip
()
and
not
is_npu
():
if
not
is_hip
()
and
not
is_npu
():
custom_op
=
sgl_kernel
.
allreduce
custom_op
=
sgl_kernel
.
allreduce
...
@@ -54,8 +73,79 @@ if not is_hip() and not is_npu():
...
@@ -54,8 +73,79 @@ if not is_hip() and not is_npu():
)
->
None
:
)
->
None
:
custom_op
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
custom_op
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
elif
is_hip
and
use_dcu_custom_allreduce
:
# custom ar
def
init_custom_ar
(
ipc_tensors
:
list
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
fully_connected
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
fully_connected
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
list
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
tuple
[
list
[
int
],
list
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
list
[
list
[
int
]],
offsets
:
list
[
list
[
int
]])
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_shared_buffer_and_handle
(
size
:
int
)
->
tuple
[
int
,
torch
.
Tensor
]:
return
torch
.
ops
.
_C_custom_ar
.
allocate_shared_buffer_and_handle
(
size
)
def
open_mem_handle
(
mem_handle
:
torch
.
Tensor
):
return
torch
.
ops
.
_C_custom_ar
.
open_mem_handle
(
mem_handle
)
def
free_shared_buffer
(
ptr
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
free_shared_buffer
(
ptr
)
def
read_cache
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
list
[
torch
.
Tensor
],
value_caches
:
list
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
def
write_cache_multi_layers
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
list
[
torch
.
Tensor
],
value_caches
:
list
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
else
:
else
:
# ROCM custom allreduce
#
sgl_kernel
ROCM custom allreduce
def
init_custom_ar
(
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
meta
:
torch
.
Tensor
,
...
@@ -163,3 +253,83 @@ def mscclpp_allreduce(
...
@@ -163,3 +253,83 @@ def mscclpp_allreduce(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
)
->
None
:
return
sgl_kernel
.
allreduce
.
mscclpp_allreduce
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
return
sgl_kernel
.
allreduce
.
mscclpp_allreduce
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
def
triton_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
list
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
,
best_config
)
def
cutlass_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
# m = a.shape[0]
# n = b.shape[1]
# cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
# if current_platform.is_rocm() or not cutlass_compatible_b:
# from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
# triton_scaled_mm)
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
# out = torch.empty((m, n), dtype=out_dtype, device=a.device)
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
triton_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
out_dtype
:
type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda:0"
,
best_config
:
Optional
[
list
]
=
None
,
repeat
:
Optional
[
int
]
=
2
):
return
quant_tools
.
triton_int8_gemm_helper
(
m
,
n
,
k
,
per_token_act_quant
,
per_out_channel_weight_quant
,
use_bias
,
out_dtype
,
device
,
best_config
,
repeat
)
\ No newline at end of file
python/sglang/srt/configs/model_config.py
View file @
a1175a4e
...
@@ -635,7 +635,9 @@ class ModelConfig:
...
@@ -635,7 +635,9 @@ class ModelConfig:
"petit_nvfp4"
,
"petit_nvfp4"
,
"quark"
,
"quark"
,
"mxfp4"
,
"mxfp4"
,
"auto-round"
,
"slimquant_w4a8_marlin"
,
"w8a8_int8"
,
"slimquant_marlin"
,
]
]
optimized_quantization_methods
=
[
optimized_quantization_methods
=
[
"fp8"
,
"fp8"
,
...
@@ -655,6 +657,8 @@ class ModelConfig:
...
@@ -655,6 +657,8 @@ class ModelConfig:
"qoq"
,
"qoq"
,
"w4afp8"
,
"w4afp8"
,
"petit_nvfp4"
,
"petit_nvfp4"
,
"slimquant_w4a8_marlin"
,
"slimquant_marlin"
,
]
]
compatible_quantization_methods
=
{
compatible_quantization_methods
=
{
"modelopt_fp8"
:
[
"modelopt"
],
"modelopt_fp8"
:
[
"modelopt"
],
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
a1175a4e
...
@@ -34,6 +34,21 @@ except ImportError:
...
@@ -34,6 +34,21 @@ except ImportError:
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
try
:
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
# Use vLLM custom allreduce
ops
.
meta_size
()
elif
ops
.
use_dcu_custom_allreduce
:
ops
.
meta_size
()
else
:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
# noqa: F401
custom_ar
=
True
except
Exception
:
# For CPUs
custom_ar
=
False
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -416,3 +431,274 @@ class CustomAllreduce:
...
@@ -416,3 +431,274 @@ class CustomAllreduce:
def
__del__
(
self
):
def
__del__
(
self
):
self
.
close
()
self
.
close
()
class
DCUCustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
,
16
]
# max_size: max supported allreduce size
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
512
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
not
custom_ar
:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger
.
info
(
"Custom allreduce is disabled because "
"of missing custom allreduce library"
)
return
self
.
group
=
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"CustomAllreduce should be attached to a non-NCCL group."
)
if
not
all
(
in_the_same_node_as
(
group
,
source_rank
=
0
)):
# No need to initialize custom allreduce for multi-node case.
logger
.
warning
(
"Custom allreduce is disabled because this process group"
" spans across nodes."
)
return
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
self
.
rank
=
rank
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
if
world_size
>
16
:
return
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
))
return
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
cuda_visible_devices
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
)
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
torch
.
cuda
.
device_count
()))
physical_device_id
=
device_ids
[
device
.
index
]
tensor
=
torch
.
tensor
([
physical_device_id
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
gather_list
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
for
_
in
range
(
world_size
)
]
dist
.
all_gather
(
gather_list
,
tensor
,
group
=
self
.
group
)
physical_device_ids
=
[
t
.
item
()
for
t
in
gather_list
]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# assert current_platform.is_cuda_alike()
# fully_connected = current_platform.is_fully_connected(
# physical_device_ids)
if
_is_cuda
or
_is_hip
:
fully_connected
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
# if world_size > 2 and not fully_connected:
if
not
fully_connected
:
max_size
=
32
*
8192
*
2
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
# logger.warning(
# "Custom allreduce is disabled because it's not supported on"
# " more than two PCIe-only GPUs. To silence this warning, "
# "specify disable_custom_all_reduce=True explicitly.")
# return
logger
.
warning
(
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction."
)
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
_is_hip
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self
.
disabled
=
False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
,
uncached
=
True
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
fully_connected
=
fully_connected
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
fully_connected
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
@
contextmanager
def
capture
(
self
):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try
:
self
.
_IS_CAPTURING
=
True
yield
finally
:
self
.
_IS_CAPTURING
=
False
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
def
register_graph_buffers
(
self
):
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data
=
[[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))]
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
if
self
.
disabled
:
return
False
inp_size
=
inp
.
numel
()
*
inp
.
element_size
()
# custom allreduce requires input byte size to be multiples of 16
if
inp_size
%
16
!=
0
:
return
False
if
not
is_weak_contiguous
(
inp
):
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
return
inp_size
<=
self
.
max_size
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if
self
.
disabled
or
not
self
.
should_custom_ar
(
input
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
,
registered
=
False
)
else
:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
ops
is
not
None
:
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
,
rank
=
self
.
rank
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
,
rank
=
self
.
rank
)
def
__del__
(
self
):
self
.
close
()
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
,
uncached
:
Optional
[
bool
]
=
False
)
->
list
[
int
]:
pointer
,
handle
=
ops
.
allocate_shared_buffer_and_handle
(
size_in_bytes
)
world_size
=
dist
.
get_world_size
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
handles
=
[
None
]
*
world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
pointers
:
list
[
int
]
=
[]
for
i
,
h
in
enumerate
(
handles
):
if
i
==
rank
:
pointers
.
append
(
pointer
)
# type: ignore
else
:
pointers
.
append
(
ops
.
open_mem_handle
(
h
))
return
pointers
@
staticmethod
def
free_shared_buffer
(
pointers
:
list
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
,
rank
:
Optional
[
int
]
=
0
)
->
None
:
if
rank
is
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
if
ops
is
not
None
:
ops
.
free_shared_buffer
(
pointers
[
rank
])
python/sglang/srt/distributed/parallel_state.py
View file @
a1175a4e
...
@@ -54,6 +54,7 @@ from sglang.srt.utils import (
...
@@ -54,6 +54,7 @@ from sglang.srt.utils import (
is_xpu
,
is_xpu
,
supports_custom_op
,
supports_custom_op
,
)
)
from
sglang.srt
import
_custom_ops
as
ops
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
...
@@ -327,7 +328,7 @@ class GroupCoordinator:
...
@@ -327,7 +328,7 @@ class GroupCoordinator:
# Lazy import to avoid documentation build error
# Lazy import to avoid documentation build error
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
,
CustomAllreduce
,
DCUCustomAllreduce
)
)
from
sglang.srt.distributed.device_communicators.pymscclpp
import
(
from
sglang.srt.distributed.device_communicators.pymscclpp
import
(
PyMscclppCommunicator
,
PyMscclppCommunicator
,
...
@@ -371,9 +372,16 @@ class GroupCoordinator:
...
@@ -371,9 +372,16 @@ class GroupCoordinator:
if
use_custom_allreduce
and
self
.
world_size
>
1
:
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
# Initialize a custom fast all-reduce implementation.
try
:
try
:
if
is_hip
()
and
ops
.
use_dcu_custom_allreduce
:
self
.
ca_comm
=
DCUCustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
else
:
self
.
ca_comm
=
CustomAllreduce
(
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
group
=
self
.
cpu_group
,
device
=
self
.
device
,
device
=
self
.
device
,
max_size
=
ca_max_size
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
warning
(
logger
.
warning
(
...
...
python/sglang/srt/environ.py
View file @
a1175a4e
...
@@ -189,6 +189,17 @@ class Envs:
...
@@ -189,6 +189,17 @@ class Envs:
SGLANG_ROCM_FUSED_DECODE_MLA
=
EnvBool
(
False
)
SGLANG_ROCM_FUSED_DECODE_MLA
=
EnvBool
(
False
)
SGLANG_ROCM_DISABLE_LINEARQUANT
=
EnvBool
(
False
)
SGLANG_ROCM_DISABLE_LINEARQUANT
=
EnvBool
(
False
)
# DCU Lightop
SGLANG_USE_LIGHTOP
=
EnvBool
(
False
)
# Fused
SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD
=
EnvBool
(
False
)
SGLANG_USE_OPT_CAT
=
EnvBool
(
False
)
SGLANG_USE_FUSED_RMS_QUANT
=
EnvBool
(
False
)
SGLANG_USE_FUSED_SILU_MUL_QUANT
=
EnvBool
(
False
)
# Quantization
# Quantization
SGLANG_INT4_WEIGHT
=
EnvBool
(
False
)
SGLANG_INT4_WEIGHT
=
EnvBool
(
False
)
SGLANG_CPU_QUANTIZATION
=
EnvBool
(
False
)
SGLANG_CPU_QUANTIZATION
=
EnvBool
(
False
)
...
...
python/sglang/srt/layers/attention/attention_registry.py
View file @
a1175a4e
...
@@ -99,7 +99,6 @@ def create_triton_backend(runner):
...
@@ -99,7 +99,6 @@ def create_triton_backend(runner):
return
TritonAttnBackend
(
runner
)
return
TritonAttnBackend
(
runner
)
@
register_attention_backend
(
"torch_native"
)
@
register_attention_backend
(
"torch_native"
)
def
create_torch_native_backend
(
runner
):
def
create_torch_native_backend
(
runner
):
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
...
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
...
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
return
FlashMLABackend
(
runner
)
return
FlashMLABackend
(
runner
)
@
register_attention_backend
(
"dcu_mla"
)
def
create_dcu_mla_backend
(
runner
):
from
sglang.srt.layers.attention.dcu_mla_backend
import
DCUMLABackend
return
DCUMLABackend
(
runner
)
@
register_attention_backend
(
"fa3"
)
@
register_attention_backend
(
"fa3"
)
def
create_flashattention_v3_backend
(
runner
):
def
create_flashattention_v3_backend
(
runner
):
...
...
python/sglang/srt/layers/attention/dcu_mla_backend.py
0 → 100644
View file @
a1175a4e
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
View file @
a1175a4e
...
@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
...
@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
sgl_kernel.sparse_flash_attn
import
(
from
sgl_kernel.sparse_flash_attn
import
(
convert_vertical_slash_indexes
,
convert_vertical_slash_indexes
,
convert_vertical_slash_indexes_mergehead
,
convert_vertical_slash_indexes_mergehead
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a1175a4e
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
@
dataclass
@
dataclass
...
@@ -328,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -328,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
skip_prefill
=
skip_prefill
self
.
skip_prefill
=
skip_prefill
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
k_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
v_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
if
self
.
is_hybrid
:
if
self
.
is_hybrid
:
self
.
full_to_swa_index_mapping
=
(
self
.
full_to_swa_index_mapping
=
(
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
...
@@ -596,9 +599,11 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -596,9 +599,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
]
if
any
(
if
(
forward_batch
.
extend_prefix_lens_cpu
any
(
forward_batch
.
extend_prefix_lens_cpu
)
)
or
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
...
@@ -608,10 +613,13 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -608,10 +613,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
# # Setup local attention if enabled
if
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
:
# if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if
forward_batch
.
forward_mode
in
(
ForwardMode
.
EXTEND
,
ForwardMode
.
DRAFT_EXTEND_V2
):
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
# Encoder metadata for cross attention
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
if
forward_batch
.
encoder_lens
is
not
None
:
assert
(
assert
(
...
@@ -668,10 +676,11 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -668,10 +676,11 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
if
not
self
.
use_mla
:
if
k_rope
is
None
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
#
layer.k_scale, layer.v_scale
)
)
else
:
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
layer
,
...
@@ -690,7 +699,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -690,7 +699,8 @@ class FlashAttentionBackend(AttentionBackend):
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
)
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
...
@@ -704,7 +714,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -704,7 +714,7 @@ class FlashAttentionBackend(AttentionBackend):
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
q
=
q
.
to
(
self
.
kv_cache_dtype
)
#
q = q.to(self.kv_cache_dtype)
q_rope
=
q_rope
.
to
(
self
.
kv_cache_dtype
)
if
q_rope
is
not
None
else
None
q_rope
=
q_rope
.
to
(
self
.
kv_cache_dtype
)
if
q_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
causal
=
True
causal
=
True
...
@@ -774,60 +784,58 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -774,60 +784,58 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
window_size
=
(
-
1
,
-
1
)
result
=
flash_attn_with_kvcache
(
if
forward_batch
.
attn_attend_prefix_cache
:
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
k_cache
=
key_cache
,
# MHA for chunked prefix kv cache when running model with MLA
v_cache
=
value_cache
,
assert
forward_batch
.
prefix_chunk_idx
is
not
None
page_table
=
page_table
,
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
cache_seqlens
=
cache_seqlens
,
assert
forward_batch
.
prefix_chunk_max_seq_lens
is
not
None
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
chunk_idx
=
forward_batch
.
prefix_chunk_idx
max_seqlen_q
=
max_seqlen_q
,
assert
chunk_idx
>=
0
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
causal
=
False
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
else
:
if
use_cascade_attn
:
output
=
flash_attn_varlen_func
(
o
,
softmax_lse
,
*
rest
=
result
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
flash_attn_with_kvcache
(
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
k_cache
=
key_cache
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
v_cache
=
value_cache
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
max_seqlen_k
=
metadata
.
max_seq_len_q
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2_wrapper
(
if
forward_batch
.
mha_return_lse
:
o
,
output
,
lse
,
*
rest
=
output
softmax_lse
.
T
.
contiguous
(),
lse
=
torch
.
transpose
(
lse
,
0
,
1
).
contiguous
()
o_expand
,
return
output
,
lse
softmax_lse_expand
.
T
.
contiguous
(),
return
output
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
)
else
:
o
=
result
else
:
else
:
if
(
if
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
):
# Do multi-head attention with chunked prefix cache
# Do multi-head attention with chunked prefix cache
if
forward_batch
.
attn_attend_prefix_cache
:
if
forward_batch
.
attn_attend_prefix_cache
:
...
@@ -843,39 +851,32 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -843,39 +851,32 @@ class FlashAttentionBackend(AttentionBackend):
assert
forward_batch
.
mha_return_lse
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
causal
=
False
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
# MHA for extend part of sequence without attending prefix kv cache
cu_seqlens_k
=
(
metadata
.
cu_seqlens_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
cu_seqlens_k
)
max_seqlen_k
=
(
metadata
.
max_seq_len_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
max_seq_len_k
)
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
max_seqlen_k
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -985,10 +986,16 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -985,10 +986,16 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
# if not self.use_mla:
if
k_rope
is
None
:
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
else
:
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
layer
,
...
@@ -1030,7 +1037,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1030,7 +1037,8 @@ class FlashAttentionBackend(AttentionBackend):
if
sinks
is
not
None
:
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
...
@@ -1044,7 +1052,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1044,7 +1052,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
# Do multi-head attention
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
layer
.
layer_id
)
)
...
@@ -1096,26 +1103,33 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1096,26 +1103,33 @@ class FlashAttentionBackend(AttentionBackend):
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
cu_seqlens_q
=
metadata
.
cu_seqlens_q
max_seqlen_q
=
metadata
.
max_seq_len_q
page_table
=
metadata
.
page_table
page_table
=
metadata
.
page_table
cache_seqlens
=
metadata
.
cache_seqlens_int32
cu_seqlens_k
=
metadata
.
cu_seqlens_k
cu_seqlens_k
=
metadata
.
cu_seqlens_k
max
_seqlen
_q
=
metadata
.
max
_seq
_
len
_q
cache
_seqlen
s
=
metadata
.
cache
_seqlen
s_int32
q_reshaped
=
q
.
contiguous
()
.
view
(
key_cache
=
key_cache
.
view
(
-
1
,
layer
.
tp_
q
_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_
k
_head_num
,
layer
.
head_dim
)
)
value_cache
=
value_cache
.
view
(
# Default: single-token self-attention
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
result
=
flash_attn_with_kvcache
(
)
q
=
q_reshaped
,
if
layer
.
is_cross_attention
:
k_cache
=
key_cache
,
page_table
=
metadata
.
encoder_page_table
v_cache
=
value_cache
,
cache_seqlens
=
metadata
.
encoder_lens_int32
page_table
=
page_table
,
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
cache_seqlens
=
cache_seqlens
,
window_size
=
(
-
1
,
-
1
)
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
if
max_seqlen_q
>
1
:
cu_seqlens_k_new
=
cu_seqlens_k
,
result
=
flash_attn_varlen_func
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
causal
=
True
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
...
@@ -1124,36 +1138,26 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1124,36 +1138,26 @@ class FlashAttentionBackend(AttentionBackend):
num_splits
=
self
.
num_splits
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
if
use_cascade_attn
:
else
:
o
,
softmax_lse
,
*
rest
=
result
result
=
flash_attn_with_kvcache
(
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
v_cache
=
value_cache
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
page_table
=
page_table
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens
_int32
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq
_
len_q
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
Fals
e
,
causal
=
Tru
e
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
)
o
,
_
=
merge_state_v2
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
)
else
:
o
=
result
o
=
result
else
:
else
:
# Do absorbed multi-latent attention
# Do absorbed multi-latent attention
...
...
python/sglang/srt/layers/attention/flashattention_interface.py
0 → 100644
View file @
a1175a4e
from
flash_attn
import
(
flash_attn_varlen_func
as
flash_attn_varlen_func_interface
,
flash_attn_with_kvcache
as
flash_attn_with_kvcache_interface
)
from
typing
import
Optional
,
Union
import
torch
def
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
=
None
,
v
=
None
,
qv
=
None
,
rotary_cos
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[
int
,
torch
.
Tensor
]]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
page_table
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_k_new
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
rotary_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
q_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
k_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
v_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
attention_chunk
:
Optional
[
int
]
=
None
,
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
scheduler_metadata
=
None
,
num_splits
=
0
,
# Can be tuned for speed
pack_gqa
=
None
,
# Can be tuned for speed
sm_margin
=
0
,
# Can be tuned if some SMs are used for communication
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_with_kvcache_interface
(
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
k_cache
=
k_cache
.
view
(
q
.
dtype
),
v_cache
=
v_cache
.
view
(
q
.
dtype
),
block_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
return_softmax_lse
,
num_splits
=
num_splits
,
)
def
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
=
None
,
max_seqlen_k
=
None
,
seqused_q
=
None
,
seqused_k
=
None
,
page_table
=
None
,
softmax_scale
=
None
,
causal
=
False
,
qv
=
None
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
window_size
=
(
-
1
,
-
1
),
attention_chunk
=
0
,
softcap
=
0.0
,
num_splits
=
1
,
pack_gqa
=
None
,
sm_margin
=
0
,
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_varlen_func_interface
(
q
=
q
,
k
=
k
.
view
(
q
.
dtype
),
v
=
v
.
view
(
q
.
dtype
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
return_attn_probs
=
return_softmax_lse
,
softcap
=
softcap
,
)
\ No newline at end of file
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
a1175a4e
...
@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
...
@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
from
sgl_kernel.flash_mla
import
dcu_create_flashmla_kv_indices
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
use_sglang_create_flashmla_kv_indices_triton
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
max_seqlen_pad
=
triton
.
cdiv
(
...
@@ -91,6 +95,18 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -91,6 +95,18 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
,
device
=
forward_batch
.
seq_lens
.
device
,
)
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
...
@@ -121,10 +137,22 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -121,10 +137,22 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
device
=
seq_lens
.
device
,
)
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
seq_lens
,
forward_batch
.
seq_lens
,
None
,
None
,
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
...
...
python/sglang/srt/layers/attention/lightop_concat.py
0 → 100644
View file @
a1175a4e
from
__future__
import
annotations
import
warnings
import
torch
from
sglang.srt.utils
import
get_bool_env_var
,
direct_register_custom_op
_USE_OPT_CAT
=
get_bool_env_var
(
"SGLANG_USE_OPT_CAT"
)
if
_USE_OPT_CAT
:
try
:
from
lightop
import
ds_cat
# type: ignore
except
ImportError
:
# pragma: no cover
ds_cat
=
None
warnings
.
warn
(
"SGLANG_USE_OPT_CAT 已开启但无法导入 lightop.ds_cat,退回 torch.cat"
)
else
:
ds_cat
=
None
# TODO: 单独注册有些问题
def
ds_cat_wrapper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
ds_cat
(
A
,
B
,
C
,
mode
)
return
C
def
ds_cat_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
# 使用标准cat作为fake实现
return
torch
.
cat
([
A
,
B
],
dim
=
dim
)
direct_register_custom_op
(
op_name
=
"ds_cat"
,
op_func
=
ds_cat_wrapper
,
mutates_args
=
[],
# 没有修改参数,只有返回值
fake_impl
=
ds_cat_fake
)
def
concat_decode_opt
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
mode
=
0
if
dim
!=
0
:
return
torch
.
ops
.
sglang
.
ds_cat
(
A
,
B
,
dim
,
mode
)
assert
False
,
"not support"
# def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
python/sglang/srt/layers/attention/nsa_backend.py
View file @
a1175a4e
...
@@ -47,7 +47,8 @@ if _is_hip:
...
@@ -47,7 +47,8 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
)
else
:
else
:
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_with_kvcache
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
...
python/sglang/srt/layers/attention/xpu_backend.py
View file @
a1175a4e
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
class
XPUAttentionBackend
(
AttentionBackend
):
class
XPUAttentionBackend
(
AttentionBackend
):
...
...
python/sglang/srt/layers/layernorm.py
View file @
a1175a4e
...
@@ -160,21 +160,53 @@ class RMSNorm(CustomOp):
...
@@ -160,21 +160,53 @@ class RMSNorm(CustomOp):
return
output
,
residual_out
return
output
,
residual_out
return
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def
forward_hip
(
def
forward_hip
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
):
if
not
x
.
is_contiguous
():
if
not
x
.
is_contiguous
():
# NOTE: Remove this if aiter kernel supports discontinuous input
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
if
residual
is
not
None
:
if
residual
is
not
None
:
if
_vllm_version
<
Version
(
"0.9"
):
try
:
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
return
x
,
residual
else
:
except
TypeError
:
residual_out
=
torch
.
empty_like
(
x
)
output
=
torch
.
empty_like
(
x
)
output
=
torch
.
empty_like
(
x
)
residual_out
=
torch
.
empty_like
(
x
)
fused_add_rms_norm
(
fused_add_rms_norm
(
output
,
output
,
x
,
x
,
...
@@ -184,10 +216,13 @@ class RMSNorm(CustomOp):
...
@@ -184,10 +216,13 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
,
self
.
variance_epsilon
,
)
)
return
output
,
residual_out
return
output
,
residual_out
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
return
out
def
forward_native
(
def
forward_native
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/linear.py
View file @
a1175a4e
...
@@ -45,6 +45,18 @@ _is_hip = is_hip()
...
@@ -45,6 +45,18 @@ _is_hip = is_hip()
_disable_hip_linear_quant
=
_is_hip
and
get_bool_env_var
(
_disable_hip_linear_quant
=
_is_hip
and
get_bool_env_var
(
"SGLANG_ROCM_DISABLE_LINEARQUANT"
"SGLANG_ROCM_DISABLE_LINEARQUANT"
)
)
_use_fused_rms_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_RMS_QUANT"
)
_use_fused_silu_mul_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_SILU_MUL_QUANT"
)
if
_use_fused_rms_quant
:
try
:
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
except
Exception
as
e
:
print
(
f
"Error: Import fused rmsquant error:
{
e
}
"
)
if
_use_fused_silu_mul_quant
:
try
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
except
Exception
as
e
:
print
(
f
"Error: Import fused silu_mul_quant error:
{
e
}
"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1360,7 +1372,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1360,7 +1372,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters.
# It does not support additional parameters.
param
.
load_row_parallel_weight
(
loaded_weight
)
param
.
load_row_parallel_weight
(
loaded_weight
)
def
forward
(
self
,
input_
,
skip_all_reduce
=
False
):
def
forward
(
self
,
input_
,
skip_all_reduce
=
False
,
use_fused_silu_mul_quant
=
False
):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
...
@@ -1374,10 +1386,19 @@ class RowParallelLinear(LinearBase):
...
@@ -1374,10 +1386,19 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
with
use_symmetric_memory
(
if
use_fused_silu_mul_quant
:
get_tp_group
(),
disabled
=
not
is_allocation_symmetric
()
xq
,
xs
=
lm_fuse_silu_mul_quant
(
input_parallel
)
):
silu_quant_args
=
[
xq
,
xs
]
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
,
silu_quant_args
=
silu_quant_args
)
sm
.
tag
(
output_parallel
)
else
:
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
sm
.
tag
(
output_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
skip_all_reduce
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
skip_all_reduce
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
...
...
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
a1175a4e
...
@@ -2,7 +2,6 @@ import logging
...
@@ -2,7 +2,6 @@ import logging
import
torch
import
torch
import
triton
import
triton
from
sglang.srt.utils
import
ceil_div
,
is_cuda
from
sglang.srt.utils
import
ceil_div
,
is_cuda
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1015,196 +1014,133 @@ def zero_experts_compute_triton(
...
@@ -1015,196 +1014,133 @@ def zero_experts_compute_triton(
return
output
return
output
from
triton.language.extra
import
libdevice
from
typing
import
Optional
@
triton
.
jit
@
triton
.
jit
def
compute_problem_sizes_w4a8_kernel
(
def
_per_token_quant_int8_one_kernel_opt
(
masked_m_ptr
,
x_ptr
,
problem_sizes1_ptr
,
xq_ptr
,
problem_sizes2_ptr
,
scale_ptr
,
n
,
stride_x
,
k
,
stride_xq
,
num_experts
,
N
,
BLOCK_SIZE
:
tl
.
constexpr
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
):
pid
=
tl
.
program_id
(
axis
=
0
)
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
row_id
=
tl
.
program_id
(
0
)
mask
=
pid
<
num_experts
final_occurrences
=
tl
.
load
(
masked_m_ptr
+
pid
,
mask
=
mask
,
other
=
0
)
ps1_idx_0
=
pid
*
3
ps1_idx_1
=
ps1_idx_0
+
1
ps1_idx_2
=
ps1_idx_0
+
2
ps2_idx_0
=
pid
*
3
ps2_idx_1
=
ps2_idx_0
+
1
ps2_idx_2
=
ps2_idx_0
+
2
ps1_mask_0
=
ps1_idx_0
<
num_experts
*
3
ps1_mask_1
=
ps1_idx_1
<
num_experts
*
3
ps1_mask_2
=
ps1_idx_2
<
num_experts
*
3
ps2_mask_0
=
ps2_idx_0
<
num_experts
*
3
ps2_mask_1
=
ps2_idx_1
<
num_experts
*
3
ps2_mask_2
=
ps2_idx_2
<
num_experts
*
3
tl
.
store
(
problem_sizes1_ptr
+
ps1_idx_0
,
2
*
n
,
mask
=
ps1_mask_0
)
tl
.
store
(
problem_sizes1_ptr
+
ps1_idx_1
,
final_occurrences
,
mask
=
ps1_mask_1
)
tl
.
store
(
problem_sizes1_ptr
+
ps1_idx_2
,
k
,
mask
=
ps1_mask_2
)
tl
.
store
(
problem_sizes2_ptr
+
ps2_idx_0
,
k
,
mask
=
ps2_mask_0
)
if
tokens_per_expert_ptr
is
not
None
:
tl
.
store
(
problem_sizes2_ptr
+
ps2_idx_1
,
final_occurrences
,
mask
=
ps2_mask_1
)
e
=
row_id
//
T_dim
tl
.
store
(
problem_sizes2_ptr
+
ps2_idx_2
,
n
,
mask
=
ps2_mask_2
)
t
=
row_id
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
def
compute_problem_sizes_w4a8
(
if
t
>=
num_valid_tokens_for_e
:
masked_m
,
problem_sizes1
,
problem_sizes2
,
n
,
k
,
num_experts
return
):
BLOCK_SIZE
=
256
grid
=
lambda
meta
:
(
triton
.
cdiv
(
num_experts
,
meta
[
"BLOCK_SIZE"
]),)
compute_problem_sizes_w4a8_kernel
[
grid
](
masked_m
,
problem_sizes1
,
problem_sizes2
,
n
,
k
,
num_experts
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
problem_sizes1
,
problem_sizes2
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
def
deepep_ll_get_cutlass_w4a8_moe_mm_data
(
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
masked_m
,
other
=
0.0
).
to
(
tl
.
float32
)
problem_sizes1
,
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
problem_sizes2
,
scale_x
=
absmax
/
127
num_experts
,
x_q
=
x
*
(
127
/
absmax
)
n
,
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
k
,
):
problem_sizes1
,
problem_sizes2
=
compute_problem_sizes_w4a8
(
masked_m
,
problem_sizes1
,
problem_sizes2
,
n
,
k
,
num_experts
)
return
(
problem_sizes1
.
to
(
torch
.
int32
),
problem_sizes2
.
to
(
torch
.
int32
),
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
@
triton
.
jit
@
triton
.
jit
def
_silu_and_mul_post_per_tensor_quant_kernel
(
def
_per_token_quant_int8_kernel_opt
(
input_ptr
,
x_ptr
,
stride_input_expert
,
xq_ptr
,
stride_input_token
,
stride_input_dim
,
output_ptr
,
stride_output_expert
,
stride_output_token
,
stride_output_dim
,
scale_ptr
,
scale_ptr
,
masked_m_ptr
,
stride_x
,
inner_dim
,
stride_xq
,
fp8_max
,
N
,
fp8_min
,
E_dim
,
BLOCK_N
:
tl
.
constexpr
,
T_dim
,
NUM_STAGE
:
tl
.
constexpr
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
):
"""
token_idx_start
=
tl
.
program_id
(
0
)
Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
grid_size
=
tl
.
num_programs
(
0
)
num_total_tokens
=
E_dim
*
T_dim
Shape:
input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D]
output: [E, T_padded, D], dtype=float8_e4m3fn
"""
expert_id
=
tl
.
program_id
(
2
)
block_id_token
=
tl
.
program_id
(
1
)
block_id_dim
=
tl
.
program_id
(
0
)
num_token_blocks
=
tl
.
num_programs
(
1
)
token_num_cur_expert
=
tl
.
load
(
masked_m_ptr
+
expert_id
)
scale
=
1.0
/
tl
.
load
(
scale_ptr
).
to
(
tl
.
float32
)
stride_input_expert
=
tl
.
cast
(
stride_input_expert
,
tl
.
int32
)
stride_output_expert
=
tl
.
cast
(
stride_output_expert
,
tl
.
int32
)
stride_input_token
=
tl
.
cast
(
stride_input_token
,
tl
.
int32
)
stride_output_token
=
tl
.
cast
(
stride_output_token
,
tl
.
int32
)
offset_d
=
block_id_dim
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_d
=
offset_d
<
inner_dim
# base pointers for current expert and dim block
for
token_idx
in
range
(
token_idx_start
,
num_total_tokens
,
grid_size
):
input_base_offs
=
input_ptr
+
expert_id
*
stride_input_expert
+
offset_d
output_base_offs
=
output_ptr
+
expert_id
*
stride_output_expert
+
offset_d
for
token_idx
in
tl
.
range
(
is_valid_token
=
True
block_id_token
,
token_num_cur_expert
,
num_token_blocks
,
num_stages
=
NUM_STAGE
if
tokens_per_expert_ptr
is
not
None
:
):
e
=
token_idx
//
T_dim
gate_ptr
=
input_base_offs
+
token_idx
*
stride_input_token
t
=
token_idx
%
T_dim
up_ptr
=
gate_ptr
+
inner_dim
gate
=
tl
.
load
(
gate_ptr
,
mask
=
mask_d
,
other
=
0.0
).
to
(
tl
.
float32
)
up
=
tl
.
load
(
up_ptr
,
mask
=
mask_d
,
other
=
0.0
).
to
(
tl
.
float32
)
# SiLU: x * sigmoid(x)
gate
=
gate
/
(
1
+
tl
.
exp
(
-
gate
))
gate
=
gate
.
to
(
input_ptr
.
dtype
.
element_ty
)
gate_up
=
up
*
gate
scaled
=
gate_up
*
scale
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
output_q
=
tl
.
clamp
(
scaled
,
fp8_min
,
fp8_max
).
to
(
output_ptr
.
dtype
.
element_ty
)
out_ptr
=
output_base_offs
+
token_idx
*
stride_output_token
tl
.
store
(
out_ptr
,
output_q
,
mask
=
mask_d
)
if
t
>=
num_valid_tokens_for_e
:
is_valid_token
=
False
def
silu_and_mul_masked_post_per_tensor_quant_fwd
(
if
is_valid_token
:
input
:
torch
.
Tensor
,
cols
=
tl
.
arange
(
0
,
BLOCK
)
output
:
torch
.
Tensor
,
mask
=
cols
<
N
masked_m
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Fused SiLU + Mul + Per-Tensor Quantization to FP8.
Args:
x
=
tl
.
load
(
x_ptr
+
token_idx
*
stride_x
+
cols
,
mask
=
mask
,
input: [expert_num, token_num_padded, 2 * inner_dim]
other
=
0.0
).
to
(
tl
.
float32
)
output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
masked_m: [expert_num], actual token count for each expert
scale_x
=
absmax
/
127
scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
Returns:
tl
.
store
(
xq_ptr
+
token_idx
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
output tensor
tl
.
store
(
scale_ptr
+
token_idx
,
scale_x
)
"""
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
assert
output
.
dtype
==
torch
.
float8_e4m3fn
assert
input
.
ndim
==
3
assert
input
.
shape
[
0
]
==
masked_m
.
shape
[
0
]
assert
input
.
shape
[
-
1
]
%
2
==
0
assert
scale
.
numel
()
==
1
or
scale
.
shape
[
0
]
==
input
.
shape
[
0
]
expert_num
=
input
.
shape
[
0
]
# 3584
inner_dim
=
input
.
shape
[
-
1
]
//
2
BLOCK_N
=
256
def
per_token_quant_int8_triton_opt
(
x
:
torch
.
Tensor
,
BLOCK_M
=
64
if
expert_num
<
4
else
32
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
):
NUM_STAGES
=
3
if
x
.
dim
()
!=
3
:
hidden_dim_split_block_num
=
triton
.
cdiv
(
inner_dim
,
BLOCK_N
)
raise
ValueError
(
f
"Input must be 3D [E, T, H], but got
{
x
.
shape
}
"
)
E
,
T
,
H
=
x
.
shape
N
=
H
grid
=
(
hidden_dim_split_block_num
,
BLOCK_M
,
expert_num
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
fp8_max
=
finfo
.
max
BLOCK
=
triton
.
next_power_of_2
(
N
)
fp8_min
=
-
fp8_max
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
num_warps
=
1
_silu_and_mul_post_per_tensor_quant_kernel
[
grid
](
num_tokens
=
E
*
T
input
,
grid_opt
=
num_tokens
*
input
.
stride
(),
output
,
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
*
output
.
stride
(),
grid_opt
=
max
(
1
,
num_tokens
//
(
T
//
256
))
scale
,
_per_token_quant_int8_kernel_opt
[(
grid_opt
,
)](
masked_m
,
x
,
inner_dim
,
x_q
,
fp8_max
,
scales
,
fp8_min
,
stride_x
=
x
.
stride
(
-
2
),
BLOCK_N
=
BLOCK_N
,
stride_xq
=
x_q
.
stride
(
-
2
),
NUM_STAGE
=
NUM_STAGES
,
N
=
N
,
E_dim
=
E
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
)
return
output
else
:
_per_token_quant_int8_one_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
a1175a4e
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
a1175a4e
...
@@ -65,7 +65,7 @@ def inplace_fused_experts(
...
@@ -65,7 +65,7 @@ def inplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
@@ -84,6 +84,8 @@ def inplace_fused_experts(
...
@@ -84,6 +84,8 @@ def inplace_fused_experts(
gemm1_limit
:
Optional
[
float
]
=
None
,
gemm1_limit
:
Optional
[
float
]
=
None
,
filter_expert
:
bool
=
True
,
filter_expert
:
bool
=
True
,
)
->
None
:
)
->
None
:
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
fused_experts_impl
(
fused_experts_impl
(
hidden_states
,
hidden_states
,
w1
,
w1
,
...
@@ -123,7 +125,7 @@ def inplace_fused_experts_fake(
...
@@ -123,7 +125,7 @@ def inplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
@@ -161,7 +163,7 @@ def outplace_fused_experts(
...
@@ -161,7 +163,7 @@ def outplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
@@ -181,6 +183,8 @@ def outplace_fused_experts(
...
@@ -181,6 +183,8 @@ def outplace_fused_experts(
gemm1_limit
:
Optional
[
float
]
=
None
,
gemm1_limit
:
Optional
[
float
]
=
None
,
filter_expert
:
bool
=
True
,
filter_expert
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
return
fused_experts_impl
(
return
fused_experts_impl
(
hidden_states
,
hidden_states
,
w1
,
w1
,
...
@@ -220,7 +224,7 @@ def outplace_fused_experts_fake(
...
@@ -220,7 +224,7 @@ def outplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
@@ -273,9 +277,12 @@ def fused_experts(
...
@@ -273,9 +277,12 @@ def fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
):
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
filter_expert
=
(
act_id
=
(
moe_runner_config
.
num_experts
is
None
0
if
(
or
moe_runner_config
.
num_experts
!=
moe_runner_config
.
num_local_experts
moe_runner_config
.
activation
==
0
or
(
isinstance
(
moe_runner_config
.
activation
,
str
)
and
moe_runner_config
.
activation
.
lower
()
==
"silu"
)
)
else
1
)
)
if
moe_runner_config
.
inplace
:
if
moe_runner_config
.
inplace
:
assert
not
moe_runner_config
.
no_combine
,
"no combine + inplace makes no sense"
assert
not
moe_runner_config
.
no_combine
,
"no combine + inplace makes no sense"
...
@@ -287,7 +294,7 @@ def fused_experts(
...
@@ -287,7 +294,7 @@ def fused_experts(
topk_ids
,
topk_ids
,
b1
,
b1
,
b2
,
b2
,
moe_runner_config
.
activation
,
act_id
,
moe_runner_config
.
apply_router_weight_on_input
,
moe_runner_config
.
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a8
,
...
@@ -316,7 +323,7 @@ def fused_experts(
...
@@ -316,7 +323,7 @@ def fused_experts(
topk_ids
,
topk_ids
,
b1
,
b1
,
b2
,
b2
,
moe_runner_config
.
activation
,
act_id
,
moe_runner_config
.
apply_router_weight_on_input
,
moe_runner_config
.
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a8
,
...
@@ -366,7 +373,7 @@ def fused_experts_impl(
...
@@ -366,7 +373,7 @@ def fused_experts_impl(
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
@@ -386,6 +393,8 @@ def fused_experts_impl(
...
@@ -386,6 +393,8 @@ def fused_experts_impl(
gemm1_limit
:
Optional
[
float
]
=
None
,
gemm1_limit
:
Optional
[
float
]
=
None
,
filter_expert
:
bool
=
True
,
filter_expert
:
bool
=
True
,
):
):
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
padded_size
=
padding_size
padded_size
=
padding_size
if
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
_use_aiter
:
if
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
_use_aiter
:
padded_size
=
0
padded_size
=
0
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment