Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e483ab6d
Unverified
Commit
e483ab6d
authored
Aug 19, 2025
by
Enrique Shockwave
Committed by
GitHub
Aug 18, 2025
Browse files
enable marlin fp8 blockwise (#8990)
parent
720cd308
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
84 deletions
+92
-84
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+83
-84
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+9
-0
No files found.
python/sglang/srt/layers/quantization/fp8.py
View file @
e483ab6d
...
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
apply_fp8_linear
,
can_auto_enable_marlin_fp8
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
dispatch_w8a8_block_fp8_linear
,
dispatch_w8a8_block_fp8_linear
,
input_to_float8
,
input_to_float8
,
...
@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
(
self
.
use_marlin
=
False
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
and
MARLIN_FP8_AVAILABLE
if
_is_cuda
and
MARLIN_FP8_AVAILABLE
:
)
force_marlin
=
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
# Disable marlin for ROCm
auto_enable
=
can_auto_enable_marlin_fp8
()
if
_is_hip
:
self
.
use_marlin
=
force_marlin
or
auto_enable
self
.
use_marlin
=
False
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
if
self
.
block_quant
:
# Marlin doesn't support block-wise fp8
self
.
use_marlin
=
False
self
.
w8a8_block_fp8_linear
=
dispatch_w8a8_block_fp8_linear
()
self
.
w8a8_block_fp8_linear
=
dispatch_w8a8_block_fp8_linear
()
...
@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
None
)
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_fp8_fnuz
:
if
_is_fp8_fnuz
:
...
@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale_inv
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
None
,
input_scale
=
None
,
)
)
layer
.
input_scale
=
None
layer
.
input_scale
=
None
elif
_is_cpu
:
elif
_is_cpu
:
assert
(
assert
(
...
@@ -352,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -352,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase):
return
return
else
:
else
:
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
weight_scale
,
requires_grad
=
False
else
:
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
# apply per-channel quantization default as
# cutlass sgl-kernel and marlin only support per-channel scale
qweight
,
weight_scale
=
per_token_group_quant_fp8
(
layer
.
weight
,
layer
.
weight
.
shape
[
-
1
]
)
weight_scale
=
weight_scale
.
t
().
contiguous
()
else
:
# per-tensor quantization
qweight
,
weight_scale
=
input_to_float8
(
layer
.
weight
)
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
# If checkpoint not serialized fp8, quantize the weights.
# If checkpoint is fp8, handle that there are N scales for N
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# shards in a fused module
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
qweight
,
weight_scale
=
per_token_group_quant_fp8
(
layer
.
weight
,
layer
.
weight
.
shape
[
-
1
]
)
weight_scale
=
weight_scale
.
t
().
contiguous
()
else
:
else
:
# per-tensor quantization
layer
.
weight_scale
=
Parameter
(
qweight
,
weight_scale
=
input_to_float8
(
layer
.
weight
)
layer
.
weight_scale
.
data
,
requires_grad
=
False
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
if
(
hasattr
(
self
.
quant_config
,
"activation_scheme"
)
and
self
.
quant_config
.
activation_scheme
==
"static"
)
or
(
hasattr
(
self
.
quant_config
,
"linear_activation_scheme"
)
and
self
.
quant_config
.
linear_activation_scheme
==
"static"
):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
)
if
(
hasattr
(
self
.
quant_config
,
"activation_scheme"
)
and
self
.
quant_config
.
activation_scheme
==
"static"
)
or
(
hasattr
(
self
.
quant_config
,
"linear_activation_scheme"
)
and
self
.
quant_config
.
linear_activation_scheme
==
"static"
):
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
# cutlass sgl-kernel and marlin only support per-channel scale
# cutlass sgl-kernel and marlin only support per-channel scale
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
weight
=
layer
.
weight
weight
=
layer
.
weight
weight_scale
=
convert_to_channelwise
(
weight_scale
=
convert_to_channelwise
(
layer
.
weight_scale
,
layer
.
logical_widths
layer
.
weight_scale
,
layer
.
logical_widths
)
)
else
:
else
:
# Dequant -> Quant with max scale so we can run per tensor.
# Dequant -> Quant with max scale so we can run per tensor.
weight
=
layer
.
weight
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_fp8_fnuz
:
if
_is_fp8_fnuz
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
weight_scale
,
input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
layer
.
input_scale
,
)
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
weight
,
weight
=
weight
,
weight_scale
=
weight_scale
,
weight_scale
=
weight_scale
,
input_scale
=
layer
.
input_scale
,
logical_widths
=
layer
.
logical_widths
,
)
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
weight
,
weight_scale
=
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
# Update layer with new values.
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
if
(
if
(
hasattr
(
self
.
quant_config
,
"activation_scheme"
)
hasattr
(
self
.
quant_config
,
"activation_scheme"
)
and
self
.
quant_config
.
activation_scheme
==
"static"
and
self
.
quant_config
.
activation_scheme
==
"static"
)
or
(
)
or
(
hasattr
(
self
.
quant_config
,
"linear_activation_scheme"
)
hasattr
(
self
.
quant_config
,
"linear_activation_scheme"
)
and
self
.
quant_config
.
linear_activation_scheme
==
"static"
and
self
.
quant_config
.
linear_activation_scheme
==
"static"
):
):
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
)
if
self
.
use_marlin
:
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
if
self
.
block_quant
:
layer
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
prepare_fp8_layer_for_marlin
(
layer
,
not
self
.
block_quant
)
# Activations not quantized for marlin.
# Activations not quantized for marlin.
del
layer
.
input_scale
del
layer
.
input_scale
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
e483ab6d
...
@@ -789,3 +789,12 @@ def apply_fp8_linear(
...
@@ -789,3 +789,12 @@ def apply_fp8_linear(
bias
,
bias
,
input
.
dtype
,
input
.
dtype
,
)
)
def
can_auto_enable_marlin_fp8
()
->
bool
:
try
:
major
,
minor
=
get_device_capability
()
sm
=
major
*
10
+
minor
return
80
<=
sm
<
89
except
Exception
:
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment