Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11d760d5
Unverified
Commit
11d760d5
authored
Apr 08, 2025
by
Trevor Morris
Committed by
GitHub
Apr 08, 2025
Browse files
FP4 weight loading and inference (2/2) (#3972)
parent
5039d547
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
262 additions
and
1 deletion
+262
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+1
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+5
-1
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+246
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
sgl-kernel/README.md
sgl-kernel/README.md
+8
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
11d760d5
...
...
@@ -279,6 +279,7 @@ class ModelConfig:
"moe_wna16"
,
]
compatible_quantization_methods
=
{
"modelopt_fp4"
:
[
"modelopt"
],
"w8a8_int8"
:
[
"compressed-tensors"
,
"compressed_tensors"
],
"w8a8_fp8"
:
[
"compressed-tensors"
,
"compressed_tensors"
],
}
...
...
python/sglang/srt/layers/linear.py
View file @
11d760d5
...
...
@@ -47,6 +47,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
,
"ModelOptFp4LinearMethod"
,
"IPEXAWQLinearMethod"
,
]
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
11d760d5
...
...
@@ -59,7 +59,10 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.gptq
import
GPTQConfig
,
GPTQMarlinConfig
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp8Config
from
sglang.srt.layers.quantization.modelopt_quant
import
(
ModelOptFp4Config
,
ModelOptFp8Config
,
)
from
sglang.srt.layers.quantization.moe_wna16
import
MoeWNA16Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
...
...
@@ -69,6 +72,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8"
:
Fp8Config
,
"blockwise_int8"
:
BlockInt8Config
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt_fp4"
:
ModelOptFp4Config
,
"w8a8_int8"
:
W8A8Int8Config
,
"w8a8_fp8"
:
W8A8Fp8Config
,
"moe_wna16"
:
MoeWNA16Config
,
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
11d760d5
...
...
@@ -22,6 +22,10 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
# Initialize logger for the module
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -215,3 +219,245 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
super
().
__init__
(
quant_config
)
class
ModelOptFp4Config
(
QuantizationConfig
):
"""Config class for FP4."""
def
__init__
(
self
,
is_checkpoint_nvfp4_serialized
:
bool
=
False
,
kv_cache_quant_algo
:
str
=
None
,
group_size
:
int
=
None
,
exclude_modules
:
List
[
str
]
=
None
,
)
->
None
:
self
.
is_checkpoint_nvfp4_serialized
=
is_checkpoint_nvfp4_serialized
if
is_checkpoint_nvfp4_serialized
:
logger
.
warning
(
"Detected nvfp4 checkpoint. Please note that the "
"format is experimental and subject to change."
)
self
.
group_size
=
group_size
self
.
kv_cache_quant_algo
=
kv_cache_quant_algo
self
.
exclude_modules
=
exclude_modules
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"modelopt_fp4"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float8_e4m3fn
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
100
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptFp4Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
if
not
quant_method
in
[
"FP8"
,
"NVFP4"
]:
raise
ValueError
(
f
"ModelOpt currently only supports: FP8, NVFP4"
" quantizations in sglang. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
is_checkpoint_nvfp4_serialized
=
"NVFP4"
in
quant_method
kv_cache_quant_algo
=
quant_config
[
"kv_cache_quant_algo"
]
group_size
=
quant_config
[
"group_size"
]
exclude_modules
=
quant_config
[
"exclude_modules"
]
if
not
(
group_size
and
kv_cache_quant_algo
and
exclude_modules
):
raise
ValueError
(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
)
return
cls
(
is_checkpoint_nvfp4_serialized
,
kv_cache_quant_algo
,
group_size
,
exclude_modules
,
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
self
.
exclude_modules
and
any
(
module
in
prefix
for
module
in
self
.
exclude_modules
):
return
None
if
isinstance
(
layer
,
LinearBase
):
return
ModelOptFp4LinearMethod
(
self
)
if
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ModelOptFp4LinearMethod
(
LinearMethodBase
):
"""Linear method for NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
|Tensor Name | datatype | shape |
|----------------------------------------------------|
|input_scale | torch.float32 | scalar |
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|weight_scale | FP8-E4M3 | [X, Y] |
|weight_scale_2 | torch.float32 | scalar |
The weights are quantized per block of 16 elements.
Args: quant_config: The ModelOpt quantization config.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp4Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
if
not
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
:
raise
ValueError
(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
if
input_size_per_partition
%
16
!=
0
:
raise
ValueError
(
"Unsupported model when in features size is "
"not multiple of 16"
)
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
else
params_dtype
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
# 2 fp4 data is packed in one uint8 in the input dimension
output_size_per_partition
,
input_size_per_partition
//
2
,
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale_2"
,
weight_scale_2
)
weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
dtype
=
weight_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
input_scale_2
=
layer
.
input_scale
.
max
().
to
(
torch
.
float32
)
weight_scale_2
=
layer
.
weight_scale_2
.
max
().
to
(
torch
.
float32
)
layer
.
input_scale
=
Parameter
(
input_scale_2
,
requires_grad
=
False
)
layer
.
weight_scale_2
=
Parameter
(
weight_scale_2
,
requires_grad
=
False
)
layer
.
alpha
=
Parameter
(
layer
.
input_scale
*
layer
.
weight_scale_2
,
requires_grad
=
False
)
# Pad and blockwise interleave weight_scale
scales
=
layer
.
weight_scale
scale_ndim
=
scales
.
ndim
if
scale_ndim
==
2
:
scales
=
scales
.
unsqueeze
(
0
)
assert
scales
.
ndim
==
3
B
,
M
,
K
=
scales
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scales
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scales
.
dtype
)
padded_scales
[:
B
,
:
M
,
:
K
]
=
scales
batches
,
rows
,
cols
=
padded_scales
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scales
=
padded_scales
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
padded_scales
=
padded_scales
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
padded_scales
=
padded_scales
.
contiguous
().
cuda
()
padded_scales
=
(
padded_scales
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
padded_scales
.
reshape
(
B
,
M
,
K
)
)
layer
.
weight_scale_interleaved
=
Parameter
(
padded_scales
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
output_dtype
=
x
.
dtype
x_m
,
_
=
x
.
shape
w_n
,
_
=
layer
.
weight
.
shape
output_shape
=
[
x_m
,
w_n
]
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_scale_interleaved
=
scaled_fp4_quant
(
x
,
1
/
layer
.
input_scale
)
assert
x_fp4
.
dtype
==
torch
.
uint8
assert
x_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
weight
.
dtype
==
torch
.
uint8
assert
layer
.
weight_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
alpha
.
dtype
==
torch
.
float32
out
=
cutlass_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_scale_interleaved
,
layer
.
weight_scale_interleaved
,
layer
.
alpha
,
output_dtype
,
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
python/sglang/srt/server_args.py
View file @
11d760d5
...
...
@@ -495,6 +495,7 @@ class ServerArgs:
"bitsandbytes"
,
"gguf"
,
"modelopt"
,
"modelopt_fp4"
,
"w8a8_int8"
,
"w8a8_fp8"
,
"moe_wna16"
,
...
...
sgl-kernel/README.md
View file @
11d760d5
...
...
@@ -156,6 +156,14 @@ unset CCACHE_READONLY
python
-m
uv build
--wheel
-Cbuild-dir
=
build
--color
=
always .
```
##### Configuring CMake Build Options
Cmake options can be configuring by adding
`-Ccmake.define.<option>=<value>`
to the
`uv build`
flags.
For example, to enable building FP4 kernels, use:
```
bash
python
-m
uv build
--wheel
-Cbuild-dir
=
build
-Ccmake
.define.SGL_KERNEL_ENABLE_FP4
=
1
--color
=
always .
```
See CMakeLists.txt for more options.
### Testing & Benchmarking
1.
Add pytest tests in
[
tests/
](
https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests
)
, if you need to skip some test, please use
`@pytest.mark.skipif`
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment