Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
82605747
Unverified
Commit
82605747
authored
Sep 27, 2025
by
Yueyang Pan
Committed by
GitHub
Sep 27, 2025
Browse files
fix: fp8 quantization failure of qwen 2.5 VL 7B model (#10112)
Signed-off-by:
PanJason
<
pyyjason@gmail.com
>
parent
37f3325b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
81 additions
and
14 deletions
+81
-14
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+21
-4
python/sglang/srt/layers/parameter.py
python/sglang/srt/layers/parameter.py
+23
-6
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+13
-3
python/sglang/srt/layers/utils.py
python/sglang/srt/layers/utils.py
+23
-0
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+1
-1
No files found.
python/sglang/srt/layers/linear.py
View file @
82605747
...
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
...
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
_ColumnvLLMParameter
,
_ColumnvLLMParameter
,
)
)
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.utils
import
pad_or_narrow_weight
from
sglang.srt.utils
import
is_cpu
,
is_npu
,
set_weight_attrs
from
sglang.srt.utils
import
is_cpu
,
is_npu
,
set_weight_attrs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -625,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -625,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# bitsandbytes loads the weights of the specific portion
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
# no need to narrow here
if
not
use_bitsandbytes_4bit
and
not
self
.
use_presharded_weights
:
if
not
use_bitsandbytes_4bit
and
not
self
.
use_presharded_weights
:
loaded_weight
=
loaded_weight
.
narrow
(
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
output_dim
,
start_idx
,
shard_size
end_idx
=
start_idx
+
shard_size
)
if
end_idx
>
loaded_weight
.
shape
[
output_dim
]:
loaded_weight
=
pad_or_narrow_weight
(
loaded_weight
,
output_dim
,
start_idx
,
shard_size
)
else
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for AQLM codebooks.
# Special case for AQLM codebooks.
elif
is_metadata
:
elif
is_metadata
:
...
@@ -1302,7 +1310,16 @@ class RowParallelLinear(LinearBase):
...
@@ -1302,7 +1310,16 @@ class RowParallelLinear(LinearBase):
shard_size
,
shard_size
,
)
)
else
:
else
:
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
end_idx
=
start_idx
+
shard_size
if
end_idx
>
loaded_weight
.
shape
[
input_dim
]:
loaded_weight
=
pad_or_narrow_weight
(
loaded_weight
,
input_dim
,
start_idx
,
shard_size
)
else
:
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
# Special case for loading scales off disk, which often do not
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
# have a shape (such as in the case of AutoFP8).
...
...
python/sglang/srt/layers/parameter.py
View file @
82605747
...
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
...
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
sglang.srt.layers.utils
import
pad_or_narrow_weight
from
sglang.srt.utils
import
is_cpu
from
sglang.srt.utils
import
is_cpu
__all__
=
[
__all__
=
[
...
@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
)
)
else
:
else
:
if
not
use_presharded_weights
:
if
not
use_presharded_weights
:
loaded_weight
=
loaded_weight
.
narrow
(
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
self
.
output_dim
,
tp_rank
*
shard_size
,
shard_size
start_idx
=
tp_rank
*
shard_size
)
end_idx
=
start_idx
+
shard_size
if
end_idx
>
loaded_weight
.
shape
[
self
.
output_dim
]:
loaded_weight
=
pad_or_narrow_weight
(
loaded_weight
,
self
.
output_dim
,
start_idx
,
shard_size
)
else
:
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
start_idx
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
...
@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter):
...
@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter):
return
return
else
:
else
:
loaded_weight
=
loaded_weight
.
narrow
(
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
self
.
input_dim
,
tp_rank
*
shard_size
,
shard_size
start_idx
=
tp_rank
*
shard_size
)
end_idx
=
start_idx
+
shard_size
if
end_idx
>
loaded_weight
.
shape
[
self
.
input_dim
]:
loaded_weight
=
pad_or_narrow_weight
(
loaded_weight
,
self
.
input_dim
,
start_idx
,
shard_size
)
else
:
loaded_weight
=
loaded_weight
.
narrow
(
self
.
input_dim
,
start_idx
,
shard_size
)
if
len
(
loaded_weight
.
shape
)
==
0
:
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
loaded_weight
=
loaded_weight
.
reshape
(
1
)
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
82605747
...
@@ -393,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
...
@@ -393,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x
.
dtype
,
x
.
dtype
,
True
,
# is_vnni
True
,
# is_vnni
)
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
return
int8_scaled_mm
(
x_q_2d
=
x_q
.
view
(
-
1
,
x_q
.
shape
[
-
1
])
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
x_scale_2d
=
x_scale
.
view
(
-
1
,
x_scale
.
shape
[
-
1
])
output_shape
=
[
*
x_q
.
shape
[:
-
1
],
layer
.
weight
.
shape
[
1
]]
output
=
int8_scaled_mm
(
x_q_2d
,
layer
.
weight
,
x_scale_2d
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
)
)
return
output
.
view
(
output_shape
)
class
W8A8Int8MoEMethod
(
FusedMoEMethodBase
):
class
W8A8Int8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for INT8.
"""MoE method for INT8.
...
...
python/sglang/srt/layers/utils.py
View file @
82605747
...
@@ -15,6 +15,29 @@ def get_layer_id(weight_name):
...
@@ -15,6 +15,29 @@ def get_layer_id(weight_name):
return
None
return
None
def
pad_or_narrow_weight
(
loaded_weight
:
torch
.
Tensor
,
input_dim
:
int
,
start_idx
:
int
,
shard_size
:
int
)
->
torch
.
Tensor
:
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
valid_size
=
max
(
loaded_weight
.
shape
[
input_dim
]
-
start_idx
,
0
)
if
valid_size
>
0
:
loaded_slice
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
valid_size
)
pad_shape
=
list
(
loaded_weight
.
shape
)
pad_shape
[
input_dim
]
=
shard_size
-
valid_size
pad
=
torch
.
zeros
(
pad_shape
,
dtype
=
loaded_weight
.
dtype
,
device
=
loaded_weight
.
device
)
return
torch
.
cat
([
loaded_slice
,
pad
],
dim
=
input_dim
)
# All padding
pad_shape
=
list
(
loaded_weight
.
shape
)
pad_shape
[
input_dim
]
=
shard_size
return
torch
.
zeros
(
pad_shape
,
dtype
=
loaded_weight
.
dtype
,
device
=
loaded_weight
.
device
)
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
82605747
...
@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self
.
fullatt_block_indexes
=
vision_config
.
fullatt_block_indexes
self
.
fullatt_block_indexes
=
vision_config
.
fullatt_block_indexes
self
.
window_size
=
vision_config
.
window_size
self
.
window_size
=
vision_config
.
window_size
self
.
patch_size
=
vision_config
.
patch_size
self
.
patch_size
=
vision_config
.
patch_size
mlp_hidden_size
:
int
=
vision_config
.
intermediate_size
mlp_hidden_size
:
int
=
((
vision_config
.
intermediate_size
+
7
)
//
8
)
*
8
self
.
patch_embed
=
Qwen2_5_VisionPatchEmbed
(
self
.
patch_embed
=
Qwen2_5_VisionPatchEmbed
(
patch_size
=
patch_size
,
patch_size
=
patch_size
,
temporal_patch_size
=
temporal_patch_size
,
temporal_patch_size
=
temporal_patch_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment