Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
48542418
Commit
48542418
authored
Nov 24, 2025
by
renzhc
Browse files
fix lightop gemm smooth and add some trivial changes
parent
1a73f6a3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
275 additions
and
12 deletions
+275
-12
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+12
-0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+1
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+29
-1
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+69
-8
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+8
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+155
-0
No files found.
python/sglang/srt/_custom_ops.py
View file @
48542418
...
...
@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
blaslt_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
0
]
k
=
a
.
shape
[
1
]
_
,
out
=
quant_ops
.
hipblaslt_w8a8_gemm
(
a
,
b
,
scale_a
,
scale_b
,
m
,
n
,
k
,
'NT'
,
out_dtype
)
return
out
def
triton_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
48542418
...
...
@@ -725,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self
.
packed_recv_count
=
self
.
handle
=
None
return
combined_hidden_states
,
event
,
hook
@
torch
.
_dynamo
.
disable
()
def
_get_buffer
(
self
):
DeepEPBuffer
.
set_dispatch_mode_as_low_latency
()
return
DeepEPBuffer
.
get_deepep_buffer
(
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
48542418
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
os
import
logging
from
contextlib
import
suppress
from
typing
import
Any
,
Dict
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
cast
...
...
@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt.utils
import
W8a8GetCacheJSON
# TODO: remove vllm dependency
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"CompressedTensorsLinearMethod"
]
...
...
@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
[
n
,
k
]
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
([
n
,
k
])
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
elif
self
.
w8a8_strategy
==
3
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
T
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
self
.
tritonsingleton
.
gen_model_json
()
layer
.
scheme
.
process_weights_after_loading
(
layer
)
def
create_weights
(
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
48542418
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Callable
,
Optional
import
torch
...
...
@@ -19,11 +20,14 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.utils
import
is_cuda
from
lmslim
import
quant_ops
from
sglang.srt
import
_custom_ops
as
ops
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
# TODO: remove vllm deps
from
sglang.srt.utils
import
W8a8GetCacheJSON
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
...
...
@@ -33,6 +37,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
input_symmetric
=
input_symmetric
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
# TODO
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
...
...
@@ -163,14 +168,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
@
torch
.
_dynamo
.
disable
()
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
# TODO: add cutlass_scaled_mm_azp support
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
# TODO: fix with lmslim/lightop
return
quant_ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
# return quant_ops.custom_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias)
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
elif
self
.
w8a8_strategy
==
3
:
return
ops
.
blaslt_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
None
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
python/sglang/srt/layers/quantization/slimquant_w4a8.py
View file @
48542418
...
...
@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8
,
per_token_quant_int8
)
from
sglang.srt
import
_custom_ops
as
ops
from
vllm
.utils
import
W8a8GetCacheJSON
from
sglang.srt
.utils
import
W8a8GetCacheJSON
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
import
os
...
...
@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需要lmslim/lightop配合
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
elif
self
.
w8a8_strategy
==
3
:
return
ops
.
blaslt_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
None
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
48542418
...
...
@@ -653,7 +653,7 @@ class ForwardBatch:
bs
=
self
.
batch_size
,
)
else
:
logger
.
info
(
"SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0"
)
#
logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
create_chunked_prefix_cache_kv_indices
[(
self
.
batch_size
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
...
...
python/sglang/srt/utils/common.py
View file @
48542418
...
...
@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None):
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
# from vllm
class
W8a8GetCacheJSON
:
_instance
=
None
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
(
W8a8GetCacheJSON
,
cls
).
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
.
_initialize
()
return
cls
.
_instance
def
_initialize
(
self
):
current_folder_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
json_folder_path
=
current_folder_path
+
'/../../lmslim/configs/w8a8'
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
json_folder_path
))
self
.
triton_json_dict
=
{}
self
.
triton_moejson_dict
=
{}
self
.
triton_json_list
=
[]
self
.
weight_shapes
=
[]
self
.
moe_weight_shapes
=
[]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
self
.
device_name
=
device_name
self
.
topk
=
1
self
.
quant_method
=
None
#析构函数,最后会生成model.json的配置文件
def
gen_model_json
(
self
,
E
:
Optional
[
int
]
=
0
,
block_size
:
Optional
[
list
]
=
None
):
json_dir
=
os
.
getenv
(
'LMSLIM_TUNING_JSON'
,
"None"
)
if
json_dir
!=
"None"
and
os
.
path
.
exists
(
json_dir
):
#生成模型配置文件
# logger.info("model_tuning.json is at LMSLIM_TUNING_JSON:%s", json_dir)
config
=
{
"layers"
:
{
"linear"
:
{
"shapes"
:
[],
"m_range"
:
"None"
,
},
"moe"
:
{
"shapes"
:
[],
"m_range"
:
"None"
,
"topk"
:
self
.
topk
}
},
"quantization_config"
:
{
"quant_method"
:
self
.
quant_method
,
"weight_block_size"
:
"None"
}
}
# 处理 MoE shapes
for
shape
in
self
.
moe_weight_shapes
:
if
len
(
shape
)
==
4
:
# 假设 MoE shape 是 [N1, N2,K] 格式
moe_config
=
{
"E"
:
shape
[
0
],
"N1"
:
shape
[
1
],
"N2"
:
shape
[
2
],
"K"
:
shape
[
3
],
# 默认值
}
config
[
"layers"
][
"moe"
][
"shapes"
].
append
(
moe_config
)
for
shape
in
self
.
weight_shapes
:
config
[
"layers"
][
"linear"
][
"shapes"
].
append
(
shape
)
if
block_size
is
not
None
:
config
[
"quantization_config"
][
"weight_block_size"
]
=
block_size
with
open
(
json_dir
+
"/model.json"
,
'w'
)
as
f
:
json
.
dump
(
config
,
f
,
indent
=
4
)
# else:
# logger.info("LMSLIM_TUNING_JSON is not set")
def
getspec_config
(
self
,
configs_dict
,
M
,
N
,
K
):
if
f
"
{
M
}
_
{
N
}
_
{
K
}
"
in
configs_dict
:
return
configs_dict
[
f
"
{
M
}
_
{
N
}
_
{
K
}
"
]
else
:
return
None
def
get_triton_cache
(
self
,
file_path
,
n
,
k
):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
return
configs_dict
def
get_w8a8json_name
(
self
,
n
,
k
):
return
self
.
triton_json_dir
+
f
"/W8A8_
{
n
}
_
{
k
}
_
{
self
.
device_name
}
.json"
def
get_blockint8_triton_cache
(
self
,
file_path
,
n
,
k
,
block_n
,
block_k
):
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
return
configs_dict
def
get_blockint8json_name
(
self
,
n
,
k
,
block_n
,
block_k
):
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
):
if
use_int4_w4a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config1,config2]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
return
configs_dict
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment