Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e483ab6d
"vscode:/vscode.git/clone" did not exist on "c68e1835bc85686d1c987bab463af9c1f760c00e"
Unverified
Commit
e483ab6d
authored
Aug 19, 2025
by
Enrique Shockwave
Committed by
GitHub
Aug 18, 2025
Browse files
enable marlin fp8 blockwise (#8990)
parent
720cd308
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
84 deletions
+92
-84
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+83
-84
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+9
-0
No files found.
python/sglang/srt/layers/quantization/fp8.py
View file @
e483ab6d
...
...
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
can_auto_enable_marlin_fp8
,
cutlass_fp8_supported
,
dispatch_w8a8_block_fp8_linear
,
input_to_float8
,
...
...
@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
(
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
and
MARLIN_FP8_AVAILABLE
)
# Disable marlin for ROCm
if
_is_hip
:
self
.
use_marlin
=
False
if
_is_cuda
and
MARLIN_FP8_AVAILABLE
:
force_marlin
=
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
auto_enable
=
can_auto_enable_marlin_fp8
()
self
.
use_marlin
=
force_marlin
or
auto_enable
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
if
self
.
block_quant
:
# Marlin doesn't support block-wise fp8
self
.
use_marlin
=
False
self
.
w8a8_block_fp8_linear
=
dispatch_w8a8_block_fp8_linear
()
...
...
@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_fp8_fnuz
:
...
...
@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
None
,
)
layer
.
input_scale
=
None
elif
_is_cpu
:
assert
(
...
...
@@ -352,18 +347,16 @@ class Fp8LinearMethod(LinearMethodBase):
return
else
:
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
else
:
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
# apply per-channel quantization default as
# cutlass sgl-kernel and marlin only support per-channel scale
qweight
,
weight_scale
=
per_token_group_quant_fp8
(
layer
.
weight
,
layer
.
weight
.
shape
[
-
1
]
)
...
...
@@ -380,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
if
(
...
...
@@ -390,7 +383,7 @@ class Fp8LinearMethod(LinearMethodBase):
hasattr
(
self
.
quant_config
,
"linear_activation_scheme"
)
and
self
.
quant_config
.
linear_activation_scheme
==
"static"
):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
...
...
@@ -406,13 +399,17 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_fp8_fnuz
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
weight_scale
,
input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
layer
.
input_scale
,
)
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
weight
,
...
...
@@ -435,7 +432,9 @@ class Fp8LinearMethod(LinearMethodBase):
)
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
if
self
.
block_quant
:
layer
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
prepare_fp8_layer_for_marlin
(
layer
,
not
self
.
block_quant
)
# Activations not quantized for marlin.
del
layer
.
input_scale
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
e483ab6d
...
...
@@ -789,3 +789,12 @@ def apply_fp8_linear(
bias
,
input
.
dtype
,
)
def
can_auto_enable_marlin_fp8
()
->
bool
:
try
:
major
,
minor
=
get_device_capability
()
sm
=
major
*
10
+
minor
return
80
<=
sm
<
89
except
Exception
:
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment