Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e483ab6d
Unverified
Commit
e483ab6d
authored
Aug 19, 2025
by
Enrique Shockwave
Committed by
GitHub
Aug 18, 2025
Browse files
enable marlin fp8 blockwise (#8990)
parent
720cd308
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
84 deletions
+92
-84
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+83
-84
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+9
-0
No files found.
python/sglang/srt/layers/quantization/fp8.py
View file @
e483ab6d
...
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
apply_fp8_linear
,
can_auto_enable_marlin_fp8
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
dispatch_w8a8_block_fp8_linear
,
dispatch_w8a8_block_fp8_linear
,
input_to_float8
,
input_to_float8
,
...
@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
(
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
and
MARLIN_FP8_AVAILABLE
)
# Disable marlin for ROCm
if
_is_hip
:
self
.
use_marlin
=
False
self
.
use_marlin
=
False
if
_is_cuda
and
MARLIN_FP8_AVAILABLE
:
force_marlin
=
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
auto_enable
=
can_auto_enable_marlin_fp8
()
self
.
use_marlin
=
force_marlin
or
auto_enable
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
if
self
.
block_quant
:
# Marlin doesn't support block-wise fp8
self
.
use_marlin
=
False
self
.
w8a8_block_fp8_linear
=
dispatch_w8a8_block_fp8_linear
()
self
.
w8a8_block_fp8_linear
=
dispatch_w8a8_block_fp8_linear
()
...
@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
None
)
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_fp8_fnuz
:
if
_is_fp8_fnuz
:
...
@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale_inv
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
None
,
input_scale
=
None
,
)
)
layer
.
input_scale
=
None
layer
.
input_scale
=
None
elif
_is_cpu
:
elif
_is_cpu
:
assert
(
assert
(
...
@@ -352,18 +347,16 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -352,18 +347,16 @@ class Fp8LinearMethod(LinearMethodBase):
return
return
else
:
else
:
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
weight_scale
,
requires_grad
=
False
else
:
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
# apply per-channel quantization default as
# cutlass sgl-kernel and marlin only support per-channel scale
qweight
,
weight_scale
=
per_token_group_quant_fp8
(
qweight
,
weight_scale
=
per_token_group_quant_fp8
(
layer
.
weight
,
layer
.
weight
.
shape
[
-
1
]
layer
.
weight
,
layer
.
weight
.
shape
[
-
1
]
)
)
...
@@ -380,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -380,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint is fp8, handle that there are N scales for N
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
# shards in a fused module
else
:
else
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
)
if
(
if
(
...
@@ -390,7 +383,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -390,7 +383,7 @@ class Fp8LinearMethod(LinearMethodBase):
hasattr
(
self
.
quant_config
,
"linear_activation_scheme"
)
hasattr
(
self
.
quant_config
,
"linear_activation_scheme"
)
and
self
.
quant_config
.
linear_activation_scheme
==
"static"
and
self
.
quant_config
.
linear_activation_scheme
==
"static"
):
):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
layer
.
input_scale
.
data
,
requires_grad
=
False
)
)
...
@@ -406,13 +399,17 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -406,13 +399,17 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_fp8_fnuz
:
if
_is_fp8_fnuz
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
weight_scale
,
input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
weight_scale
=
weight_scale
,
weight_scale
=
weight_scale
,
input_scale
=
layer
.
input_scale
,
input_scale
=
layer
.
input_scale
,
)
)
)
if
input_scale
is
not
None
:
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
weight_scale
,
weight
=
requantize_with_max_scale
(
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
weight
,
weight
=
weight
,
...
@@ -435,7 +432,9 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -435,7 +432,9 @@ class Fp8LinearMethod(LinearMethodBase):
)
)
if
self
.
use_marlin
:
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
if
self
.
block_quant
:
layer
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
prepare_fp8_layer_for_marlin
(
layer
,
not
self
.
block_quant
)
# Activations not quantized for marlin.
# Activations not quantized for marlin.
del
layer
.
input_scale
del
layer
.
input_scale
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
e483ab6d
...
@@ -789,3 +789,12 @@ def apply_fp8_linear(
...
@@ -789,3 +789,12 @@ def apply_fp8_linear(
bias
,
bias
,
input
.
dtype
,
input
.
dtype
,
)
)
def
can_auto_enable_marlin_fp8
()
->
bool
:
try
:
major
,
minor
=
get_device_capability
()
sm
=
major
*
10
+
minor
return
80
<=
sm
<
89
except
Exception
:
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment