Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11d760d5
"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "576b02b19ec7b8273cc3c343a8d36272b63330ca"
Unverified
Commit
11d760d5
authored
Apr 08, 2025
by
Trevor Morris
Committed by
GitHub
Apr 08, 2025
Browse files
FP4 weight loading and inference (2/2) (#3972)
parent
5039d547
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
262 additions
and
1 deletion
+262
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+1
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+5
-1
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+246
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
sgl-kernel/README.md
sgl-kernel/README.md
+8
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
11d760d5
...
...
@@ -279,6 +279,7 @@ class ModelConfig:
"moe_wna16"
,
]
compatible_quantization_methods
=
{
"modelopt_fp4"
:
[
"modelopt"
],
"w8a8_int8"
:
[
"compressed-tensors"
,
"compressed_tensors"
],
"w8a8_fp8"
:
[
"compressed-tensors"
,
"compressed_tensors"
],
}
...
...
python/sglang/srt/layers/linear.py
View file @
11d760d5
...
...
@@ -47,6 +47,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
,
"ModelOptFp4LinearMethod"
,
"IPEXAWQLinearMethod"
,
]
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
11d760d5
...
...
@@ -59,7 +59,10 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.gptq
import
GPTQConfig
,
GPTQMarlinConfig
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp8Config
from
sglang.srt.layers.quantization.modelopt_quant
import
(
ModelOptFp4Config
,
ModelOptFp8Config
,
)
from
sglang.srt.layers.quantization.moe_wna16
import
MoeWNA16Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
...
...
@@ -69,6 +72,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8"
:
Fp8Config
,
"blockwise_int8"
:
BlockInt8Config
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt_fp4"
:
ModelOptFp4Config
,
"w8a8_int8"
:
W8A8Int8Config
,
"w8a8_fp8"
:
W8A8Fp8Config
,
"moe_wna16"
:
MoeWNA16Config
,
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
11d760d5
...
...
@@ -22,6 +22,10 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
# Initialize logger for the module
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -215,3 +219,245 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
super
().
__init__
(
quant_config
)
class
ModelOptFp4Config
(
QuantizationConfig
):
"""Config class for FP4."""
def
__init__
(
self
,
is_checkpoint_nvfp4_serialized
:
bool
=
False
,
kv_cache_quant_algo
:
str
=
None
,
group_size
:
int
=
None
,
exclude_modules
:
List
[
str
]
=
None
,
)
->
None
:
self
.
is_checkpoint_nvfp4_serialized
=
is_checkpoint_nvfp4_serialized
if
is_checkpoint_nvfp4_serialized
:
logger
.
warning
(
"Detected nvfp4 checkpoint. Please note that the "
"format is experimental and subject to change."
)
self
.
group_size
=
group_size
self
.
kv_cache_quant_algo
=
kv_cache_quant_algo
self
.
exclude_modules
=
exclude_modules
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"modelopt_fp4"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float8_e4m3fn
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
100
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptFp4Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
if
not
quant_method
in
[
"FP8"
,
"NVFP4"
]:
raise
ValueError
(
f
"ModelOpt currently only supports: FP8, NVFP4"
" quantizations in sglang. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
is_checkpoint_nvfp4_serialized
=
"NVFP4"
in
quant_method
kv_cache_quant_algo
=
quant_config
[
"kv_cache_quant_algo"
]
group_size
=
quant_config
[
"group_size"
]
exclude_modules
=
quant_config
[
"exclude_modules"
]
if
not
(
group_size
and
kv_cache_quant_algo
and
exclude_modules
):
raise
ValueError
(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
)
return
cls
(
is_checkpoint_nvfp4_serialized
,
kv_cache_quant_algo
,
group_size
,
exclude_modules
,
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
self
.
exclude_modules
and
any
(
module
in
prefix
for
module
in
self
.
exclude_modules
):
return
None
if
isinstance
(
layer
,
LinearBase
):
return
ModelOptFp4LinearMethod
(
self
)
if
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ModelOptFp4LinearMethod
(
LinearMethodBase
):
"""Linear method for NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
|Tensor Name | datatype | shape |
|----------------------------------------------------|
|input_scale | torch.float32 | scalar |
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|weight_scale | FP8-E4M3 | [X, Y] |
|weight_scale_2 | torch.float32 | scalar |
The weights are quantized per block of 16 elements.
Args: quant_config: The ModelOpt quantization config.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp4Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
if
not
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
:
raise
ValueError
(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
if
input_size_per_partition
%
16
!=
0
:
raise
ValueError
(
"Unsupported model when in features size is "
"not multiple of 16"
)
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
else
params_dtype
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
# 2 fp4 data is packed in one uint8 in the input dimension
output_size_per_partition
,
input_size_per_partition
//
2
,
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale_2"
,
weight_scale_2
)
weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
dtype
=
weight_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
input_scale_2
=
layer
.
input_scale
.
max
().
to
(
torch
.
float32
)
weight_scale_2
=
layer
.
weight_scale_2
.
max
().
to
(
torch
.
float32
)
layer
.
input_scale
=
Parameter
(
input_scale_2
,
requires_grad
=
False
)
layer
.
weight_scale_2
=
Parameter
(
weight_scale_2
,
requires_grad
=
False
)
layer
.
alpha
=
Parameter
(
layer
.
input_scale
*
layer
.
weight_scale_2
,
requires_grad
=
False
)
# Pad and blockwise interleave weight_scale
scales
=
layer
.
weight_scale
scale_ndim
=
scales
.
ndim
if
scale_ndim
==
2
:
scales
=
scales
.
unsqueeze
(
0
)
assert
scales
.
ndim
==
3
B
,
M
,
K
=
scales
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scales
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scales
.
dtype
)
padded_scales
[:
B
,
:
M
,
:
K
]
=
scales
batches
,
rows
,
cols
=
padded_scales
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scales
=
padded_scales
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
padded_scales
=
padded_scales
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
padded_scales
=
padded_scales
.
contiguous
().
cuda
()
padded_scales
=
(
padded_scales
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
padded_scales
.
reshape
(
B
,
M
,
K
)
)
layer
.
weight_scale_interleaved
=
Parameter
(
padded_scales
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
output_dtype
=
x
.
dtype
x_m
,
_
=
x
.
shape
w_n
,
_
=
layer
.
weight
.
shape
output_shape
=
[
x_m
,
w_n
]
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_scale_interleaved
=
scaled_fp4_quant
(
x
,
1
/
layer
.
input_scale
)
assert
x_fp4
.
dtype
==
torch
.
uint8
assert
x_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
weight
.
dtype
==
torch
.
uint8
assert
layer
.
weight_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
alpha
.
dtype
==
torch
.
float32
out
=
cutlass_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_scale_interleaved
,
layer
.
weight_scale_interleaved
,
layer
.
alpha
,
output_dtype
,
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
python/sglang/srt/server_args.py
View file @
11d760d5
...
...
@@ -495,6 +495,7 @@ class ServerArgs:
"bitsandbytes"
,
"gguf"
,
"modelopt"
,
"modelopt_fp4"
,
"w8a8_int8"
,
"w8a8_fp8"
,
"moe_wna16"
,
...
...
sgl-kernel/README.md
View file @
11d760d5
...
...
@@ -156,6 +156,14 @@ unset CCACHE_READONLY
python
-m
uv build
--wheel
-Cbuild-dir
=
build
--color
=
always .
```
##### Configuring CMake Build Options
Cmake options can be configuring by adding
`-Ccmake.define.<option>=<value>`
to the
`uv build`
flags.
For example, to enable building FP4 kernels, use:
```
bash
python
-m
uv build
--wheel
-Cbuild-dir
=
build
-Ccmake
.define.SGL_KERNEL_ENABLE_FP4
=
1
--color
=
always .
```
See CMakeLists.txt for more options.
### Testing & Benchmarking
1.
Add pytest tests in
[
tests/
](
https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests
)
, if you need to skip some test, please use
`@pytest.mark.skipif`
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment