Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0dd6cda2
Unverified
Commit
0dd6cda2
authored
Mar 09, 2025
by
HandH1998
Committed by
GitHub
Mar 09, 2025
Browse files
Apply sgl w8a8 fp8 kernel (#3148)
parent
9fb48f95
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
523 additions
and
37 deletions
+523
-37
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+3
-1
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+1
-1
python/sglang/srt/layers/parameter.py
python/sglang/srt/layers/parameter.py
+10
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+2
-0
python/sglang/srt/layers/quantization/blockwise_int8.py
python/sglang/srt/layers/quantization/blockwise_int8.py
+1
-2
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+21
-20
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+130
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+148
-8
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+5
-1
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+126
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+8
-0
python/sglang/test/test_block_fp8.py
python/sglang/test/test_block_fp8.py
+67
-1
No files found.
python/sglang/srt/configs/model_config.py
View file @
0dd6cda2
...
...
@@ -250,9 +250,11 @@ class ModelConfig:
"compressed-tensors"
,
"experts_int8"
,
"w8a8_int8"
,
"w8a8_fp8"
,
]
compatible_quantization_methods
=
{
"w8a8_int8"
:
[
"compressed-tensors"
,
"compressed_tensors"
]
"w8a8_int8"
:
[
"compressed-tensors"
,
"compressed_tensors"
],
"w8a8_fp8"
:
[
"compressed-tensors"
,
"compressed_tensors"
],
}
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
python/sglang/srt/layers/linear.py
View file @
0dd6cda2
...
...
@@ -18,6 +18,7 @@ from sglang.srt.distributed import (
)
from
sglang.srt.layers.parameter
import
(
BasevLLMParameter
,
BlockQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
,
...
...
@@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
BlockQuantScaleParameter
from
sglang.srt.utils
import
set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/parameter.py
View file @
0dd6cda2
...
...
@@ -16,6 +16,7 @@ __all__ = [
"ModelWeightParameter"
,
"ChannelQuantScaleParameter"
,
"GroupQuantScaleParameter"
,
"BlockQuantScaleParameter"
,
"PackedColumnParameter"
,
"RowvLLMParameter"
,
]
...
...
@@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
pass
class
BlockQuantScaleParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for weight scales loaded for weights with
block-wise quantization. Uses both column and row parallelism.
"""
pass
class
PerTensorScaleParameter
(
BasevLLMParameter
):
"""
Parameter class for scales where the number of scales is
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
0dd6cda2
...
...
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.gptq
import
GPTQConfig
,
GPTQMarlinConfig
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
...
...
@@ -50,6 +51,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"w8a8_int8"
:
W8A8Int8Config
,
"w8a8_fp8"
:
W8A8Fp8Config
,
}
...
...
python/sglang/srt/layers/quantization/blockwise_int8.py
View file @
0dd6cda2
...
...
@@ -13,12 +13,11 @@ from sglang.srt.layers.linear import (
LinearMethodBase
,
UnquantizedLinearMethod
,
)
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScale
Parameter
from
sglang.srt.layers.parameter
import
BlockQuantScaleParameter
,
ModelWeight
Parameter
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
BlockQuantScaleParameter
from
sglang.srt.layers.quantization.int8_utils
import
apply_w8a8_block_int8_linear
from
sglang.srt.utils
import
set_weight_attrs
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
0dd6cda2
...
...
@@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
,
)
...
...
@@ -29,14 +27,21 @@ from sglang.srt.layers.linear import (
LinearMethodBase
,
UnquantizedLinearMethod
,
)
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.parameter
import
(
BlockQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
,
)
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_utils
import
(
BlockQuantScaleParamete
r
,
apply_fp8_linea
r
,
apply_w8a8_block_fp8_linear
,
cutlass_fp8_supported
,
input_to_float8
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.utils
import
(
...
...
@@ -305,15 +310,15 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
assert
weight_scale
.
numel
()
==
1
weight_scale
=
convert_to_channelwise
(
weight_scale
.
expand
(
len
(
layer
.
logical_widths
)),
layer
.
logical_widths
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
qweight
,
weight_scale
=
per_token_group_quant_fp8
(
layer
.
weight
,
layer
.
weight
.
shape
[
-
1
]
)
weight_scale
=
weight_scale
.
t
().
contiguous
()
else
:
# per-tensor quantization
qweight
,
weight_scale
=
input_to_float8
(
layer
.
weight
)
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
...
...
@@ -330,23 +335,19 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
# If using marlin (w8a16), kernel uses channelwise weights,
#
so extend the weight scales to be
channel
wise.
if
self
.
use_marlin
:
#
cutlass sgl-kernel and marlin only support per-
channel
scale
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
weight
=
layer
.
weight
weight_scale
=
convert_to_channelwise
(
layer
.
weight_scale
,
layer
.
logical_widths
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else
:
# Dequant -> Quant with max scale so we can run per tensor.
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
_
:
if
is_hip
()
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
0dd6cda2
...
...
@@ -29,7 +29,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
if
_is_cuda
:
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -70,7 +70,8 @@ def _per_token_group_quant_fp8(
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_s_inv
=
1.0
/
y_s
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
...
...
@@ -140,7 +141,7 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Note that only `torch.float8_e4m3fn` is supported for now.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
...
...
@@ -241,6 +242,132 @@ def sglang_per_token_group_quant_fp8(
return
x_q
,
x_s
def
sglang_per_token_quant_fp8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_s
=
torch
.
empty
(
x
.
shape
[
0
],
1
,
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
sgl_per_token_quant_fp8
(
x
,
x_q
,
x_s
)
return
x_q
,
x_s
@
triton
.
jit
def
_static_quant_fp8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
y_s_repeat_ptr
,
# Stride of input
y_stride
,
# Collums of input
N
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
REPEAT_SCALE
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform quantization using the given scale on a
tensor
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
if
REPEAT_SCALE
:
y_s_repeat_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y_s
=
tl
.
load
(
y_s_ptr
).
to
(
tl
.
float32
)
y_s_inv
=
1.0
/
y_s
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
if
REPEAT_SCALE
:
tl
.
store
(
y_s_repeat_ptr
,
y_s
)
def
static_quant_fp8
(
x
:
torch
.
Tensor
,
x_s
:
torch
.
Tensor
,
repeat_scale
:
bool
=
False
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform static quantization using the given scale on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
x_s: The quantization scale.
repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x_s
.
numel
()
==
1
,
"only supports per-tensor scale"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
is_hip_
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
if
repeat_scale
:
x_s_repeat
=
torch
.
empty
(
(
M
,
1
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
else
:
x_s_repeat
=
None
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_static_quant_fp8
[(
M
,)](
x
,
x_q
,
x_s
,
x_s_repeat
,
N
,
N
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
REPEAT_SCALE
=
repeat_scale
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
x_s
=
x_s_repeat
if
repeat_scale
else
x_s
return
x_q
,
x_s
@
triton
.
jit
def
_w8a8_block_fp8_matmul
(
# Pointers to inputs and output
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
0dd6cda2
...
...
@@ -2,13 +2,23 @@ import os
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
packaging.version
import
Version
from
sglang.srt.layers.parameter
import
RowvLLMParameter
,
_ColumnvLLMParameter
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
static_quant_fp8
,
w8a8_block_fp8_matmul
,
)
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_cuda_version
,
get_device_capability
,
is_hip
,
)
use_vllm_cutlass_w8a8_fp8_kernel
=
os
.
environ
.
get
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
,
default
=
False
)
is_hip_
=
is_hip
()
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
...
...
@@ -18,6 +28,25 @@ _is_cuda = torch.cuda.is_available() and torch.version.cuda
if
_is_cuda
:
from
sgl_kernel
import
fp8_blockwise_scaled_mm
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_quant_fp8
if
use_vllm_cutlass_w8a8_fp8_kernel
:
from
vllm
import
_custom_ops
as
ops
else
:
from
sgl_kernel
import
fp8_scaled_mm
def
cutlass_fp8_supported
():
if
not
_is_cuda
:
return
False
major
,
minor
=
get_device_capability
()
cuda_version
=
get_cuda_version
()
if
major
>=
9
:
return
cuda_version
>=
(
12
,
0
)
elif
major
==
8
and
minor
==
9
:
return
cuda_version
>=
(
12
,
4
)
return
False
def
normalize_e4m3fn_to_e4m3fnuz
(
weight
:
torch
.
Tensor
,
...
...
@@ -158,10 +187,121 @@ def block_quant_to_tensor_quant(
return
x_q_tensor
,
scale
class
BlockQuantScaleParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for weight scales loaded for weights with
block-wise quantization. Uses both column and row parallelism.
"""
def
apply_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
input_scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cutlass_fp8_supported
:
bool
=
True
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
torch
.
Tensor
:
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
1
]]
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
if
input_scale
is
not
None
:
assert
input_scale
.
numel
()
==
1
# broadcast per-tensor scale to per-token scale when supporting cutlass
qinput
,
x_scale
=
static_quant_fp8
(
input_2d
,
input_scale
,
repeat_scale
=
cutlass_fp8_supported
)
else
:
# default use per-token quantization if dynamic
if
_is_cuda
:
qinput
,
x_scale
=
sglang_per_token_quant_fp8
(
input_2d
)
else
:
qinput
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
group_size
=
input_2d
.
shape
[
1
]
)
if
cutlass_fp8_supported
:
if
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to vllm cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
)
else
:
assert
(
weight_scale
.
numel
()
==
weight
.
shape
[
1
]
),
"cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output
=
fp8_scaled_mm
(
qinput
,
weight
,
x_scale
,
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
return
output
.
view
(
*
output_shape
)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else
:
per_tensor_weights
=
weight_scale
.
numel
()
==
1
per_tensor_activations
=
x_scale
.
numel
()
==
1
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
]).
view
(
*
output_shape
)
else
:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
out_dtype
=
torch
.
float32
,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input_2d
.
shape
[
0
])
pass
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
weight_scale
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
0dd6cda2
...
...
@@ -7,7 +7,7 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
convert_to_channelwise
,
cutlass_fp8_supported
,
requantize_with_max_scale
,
)
...
...
@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
apply_fp8_linear
# Initialize logger for the module
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -161,6 +162,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer
.
weight
,
layer
.
weight_scale
,
layer
.
logical_widths
)
layer
.
weight
=
Parameter
(
quantized_weight
.
t
(),
requires_grad
=
False
)
# cutlass sgl-kernel only supports per-channel scale
if
self
.
cutlass_fp8_supported
:
max_w_scale
=
convert_to_channelwise
(
max_w_scale
,
layer
.
logical_widths
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
...
...
python/sglang/srt/layers/quantization/w8a8_fp8.py
0 → 100644
View file @
0dd6cda2
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearMethodBase
from
sglang.srt.layers.parameter
import
ChannelQuantScaleParameter
,
ModelWeightParameter
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.utils
import
is_hip
class
W8A8Fp8Config
(
QuantizationConfig
):
"""Config class for W8A8 FP8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
89
@
classmethod
def
get_name
(
self
)
->
str
:
return
"w8a8_fp8"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"W8A8Fp8Config"
:
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.linear
import
LinearBase
if
isinstance
(
layer
,
LinearBase
):
return
W8A8Fp8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
W8A8Fp8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quantization_config
:
W8A8Fp8Config
):
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
quantization_config
=
quantization_config
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
.
detach
()
if
is_hip
():
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
logical_widths
=
output_partition_sizes
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
return
apply_fp8_linear
(
x
,
layer
.
weight
,
layer
.
weight_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
)
python/sglang/srt/server_args.py
View file @
0dd6cda2
...
...
@@ -405,6 +405,7 @@ class ServerArgs:
"gguf"
,
"modelopt"
,
"w8a8_int8"
,
"w8a8_fp8"
,
],
help
=
"The quantization method."
,
)
...
...
python/sglang/srt/utils.py
View file @
0dd6cda2
...
...
@@ -52,11 +52,13 @@ import triton
import
zmq
from
fastapi.responses
import
ORJSONResponse
from
packaging
import
version
as
pkg_version
from
packaging.version
import
Version
,
parse
from
starlette.routing
import
Mount
from
torch
import
nn
from
torch.func
import
functional_call
from
torch.library
import
Library
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
torch.utils.cpp_extension
import
CUDA_HOME
from
triton.runtime.cache
import
(
FileCacheManager
,
default_cache_dir
,
...
...
@@ -1431,6 +1433,12 @@ def rank0_print(msg: str):
print
(
msg
,
flush
=
True
)
def
get_cuda_version
():
if
torch
.
version
.
cuda
:
return
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
return
(
0
,
0
)
def
launch_dummy_health_check_server
(
host
,
port
):
import
uvicorn
from
fastapi
import
FastAPI
,
Response
...
...
python/sglang/test/test_block_fp8.py
View file @
0dd6cda2
...
...
@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
static_quant_fp8
,
w8a8_block_fp8_matmul
,
)
...
...
@@ -63,7 +64,7 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
out
,
scale
=
per_token_group_quant_fp8
(
x
,
group_size
)
self
.
assertTrue
(
torch
.
allclose
(
out
.
to
(
torch
.
float32
),
ref_out
.
to
(
torch
.
float32
),
rtol
=
0.
15
)
torch
.
allclose
(
out
.
to
(
torch
.
float32
),
ref_out
.
to
(
torch
.
float32
),
rtol
=
0.
20
)
)
self
.
assertTrue
(
torch
.
allclose
(
scale
,
ref_scale
))
...
...
@@ -85,6 +86,71 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
self
.
_per_token_group_quant_fp8
(
*
params
)
# For test
def
native_static_quant_fp8
(
x
,
x_s
,
dtype
=
torch
.
float8_e4m3fn
):
"""Function to perform static quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x_s
.
numel
()
==
1
,
"only supports per-tensor scale"
finfo
=
torch
.
finfo
(
dtype
)
fp8_min
=
finfo
.
min
fp8_max
=
finfo
.
max
x_
=
x
.
reshape
(
x
.
numel
()
//
x
.
shape
[
-
1
],
x
.
shape
[
-
1
])
x_s_inv
=
1.0
/
x_s
x_q
=
(
x_
*
x_s_inv
).
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
dtype
)
x_q
=
x_q
.
reshape
(
x
.
shape
)
return
x_q
,
x_s
class
TestStaticQuantFP8
(
unittest
.
TestCase
):
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float32
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
D
=
[
512
,
4096
,
5120
,
13824
]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_static_quant_fp8
(
self
,
num_tokens
,
d
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
x
=
torch
.
rand
(
num_tokens
,
d
,
dtype
=
dtype
)
fp8_max
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
x_s
=
x
.
max
()
/
fp8_max
with
torch
.
inference_mode
():
ref_out
,
_
=
native_static_quant_fp8
(
x
,
x_s
)
out
,
_
=
static_quant_fp8
(
x
,
x_s
,
repeat_scale
=
True
)
self
.
assertTrue
(
torch
.
allclose
(
out
.
to
(
torch
.
float32
),
ref_out
.
to
(
torch
.
float32
),
rtol
=
0.50
)
)
def
test_static_quant_fp8
(
self
):
for
params
in
itertools
.
product
(
self
.
NUM_TOKENS
,
self
.
D
,
self
.
DTYPES
,
self
.
SEEDS
,
):
with
self
.
subTest
(
num_tokens
=
params
[
0
],
d
=
params
[
1
],
dtype
=
params
[
2
],
seed
=
params
[
3
],
):
self
.
_static_quant_fp8
(
*
params
)
# For test
def
native_w8a8_block_fp8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
"""This function performs matrix multiplication with block-wise quantization using native torch.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment