Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dd865bef
Unverified
Commit
dd865bef
authored
Mar 18, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Mar 17, 2025
Browse files
[Hotfix] solve fp8 w8a8 ci test fail (#4531)
parent
d373a48c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
113 additions
and
420 deletions
+113
-420
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+12
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+73
-70
python/sglang/srt/layers/quantization/gptq.py
python/sglang/srt/layers/quantization/gptq.py
+10
-6
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+16
-341
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+2
-0
No files found.
python/sglang/srt/layers/quantization/fp8.py
View file @
dd865bef
...
@@ -799,8 +799,17 @@ class Fp8MoEMethod:
...
@@ -799,8 +799,17 @@ class Fp8MoEMethod:
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
)
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
(
if
_is_cuda
:
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
sgl_scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
else
:
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
vllm_ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
]
)
)
start
+=
shard_size
start
+=
shard_size
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
dd865bef
...
@@ -15,6 +15,13 @@ from sglang.srt.utils import (
...
@@ -15,6 +15,13 @@ from sglang.srt.utils import (
is_hip
,
is_hip
,
)
)
try
:
import
vllm
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
use_vllm_cutlass_w8a8_fp8_kernel
=
get_bool_env_var
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
)
use_vllm_cutlass_w8a8_fp8_kernel
=
get_bool_env_var
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -27,13 +34,8 @@ if _is_cuda:
...
@@ -27,13 +34,8 @@ if _is_cuda:
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_quant_fp8
if
use_vllm_cutlass_w8a8_fp8_kernel
:
if
use_vllm_cutlass_w8a8_fp8_kernel
and
VLLM_AVAILABLE
:
try
:
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
else
:
else
:
from
sgl_kernel
import
fp8_scaled_mm
from
sgl_kernel
import
fp8_scaled_mm
...
@@ -253,6 +255,7 @@ def apply_fp8_linear(
...
@@ -253,6 +255,7 @@ def apply_fp8_linear(
# torch.scaled_mm supports per tensor weights + activations only
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
# so fallback to naive if per channel or per token
else
:
per_tensor_weights
=
weight_scale
.
numel
()
==
1
per_tensor_weights
=
weight_scale
.
numel
()
==
1
per_tensor_activations
=
x_scale
.
numel
()
==
1
per_tensor_activations
=
x_scale
.
numel
()
==
1
...
...
python/sglang/srt/layers/quantization/gptq.py
View file @
dd865bef
...
@@ -6,7 +6,6 @@ import torch
...
@@ -6,7 +6,6 @@ import torch
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.utils
import
scalar_types
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
...
@@ -133,11 +132,16 @@ class GPTQConfig(QuantizationConfig):
...
@@ -133,11 +132,16 @@ class GPTQConfig(QuantizationConfig):
class
GPTQMarlinConfig
(
QuantizationConfig
):
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
"""Config class for GPTQ Marlin"""
if
VLLM_AVAILABLE
:
from
vllm.scalar_type
import
scalar_types
# (num_bits, is_sym) -> quant_type
# (num_bits, is_sym) -> quant_type
TYPE_MAP
=
{
TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
}
else
:
raise
ImportError
(
"vllm is not installed"
)
def
__init__
(
def
__init__
(
self
,
self
,
...
...
python/sglang/srt/layers/quantization/utils.py
View file @
dd865bef
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/scalar_type.py
import
functools
import
struct
from
dataclasses
import
dataclass
from
enum
import
Enum
from
types
import
MappingProxyType
from
types
import
MappingProxyType
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Mapping
,
Tuple
,
Union
import
torch
import
torch
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
def
is_layer_skipped
(
def
is_layer_skipped
(
prefix
:
str
,
prefix
:
str
,
...
@@ -102,341 +106,12 @@ def requantize_with_max_scale(
...
@@ -102,341 +106,12 @@ def requantize_with_max_scale(
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
end
=
start
+
logical_width
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
weight
[
start
:
end
,
:],
weight_scale
[
idx
])
weight_dq
=
per_tensor_dequantize
(
weight
[
start
:
end
,
:],
weight_scale
[
idx
])
weight
[
start
:
end
,
:],
_
=
ops
.
scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
if
_is_cuda
:
start
=
end
weight
[
start
:
end
,
:],
_
=
sgl_scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
return
max_w_scale
,
weight
# Mirrors enum in `core/scalar_type.hpp`
class
NanRepr
(
Enum
):
NONE
=
0
# nans are not supported
IEEE_754
=
1
# nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN
=
2
# nans are: Exp all 1s, mantissa all 1s
# This ScalarType class is a parallel implementation of the C++ ScalarType
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
# in sync until the inductor fully supports custom C++ classes.
@
dataclass
(
frozen
=
True
)
class
ScalarType
:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent
:
int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa
:
int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
signed
:
bool
"If the type is signed (i.e. has a sign bit)"
bias
:
int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
_finite_values_only
:
bool
=
False
"""
Private: if infs are supported, used `has_infs()` instead.
"""
nan_repr
:
NanRepr
=
NanRepr
.
IEEE_754
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
def
_floating_point_max_int
(
self
)
->
int
:
assert
(
self
.
mantissa
<=
52
and
self
.
exponent
<=
11
),
f
"Cannot represent max/min as a double for type
{
self
.
__str__
()
}
"
max_mantissa
=
(
1
<<
self
.
mantissa
)
-
1
if
self
.
nan_repr
==
NanRepr
.
EXTD_RANGE_MAX_MIN
:
max_mantissa
=
max_mantissa
-
1
max_exponent
=
(
1
<<
self
.
exponent
)
-
2
if
self
.
nan_repr
==
NanRepr
.
EXTD_RANGE_MAX_MIN
or
self
.
nan_repr
==
NanRepr
.
NONE
:
assert
(
self
.
exponent
<
11
),
f
"Cannot represent max/min as a double for type
{
self
.
__str__
()
}
"
max_exponent
=
max_exponent
+
1
# adjust the exponent to match that of a double
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
# e is the exponent bits), there is some precedent for non-standard
# biases, example `float8_e4m3b11fnuz` here:
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
# complication we are just assuming the standard exponent bias until
# there is a need to support non-standard biases
exponent_bias
=
(
1
<<
(
self
.
exponent
-
1
))
-
1
exponent_bias_double
=
(
1
<<
10
)
-
1
# double e = 11
max_exponent_double
=
max_exponent
-
exponent_bias
+
exponent_bias_double
# shift the mantissa and exponent into the proper positions for an
# IEEE double and bitwise-or them together.
return
(
max_mantissa
<<
(
52
-
self
.
mantissa
))
|
(
max_exponent_double
<<
52
)
def
_floating_point_max
(
self
)
->
float
:
double_raw
=
self
.
_floating_point_max_int
()
return
struct
.
unpack
(
"!d"
,
struct
.
pack
(
"!Q"
,
double_raw
))[
0
]
def
_raw_max
(
self
)
->
Union
[
int
,
float
]:
if
self
.
is_floating_point
():
return
self
.
_floating_point_max
()
else
:
assert
(
self
.
size_bits
<
64
or
self
.
size_bits
==
64
and
self
.
is_signed
()
),
"Cannot represent max as an int"
return
(
1
<<
self
.
mantissa
)
-
1
def
_raw_min
(
self
)
->
Union
[
int
,
float
]:
if
self
.
is_floating_point
():
assert
(
self
.
is_signed
()
),
"We currently assume all floating point types are signed"
sign_bit_double
=
1
<<
63
max_raw
=
self
.
_floating_point_max_int
()
min_raw
=
max_raw
|
sign_bit_double
return
struct
.
unpack
(
"!d"
,
struct
.
pack
(
"!Q"
,
min_raw
))[
0
]
else
:
assert
(
not
self
.
is_signed
()
or
self
.
size_bits
<=
64
),
"Cannot represent min as a int64_t"
if
self
.
is_signed
():
return
-
(
1
<<
(
self
.
size_bits
-
1
))
else
:
else
:
return
0
weight
[
start
:
end
,
:],
_
=
vllm_ops
.
scaled_fp8_quant
(
weight_dq
,
max_w_scale
@
functools
.
cached_property
def
id
(
self
)
->
int
:
"""
Convert the ScalarType to an int which can be passed to pytorch custom
ops. This layout of the int must be kept in sync with the C++
ScalarType's from_id method.
"""
val
=
0
offset
=
0
def
or_and_advance
(
member
,
bit_width
):
nonlocal
val
nonlocal
offset
bit_mask
=
(
1
<<
bit_width
)
-
1
val
=
val
|
(
int
(
member
)
&
bit_mask
)
<<
offset
offset
=
offset
+
bit_width
or_and_advance
(
self
.
exponent
,
8
)
or_and_advance
(
self
.
mantissa
,
8
)
or_and_advance
(
self
.
signed
,
1
)
or_and_advance
(
self
.
bias
,
32
)
or_and_advance
(
self
.
_finite_values_only
,
1
)
or_and_advance
(
self
.
nan_repr
.
value
,
8
)
assert
offset
<=
64
,
f
"ScalarType fields too big
{
offset
}
to fit into an int64"
return
val
@
property
def
size_bits
(
self
)
->
int
:
return
self
.
exponent
+
self
.
mantissa
+
int
(
self
.
signed
)
def
min
(
self
)
->
Union
[
int
,
float
]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
return
self
.
_raw_min
()
-
self
.
bias
def
max
(
self
)
->
Union
[
int
,
float
]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
return
self
.
_raw_max
()
-
self
.
bias
def
is_signed
(
self
)
->
bool
:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
return
self
.
signed
def
is_floating_point
(
self
)
->
bool
:
"If the type is a floating point type"
return
self
.
exponent
!=
0
def
is_integer
(
self
)
->
bool
:
"If the type is an integer type"
return
self
.
exponent
==
0
def
has_bias
(
self
)
->
bool
:
"If the type has a non-zero bias"
return
self
.
bias
!=
0
def
has_infs
(
self
)
->
bool
:
"If the type is floating point and supports infinity"
return
not
self
.
_finite_values_only
def
has_nans
(
self
)
->
bool
:
return
self
.
nan_repr
!=
NanRepr
.
NONE
.
value
def
is_ieee_754
(
self
)
->
bool
:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return
self
.
nan_repr
==
NanRepr
.
IEEE_754
.
value
and
not
self
.
_finite_values_only
def
__str__
(
self
)
->
str
:
"""
naming generally follows: https://github.com/jax-ml/ml_dtypes
for floating point types (leading f) the scheme is:
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
flags:
- no-flags: means it follows IEEE 754 conventions
- f: means finite values only (no infinities)
- n: means nans are supported (non-standard encoding)
for integer types the scheme is:
`[u]int<size_bits>[b<bias>]`
- if bias is not present it means its zero
"""
if
self
.
is_floating_point
():
ret
=
(
"float"
+
str
(
self
.
size_bits
)
+
"_e"
+
str
(
self
.
exponent
)
+
"m"
+
str
(
self
.
mantissa
)
)
if
not
self
.
is_ieee_754
():
if
self
.
_finite_values_only
:
ret
=
ret
+
"f"
if
self
.
nan_repr
!=
NanRepr
.
NONE
:
ret
=
ret
+
"n"
return
ret
else
:
ret
=
(
"int"
if
self
.
is_signed
()
else
"uint"
)
+
str
(
self
.
size_bits
)
if
self
.
has_bias
():
ret
=
ret
+
"b"
+
str
(
self
.
bias
)
return
ret
def
__repr__
(
self
)
->
str
:
return
"ScalarType."
+
self
.
__str__
()
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def
__len__
(
self
)
->
int
:
raise
TypeError
#
# Convenience Constructors
#
@
classmethod
def
int_
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
"ScalarType"
:
"Create a signed integer scalar type (size_bits includes sign-bit)."
ret
=
cls
(
0
,
size_bits
-
1
,
True
,
bias
if
bias
else
0
)
ret
.
id
# noqa B018: make sure the id is cached
return
ret
@
classmethod
def
uint
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
"ScalarType"
:
"""Create a unsigned integer scalar type."""
ret
=
cls
(
0
,
size_bits
,
False
,
bias
if
bias
else
0
)
ret
.
id
# noqa B018: make sure the id is cached
return
ret
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
"ScalarType"
:
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
assert
mantissa
>
0
and
exponent
>
0
ret
=
cls
(
exponent
,
mantissa
,
True
,
0
)
ret
.
id
# noqa B018: make sure the id is cached
return
ret
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
NanRepr
)
->
"ScalarType"
:
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
assert
mantissa
>
0
and
exponent
>
0
assert
nan_repr
!=
NanRepr
.
IEEE_754
,
(
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"
)
)
ret
=
cls
(
exponent
,
mantissa
,
True
,
0
,
finite_values_only
,
nan_repr
)
start
=
end
ret
.
id
# noqa B018: make sure the id is cached
return
ret
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class
scalar_types
:
int4
=
ScalarType
.
int_
(
4
,
None
)
uint4
=
ScalarType
.
uint
(
4
,
None
)
int8
=
ScalarType
.
int_
(
8
,
None
)
uint8
=
ScalarType
.
uint
(
8
,
None
)
float8_e4m3fn
=
ScalarType
.
float_
(
4
,
3
,
True
,
NanRepr
.
EXTD_RANGE_MAX_MIN
)
float8_e5m2
=
ScalarType
.
float_IEEE754
(
5
,
2
)
float16_e8m7
=
ScalarType
.
float_IEEE754
(
8
,
7
)
float16_e5m10
=
ScalarType
.
float_IEEE754
(
5
,
10
)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f
=
ScalarType
.
float_
(
3
,
2
,
True
,
NanRepr
.
NONE
)
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float4_e2m1fn
=
ScalarType
.
float_
(
2
,
1
,
True
,
NanRepr
.
NONE
)
# "gptq" types
uint2b2
=
ScalarType
.
uint
(
2
,
2
)
uint3b4
=
ScalarType
.
uint
(
3
,
4
)
uint4b8
=
ScalarType
.
uint
(
4
,
8
)
uint8b128
=
ScalarType
.
uint
(
8
,
128
)
# colloquial names
return
max_w_scale
,
weight
bfloat16
=
float16_e8m7
float16
=
float16_e5m10
scripts/ci_install_dependency.sh
View file @
dd865bef
...
@@ -27,3 +27,5 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12
...
@@ -27,3 +27,5 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12
# For DeepSeek-VL2
# For DeepSeek-VL2
pip
install
timm
pip
install
timm
pip
install
sgl-kernel
==
0.0.5.post3
--force-reinstall
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment