Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
49b87774
Unverified
Commit
49b87774
authored
Jul 17, 2025
by
Cheng Wan
Committed by
GitHub
Jul 17, 2025
Browse files
Refactor: move all quantization-related code to `srt/layer/quantization` (#7989)
parent
02404a1e
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
67 deletions
+13
-67
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+12
-28
python/sglang/srt/layers/vocab_parallel_embedding.py
python/sglang/srt/layers/vocab_parallel_embedding.py
+1
-39
No files found.
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
49b87774
from
__future__
import
annotations
import
importlib
import
sys
from
types
import
MappingProxyType
...
...
@@ -11,21 +13,19 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.linear
import
(
LinearMethodBase
,
RowParallelLinear
,
UnquantizedLinearMethod
,
)
from
sglang.srt.layers.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
,
)
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
LinearMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.utils
import
(
apply_module_patch
,
cpu_has_amx_support
,
...
...
@@ -229,14 +229,14 @@ class W8A8Int8Config(QuantizationConfig):
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"
W8A8Int8Config
"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
W8A8Int8Config
:
return
cls
(
config
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"
QuantizeMethodBase
"
]:
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
...
...
@@ -374,7 +374,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
)
class
W8A8Int8MoEMethod
:
class
W8A8Int8MoEMethod
(
FusedMoEMethodBase
)
:
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
...
...
@@ -385,25 +385,7 @@ class W8A8Int8MoEMethod:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoEMethodBase
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
def
__init__
(
self
,
quant_config
:
W8A8Int8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
...
...
@@ -885,13 +867,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.linear
import
RowParallelLinear
if
isinstance
(
layer
,
RowParallelLinear
):
tp_rank
=
get_tensor_model_parallel_rank
()
return
self
.
quant_method
.
apply
(
layer
,
x
,
bias
,
tp_rank
)
return
self
.
quant_method
.
apply
(
layer
,
x
,
bias
)
class
NPU_W8A8MoEMethod
:
class
NPU_W8A8MoEMethod
(
FusedMoEMethodBase
)
:
"""MoE method for NPU quantization.
This class search for specific quantization
...
...
python/sglang/srt/layers/vocab_parallel_embedding.py
View file @
49b87774
...
...
@@ -5,7 +5,6 @@ from dataclasses import dataclass
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
sglang.srt.distributed
import
(
...
...
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
method_has_implemented_embedding
,
)
from
sglang.srt.layers.quantization.unquant
import
UnquantizedEmbeddingMethod
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE
=
64
...
...
@@ -32,44 +32,6 @@ _is_cpu = is_cpu()
logger
=
logging
.
getLogger
(
__name__
)
class
UnquantizedEmbeddingMethod
(
QuantizeMethodBase
):
"""Unquantized method for embeddings."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
"""Create weights for embedding layer."""
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
embedding
(
input_
,
layer
.
weight
)
def
pad_vocab_size
(
vocab_size
:
int
,
pad_to
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
)
->
int
:
"""Pad the vocab size to the given value."""
return
((
vocab_size
+
pad_to
-
1
)
//
pad_to
)
*
pad_to
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment