Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c28ad199
Unverified
Commit
c28ad199
authored
Jul 17, 2025
by
Peng Zhang
Committed by
GitHub
Jul 16, 2025
Browse files
[1/n] chore: decouple quantization implementation from vLLM dependency (#7992)
parent
570d3343
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1478 additions
and
616 deletions
+1478
-616
python/sglang/srt/layers/moe/fused_moe_triton/__init__.py
python/sglang/srt/layers/moe/fused_moe_triton/__init__.py
+4
-1
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+2
-4
python/sglang/srt/layers/quantization/gptq.py
python/sglang/srt/layers/quantization/gptq.py
+491
-119
python/sglang/srt/layers/quantization/marlin_utils.py
python/sglang/srt/layers/quantization/marlin_utils.py
+781
-0
python/sglang/srt/layers/quantization/moe_wna16.py
python/sglang/srt/layers/quantization/moe_wna16.py
+30
-0
python/sglang/srt/layers/quantization/quant_utils.py
python/sglang/srt/layers/quantization/quant_utils.py
+0
-166
python/sglang/srt/layers/quantization/scalar_type.py
python/sglang/srt/layers/quantization/scalar_type.py
+0
-0
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+162
-1
sgl-kernel/python/sgl_kernel/fused_moe.py
sgl-kernel/python/sgl_kernel/fused_moe.py
+2
-1
sgl-kernel/tests/test_marlin_repack.py
sgl-kernel/tests/test_marlin_repack.py
+2
-4
test/srt/test_gptqmodel_dynamic.py
test/srt/test_gptqmodel_dynamic.py
+4
-5
test/srt/test_int4_kernel.py
test/srt/test_int4_kernel.py
+0
-301
test/srt/test_w4a8.py
test/srt/test_w4a8.py
+0
-14
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/__init__.py
View file @
c28ad199
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Optional
import
sglang.srt.layers.moe.fused_moe_triton.fused_moe
# noqa
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
fused_experts
,
get_config_file_name
,
moe_align_block_size
,
try_get_optimal_moe_config
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FusedMoE
,
...
...
@@ -37,4 +38,6 @@ __all__ = [
"fused_moe"
,
"fused_experts"
,
"get_config_file_name"
,
"moe_align_block_size"
,
"try_get_optimal_moe_config"
,
]
python/sglang/srt/layers/quantization/__init__.py
View file @
c28ad199
...
...
@@ -22,10 +22,6 @@ try:
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
,
)
...
...
@@ -59,7 +55,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.gptq
import
(
GPTQConfig
,
GPTQLinearMethod
,
GPTQMarlinConfig
,
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
)
from
sglang.srt.layers.quantization.modelopt_quant
import
(
...
...
python/sglang/srt/layers/quantization/gptq.py
View file @
c28ad199
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/quantization/marlin_utils.py
0 → 100644
View file @
c28ad199
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/quantization/moe_wna16.py
View file @
c28ad199
...
...
@@ -19,6 +19,36 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
def
get_weight_perm
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
np
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
np
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
np
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
class
MoeWNA16Config
(
QuantizationConfig
):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
...
...
python/sglang/srt/layers/quantization/quant_utils.py
deleted
100644 → 0
View file @
570d3343
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from
typing
import
Optional
import
numpy
import
torch
from
sgl_kernel.scalar_type
import
ScalarType
def
get_pack_factor
(
num_bits
):
assert
32
%
num_bits
==
0
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
pack_cols
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
unpack_cols
(
packed_q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
assert
packed_q_w
.
shape
==
(
size_k
,
size_n
//
pack_factor
,
),
"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}"
.
format
(
packed_q_w
.
shape
,
size_k
,
size_n
,
pack_factor
)
orig_device
=
packed_q_w
.
device
packed_q_w_cpu
=
packed_q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
),
dtype
=
numpy
.
uint32
)
mask
=
(
1
<<
num_bits
)
-
1
for
i
in
range
(
pack_factor
):
vals
=
packed_q_w_cpu
&
mask
packed_q_w_cpu
>>=
num_bits
q_res
[:,
i
::
pack_factor
]
=
vals
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
zero_points
:
bool
=
False
,
ref_zero_points_after_scales
:
bool
=
False
,
):
assert
(
quant_type
.
is_integer
()
),
"Floating point quantization may work but has not been tested"
assert
not
zero_points
or
group_size
is
not
None
,
(
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device
=
w
.
device
orig_type
=
w
.
dtype
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
if
group_size
==
-
1
:
group_size
=
size_k
# Reshape to [groupsize, -1]
if
group_size
is
not
None
and
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
max_val
=
torch
.
max
(
w
,
0
,
keepdim
=
True
).
values
min_val
=
torch
.
min
(
w
,
0
,
keepdim
=
True
).
values
max_q_val
=
quant_type
.
max
()
min_q_val
=
quant_type
.
min
()
w_s
=
torch
.
Tensor
([
1.0
]).
to
(
w
.
device
)
# unscaled case
maybe_w_zp
=
None
if
group_size
is
not
None
:
if
zero_points
:
assert
not
quant_type
.
is_signed
()
and
quant_type
.
max
()
>
0
w_s
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
quant_type
.
max
()
maybe_w_zp
=
(
torch
.
round
(
torch
.
abs
(
min_val
/
w_s
)).
clamp
(
min_q_val
,
max_q_val
).
int
()
)
else
:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s
=
torch
.
max
(
abs
(
max_val
/
(
max_q_val
if
max_q_val
!=
0
else
torch
.
inf
)),
abs
(
min_val
/
(
min_q_val
if
min_q_val
!=
0
else
torch
.
inf
)),
)
# Quantize
w_q
=
torch
.
round
(
w
/
w_s
).
int
()
+
(
maybe_w_zp
if
zero_points
else
0
)
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if
ref_zero_points_after_scales
and
maybe_w_zp
is
not
None
:
w_ref
=
w_q
.
to
(
orig_type
)
*
w_s
-
maybe_w_zp
.
to
(
orig_type
)
*
w_s
else
:
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
if
quant_type
.
has_bias
():
w_q
+=
quant_type
.
bias
# Restore original shapes
if
group_size
is
not
None
and
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
w_q
=
reshape_w
(
w_q
)
w_ref
=
reshape_w
(
w_ref
)
w_s
=
w_s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
if
maybe_w_zp
is
not
None
:
maybe_w_zp
=
maybe_w_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
maybe_w_zp
=
maybe_w_zp
.
to
(
device
=
orig_device
)
return
(
w_ref
.
to
(
device
=
orig_device
),
w_q
.
to
(
device
=
orig_device
),
w_s
if
group_size
is
not
None
else
None
,
maybe_w_zp
,
)
sgl-kernel/python/sgl_kernel
/scalar_type.py
→
python/sglang/srt/layers/quantization
/scalar_type.py
View file @
c28ad199
File moved
python/sglang/srt/layers/quantization/utils.py
View file @
c28ad199
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from
types
import
MappingProxyType
from
typing
import
List
,
Mapping
,
Tuple
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
import
numpy
import
torch
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_npu
_is_cuda
=
is_cuda
()
...
...
@@ -143,3 +145,162 @@ def replace_parameter(
if
not
isinstance
(
new
,
torch
.
nn
.
Parameter
):
new
=
torch
.
nn
.
Parameter
(
new
,
requires_grad
=
False
)
mod
.
register_parameter
(
name
,
torch
.
nn
.
Parameter
(
new
,
requires_grad
=
False
))
def
get_pack_factor
(
num_bits
):
assert
32
%
num_bits
==
0
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
pack_cols
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
unpack_cols
(
packed_q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
assert
packed_q_w
.
shape
==
(
size_k
,
size_n
//
pack_factor
,
),
"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}"
.
format
(
packed_q_w
.
shape
,
size_k
,
size_n
,
pack_factor
)
orig_device
=
packed_q_w
.
device
packed_q_w_cpu
=
packed_q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
),
dtype
=
numpy
.
uint32
)
mask
=
(
1
<<
num_bits
)
-
1
for
i
in
range
(
pack_factor
):
vals
=
packed_q_w_cpu
&
mask
packed_q_w_cpu
>>=
num_bits
q_res
[:,
i
::
pack_factor
]
=
vals
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
def
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
zero_points
:
bool
=
False
,
ref_zero_points_after_scales
:
bool
=
False
,
):
assert
(
quant_type
.
is_integer
()
),
"Floating point quantization may work but has not been tested"
assert
not
zero_points
or
group_size
is
not
None
,
(
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device
=
w
.
device
orig_type
=
w
.
dtype
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
if
group_size
==
-
1
:
group_size
=
size_k
# Reshape to [groupsize, -1]
if
group_size
is
not
None
and
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
max_val
=
torch
.
max
(
w
,
0
,
keepdim
=
True
).
values
min_val
=
torch
.
min
(
w
,
0
,
keepdim
=
True
).
values
max_q_val
=
quant_type
.
max
()
min_q_val
=
quant_type
.
min
()
w_s
=
torch
.
Tensor
([
1.0
]).
to
(
w
.
device
)
# unscaled case
maybe_w_zp
=
None
if
group_size
is
not
None
:
if
zero_points
:
assert
not
quant_type
.
is_signed
()
and
quant_type
.
max
()
>
0
w_s
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
quant_type
.
max
()
maybe_w_zp
=
(
torch
.
round
(
torch
.
abs
(
min_val
/
w_s
)).
clamp
(
min_q_val
,
max_q_val
).
int
()
)
else
:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s
=
torch
.
max
(
abs
(
max_val
/
(
max_q_val
if
max_q_val
!=
0
else
torch
.
inf
)),
abs
(
min_val
/
(
min_q_val
if
min_q_val
!=
0
else
torch
.
inf
)),
)
# Quantize
w_q
=
torch
.
round
(
w
/
w_s
).
int
()
+
(
maybe_w_zp
if
zero_points
else
0
)
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if
ref_zero_points_after_scales
and
maybe_w_zp
is
not
None
:
w_ref
=
w_q
.
to
(
orig_type
)
*
w_s
-
maybe_w_zp
.
to
(
orig_type
)
*
w_s
else
:
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
if
quant_type
.
has_bias
():
w_q
+=
quant_type
.
bias
# Restore original shapes
if
group_size
is
not
None
and
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
w_q
=
reshape_w
(
w_q
)
w_ref
=
reshape_w
(
w_ref
)
w_s
=
w_s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
if
maybe_w_zp
is
not
None
:
maybe_w_zp
=
maybe_w_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
maybe_w_zp
=
maybe_w_zp
.
to
(
device
=
orig_device
)
return
(
w_ref
.
to
(
device
=
orig_device
),
w_q
.
to
(
device
=
orig_device
),
w_s
if
group_size
is
not
None
else
None
,
maybe_w_zp
,
)
sgl-kernel/python/sgl_kernel/fused_moe.py
View file @
c28ad199
...
...
@@ -2,10 +2,11 @@ import functools
from
typing
import
Optional
import
torch
from
sgl_kernel.scalar_type
import
scalar_types
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
from
sglang.srt.layers.quantization.scalar_type
import
scalar_types
if
has_zp
:
assert
num_bits
==
4
return
scalar_types
.
uint4
...
...
sgl-kernel/tests/test_marlin_repack.py
View file @
c28ad199
import
math
import
numpy
as
np
import
pytest
import
torch
from
sgl_kernel
import
awq_marlin_repack
from
sgl_kernel.scalar_type
import
scalar_types
from
sglang.srt.layers.quantization.quant_utils
import
(
from
sglang.srt.layers.quantization.scalar_type
import
scalar_types
from
sglang.srt.layers.quantization.utils
import
(
get_pack_factor
,
pack_cols
,
quantize_weights
,
...
...
test/srt/test_gptqmodel_dynamic.py
View file @
c28ad199
...
...
@@ -51,13 +51,12 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
model_config
=
model_config
,
load_config
=
load_config
,
device_config
=
device_config
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.gptq
import
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
)
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
linear_method_cls
=
(
GPTQMarlinLinearMethod
if
use_marlin_kernel
else
(
GPTQLinearMethod
)
)
...
...
@@ -162,7 +161,7 @@ class TestGPTQModelDynamicWithMarlin(CustomTestCase):
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--dtype"
,
"float16"
],
other_args
=
[
"--dtype"
,
"
b
float16"
],
)
@
classmethod
...
...
test/srt/test_int4_kernel.py
deleted
100644 → 0
View file @
570d3343
import
itertools
import
sys
import
unittest
import
torch
sys
.
path
.
insert
(
0
,
"/home/hadoop-hmart-waimai-rank/vllm"
)
# from sglang.srt.layers.moe.topk import select_experts
from
sgl_kernel
import
fused_marlin_moe
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
# from vllm.model_executor.layers. import select_experts
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
,
)
from
vllm.scalar_type
import
scalar_types
def
stack_and_dev
(
tensors
:
list
[
torch
.
Tensor
]):
dev
=
tensors
[
0
].
device
return
torch
.
stack
(
tensors
,
dim
=
0
).
to
(
dev
)
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
expert_map
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
def
native_w8a8_per_token_matmul
(
A
,
B
,
As
,
Bs
,
output_dtype
=
torch
.
float16
):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
],
"Dimension mismatch"
assert
B
.
ndim
==
2
and
B
.
is_contiguous
(),
"B must be a 2D contiguous tensor"
# Reshape input
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
B
=
B
.
t
()
# Transpose weight matrix
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
K
,)
A
=
A
.
reshape
(
M
,
N
)
# As is per-token [M, 1], Bs is per-column [1, K]
C
=
torch
.
matmul
(
A
,
B
)
# [M, K]
C
=
As
*
C
*
Bs
.
view
(
1
,
-
1
)
# Broadcast per-column scale
return
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
def
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B
,
D
=
a
.
shape
# Perform per-token quantization
a_q
,
a_s
=
per_token_quant_int8
(
a
)
# Repeat tokens to match topk
a_q
=
a_q
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
# Also repeat the scale
a_s
=
a_s
.
view
(
B
,
-
1
,
1
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
1
)
# [B*topk, 1]
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
# Calculate routing
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
# Process each expert
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
# First MLP layer: note that a_s is now per-token
inter_out
=
native_w8a8_per_token_matmul
(
a_q
[
mask
],
w1
[
i
],
a_s
[
mask
],
w1_s
[
i
],
output_dtype
=
a
.
dtype
)
# Activation function
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
# Quantize activation output with per-token
act_out_q
,
act_out_s
=
per_token_quant_int8
(
act_out
)
# Second MLP layer
out
[
mask
]
=
native_w8a8_per_token_matmul
(
act_out_q
,
w2
[
i
],
act_out_s
,
w2_s
[
i
],
output_dtype
=
a
.
dtype
)
# Apply routing weights and sum
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
def
marlin_fused_moe
(
N
,
E
,
K
,
a
,
w1
,
w2
,
num_bits
,
group_size
,
act_order
,
score
,
topk
,
ep_size
):
quant_type
=
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
if
ep_size
>
1
:
local_e
=
E
//
ep_size
e_ids
=
torch
.
randperm
(
E
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
local_e
]
e_map
=
torch
.
full
((
E
,),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
w1
=
w1
[
e_ids
]
w2
=
w2
[
e_ids
]
else
:
e_map
=
None
w_ref1_l
=
[]
qweight1_l
=
[]
scales1_l
=
[]
zeros1_l
=
[]
g_idx1_l
=
[]
sort_indices1_l
=
[]
s1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
n
=
K
)
quant_res
=
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
quant_res
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
g_idx1_l
.
append
(
g_idx1
)
sort_indices1_l
.
append
(
sort_indices1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
g_idx1
=
stack_and_dev
(
g_idx1_l
)
if
g_idx1_l
else
None
zeros1
=
stack_and_dev
(
zeros1_l
)
if
zeros1_l
else
None
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
if
sort_indices1_l
else
None
w_ref2_l
=
[]
qweight2_l
=
[]
scales2_l
=
[]
zeros2_l
=
[]
g_idx2_l
=
[]
sort_indices2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
n
=
N
)
quant_res
=
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
quant_res
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
g_idx2_l
.
append
(
g_idx2
)
sort_indices2_l
.
append
(
sort_indices2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
g_idx2
=
stack_and_dev
(
g_idx2_l
)
if
g_idx2_l
else
None
zeros2
=
stack_and_dev
(
zeros2_l
)
if
zeros2_l
else
None
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
if
sort_indices2_l
else
None
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
False
)
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=a,
# router_logits=score,
# top_k=topk,
# num_expert_group=E,
# use_grouped_topk=False,
# renormalize=False,
# topk_group=None,
# )
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
marlin_output
=
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
scales1
,
scales2
,
score
,
topk_weights
,
topk_ids
,
global_num_experts
=
E
,
expert_map
=
e_map
,
g_idx1
=
g_idx1
,
g_idx2
=
g_idx2
,
sort_indices1
=
sort_indices1
,
sort_indices2
=
sort_indices2
,
w1_zeros
=
zeros1
,
w2_zeros
=
zeros2
,
num_bits
=
num_bits
,
is_k_full
=
True
,
)
return
marlin_output
,
torch_output
class
TestW8A8Int8FusedMoE
(
unittest
.
TestCase
):
DTYPES
=
[
torch
.
float16
]
M
=
[
1
,
16
]
N
=
[
128
]
K
=
[
256
]
E
=
[
4
,
10
]
TOP_KS
=
[
2
,
4
]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
NUM_BITS
=
[
4
]
EP_SIZE
=
[
1
,
4
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_w4a8_int8_fused_moe
(
self
,
M
,
N
,
K
,
E
,
topk
,
block_size
,
dtype
,
seed
,
num_bits
,
ep_size
):
torch
.
manual_seed
(
seed
)
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
# Generate int8 weights
w1_fp16
=
(
torch
.
rand
((
E
,
2
*
N
,
K
),
dtype
=
dtype
)
-
0.5
)
*
2
w2_fp16
=
(
torch
.
rand
((
E
,
K
,
N
),
dtype
=
dtype
)
-
0.5
)
*
2
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
with
torch
.
inference_mode
():
marlin_out
,
ref_out
=
marlin_fused_moe
(
N
=
N
,
E
=
E
,
K
=
K
,
a
=
a
,
w1
=
w1_fp16
,
w2
=
w2_fp16
,
num_bits
=
num_bits
,
group_size
=-
1
,
act_order
=
False
,
score
=
score
,
topk
=
topk
,
ep_size
=
ep_size
,
)
# Check results
if
(
torch
.
mean
(
torch
.
abs
(
marlin_out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
))
)
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
)))
>
0.1
):
print
(
f
"marlin_out:
{
marlin_out
}
"
)
print
(
f
"ref_out:
{
ref_out
}
"
)
print
(
torch
.
mean
(
torch
.
abs
(
marlin_out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
))
)
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
)))
)
torch
.
testing
.
assert_close
(
marlin_out
,
ref_out
,
atol
=
2e-2
,
rtol
=
0
)
def
test_w4a8_int8_fused_moe
(
self
):
for
params
in
itertools
.
product
(
self
.
M
,
self
.
N
,
self
.
K
,
self
.
E
,
self
.
TOP_KS
,
self
.
BLOCK_SIZE
,
self
.
DTYPES
,
self
.
SEEDS
,
self
.
NUM_BITS
,
self
.
EP_SIZE
,
):
with
self
.
subTest
(
M
=
params
[
0
],
N
=
params
[
1
],
K
=
params
[
2
],
E
=
params
[
3
],
topk
=
params
[
4
],
block_size
=
params
[
5
],
dtype
=
params
[
6
],
seed
=
params
[
7
],
num_bits
=
params
[
8
],
ep_size
=
params
[
9
],
):
self
.
_w4a8_int8_fused_moe
(
*
params
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
test/srt/test_w4a8.py
deleted
100644 → 0
View file @
570d3343
import
sgl_kernel
import
torch
x
=
torch
.
randn
(
10
,
10
,
device
=
"cuda"
)
qweight
=
torch
.
randn
(
10
,
10
,
device
=
"cuda"
)
s1_scales
=
torch
.
randn
(
10
,
device
=
"cuda"
)
input_scales
=
torch
.
randn
(
10
,
device
=
"cuda"
)
s1_szeros
=
torch
.
randn
(
10
,
device
=
"cuda"
)
input_sum
=
torch
.
randn
(
10
,
device
=
"cuda"
)
output_buffer
=
torch
.
randn
(
10
,
device
=
"cuda"
)
torch
.
ops
.
sgl_kernel
.
gemm_forward_cuda
.
default
(
x
,
qweight
,
s1_scales
,
input_scales
,
s1_szeros
,
input_sum
,
output_buffer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment