Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc410e86
Unverified
Commit
cc410e86
authored
Jan 01, 2026
by
Kyuyeun Kim
Committed by
GitHub
Jan 02, 2026
Browse files
[Bugfix] Fix weight_loader v1 block scale (#31103)
Signed-off-by:
Kyuyeun Kim
<
kyuyeunk@google.com
>
parent
825c2dc1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
27 deletions
+40
-27
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+40
-27
No files found.
vllm/model_executor/layers/linear.py
View file @
cc410e86
...
...
@@ -80,6 +80,14 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
):
assert
weight_block_size
is
not
None
block_n
=
weight_block_size
[
0
]
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
shard_size
=
(
shard_size
+
block_n
-
1
)
//
block_n
return
shard_size
,
shard_offset
def
adjust_bitsandbytes_4bit_shard
(
param
:
Parameter
,
shard_offsets
:
dict
[
str
,
tuple
[
int
,
int
]],
loaded_shard_id
:
str
)
->
tuple
[
int
,
int
]:
...
...
@@ -763,8 +771,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
if
output_dim
is
not
None
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
self
.
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
self
.
tp_size
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
)
shard_offset
//=
self
.
tp_size
shard_size
//=
self
.
tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
...
...
@@ -867,24 +885,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
if
isinstance
(
param
,
BlockQuantScaleParameter
):
assert
self
.
quant_method
is
not
None
# Assume the weight block size has been set by quant method
assert
hasattr
(
self
,
"weight_block_size"
)
weight_block_size
=
self
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
(
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
+
block_n
-
1
)
//
block_n
)
//
self
.
tp_size
shard_size
=
(
(
self
.
output_sizes
[
loaded_shard_id
]
+
block_n
-
1
)
//
block_n
//
self
.
tp_size
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
)
else
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
self
.
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
self
.
tp_size
shard_offset
//
=
self
.
tp_size
shard_size
//
=
self
.
tp_size
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
...
...
@@ -1066,16 +1077,11 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset
=
self
.
_get_shard_offset_mapping
(
loaded_shard_id
)
shard_size
=
self
.
_get_shard_size_mapping
(
loaded_shard_id
)
# Note(simon): This is needed for Qwen3's fp8 quantization.
if
isinstance
(
param
,
BlockQuantScaleParameter
):
assert
self
.
quant_method
is
not
None
# Assume the weight block size has been set by quant method
assert
hasattr
(
self
,
"weight_block_size"
)
weight_block_size
=
self
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
shard_size
=
(
shard_size
+
block_n
-
1
)
//
block_n
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
)
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
...
...
@@ -1208,6 +1214,13 @@ class QKVParallelLinear(ColumnParallelLinear):
elif
loaded_shard_id
==
"v"
:
shard_offset
=
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
v_head_size
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
)
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment