Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a6c13752
Unverified
Commit
a6c13752
authored
Feb 25, 2026
by
Isotr0py
Committed by
GitHub
Feb 24, 2026
Browse files
[Misc] Add shard_id validation for MergedColumnLinear (#35055)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
4572a06a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
7 deletions
+67
-7
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+67
-7
No files found.
vllm/model_executor/layers/linear.py
View file @
a6c13752
...
@@ -66,15 +66,23 @@ WEIGHT_LOADER_V2_SUPPORTED = [
...
@@ -66,15 +66,23 @@ WEIGHT_LOADER_V2_SUPPORTED = [
]
]
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
def
adjust_marlin_shard
(
marlin_tile_size
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
param
:
Parameter
,
shard_size
:
int
,
shard_offset
:
int
,
)
->
tuple
[
int
,
int
]:
marlin_tile_size
:
int
|
None
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
if
marlin_tile_size
is
None
:
if
marlin_tile_size
is
None
:
return
shard_size
,
shard_offset
return
shard_size
,
shard_offset
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
):
def
adjust_block_scale_shard
(
weight_block_size
:
tuple
[
int
,
...]
|
None
,
shard_size
:
int
,
shard_offset
:
int
,
)
->
tuple
[
int
,
int
]:
assert
weight_block_size
is
not
None
assert
weight_block_size
is
not
None
block_n
=
weight_block_size
[
0
]
block_n
=
weight_block_size
[
0
]
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
...
@@ -83,7 +91,9 @@ def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
...
@@ -83,7 +91,9 @@ def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
def
adjust_bitsandbytes_4bit_shard
(
def
adjust_bitsandbytes_4bit_shard
(
param
:
Parameter
,
shard_offsets
:
dict
[
str
,
tuple
[
int
,
int
]],
loaded_shard_id
:
str
param
:
Parameter
,
shard_offsets
:
dict
[
str
,
tuple
[
int
,
int
]],
loaded_shard_id
:
str
,
)
->
tuple
[
int
,
int
]:
)
->
tuple
[
int
,
int
]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
...
@@ -97,7 +107,11 @@ def adjust_bitsandbytes_4bit_shard(
...
@@ -97,7 +107,11 @@ def adjust_bitsandbytes_4bit_shard(
return
quantized_size
,
quantized_offset
return
quantized_size
,
quantized_offset
def
adjust_scalar_to_fused_array
(
param
,
loaded_weight
,
shard_id
):
def
adjust_scalar_to_fused_array
(
param_data
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
int
|
str
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""For fused modules (QKV and MLP) we have an array of length
"""For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param
N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to
is an array of length N. The loaded_weight corresponds to
...
@@ -117,12 +131,14 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
...
@@ -117,12 +131,14 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
assert
loaded_weight
.
shape
[
0
]
==
1
assert
loaded_weight
.
shape
[
0
]
==
1
loaded_weight
=
loaded_weight
[
0
]
loaded_weight
=
loaded_weight
[
0
]
return
param
[
shard_id
],
loaded_weight
return
param
_data
[
shard_id
],
loaded_weight
# TODO(Isotr0py): We might need a more flexible structure to handle
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
# bitsandbytes shard offsets.
def
left_shift_bitsandbytes_4bit_shard
(
bnb_weight_attrs
:
dict
[
str
,
Any
]):
def
left_shift_bitsandbytes_4bit_shard
(
bnb_weight_attrs
:
dict
[
str
,
Any
],
)
->
tuple
[
dict
[
str
,
Any
],
dict
[
str
,
Any
]]:
"""
"""
Separate the BitsAndBytes 4-bit shard.
Separate the BitsAndBytes 4-bit shard.
...
@@ -681,12 +697,41 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -681,12 +697,41 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp
=
disable_tp
,
disable_tp
=
disable_tp
,
)
)
def
validate_shard_id
(
self
,
loaded_shard_id
:
int
|
tuple
[
int
,
...]
|
None
):
if
loaded_shard_id
is
None
:
return
if
isinstance
(
loaded_shard_id
,
tuple
):
for
idx
in
loaded_shard_id
:
if
not
(
0
<=
idx
<
len
(
self
.
output_sizes
)):
raise
ValueError
(
f
"Shard id index
{
idx
}
should be between 0 and "
f
"
{
len
(
self
.
output_sizes
)
-
1
}
. Got shard id
{
loaded_shard_id
}
."
)
if
len
(
loaded_shard_id
)
>
1
and
any
(
b
-
a
!=
1
for
a
,
b
in
zip
(
loaded_shard_id
[:
-
1
],
loaded_shard_id
[
1
:])
):
raise
ValueError
(
"Shard id with multiple indices should be consecutive. "
f
"Got shard id
{
loaded_shard_id
}
."
)
return
elif
isinstance
(
loaded_shard_id
,
int
):
if
loaded_shard_id
<
0
or
loaded_shard_id
>=
len
(
self
.
output_sizes
):
raise
ValueError
(
f
"Shard id should be between 0 and
{
len
(
self
.
output_sizes
)
-
1
}
. "
f
"Got shard id
{
loaded_shard_id
}
."
)
return
raise
ValueError
(
"This line should not be reached"
)
def
weight_loader
(
def
weight_loader
(
self
,
self
,
param
:
Parameter
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
):
):
self
.
validate_shard_id
(
loaded_shard_id
)
# FIXME(Isotr0py): Enable tuple shard_id for BNB quantization.
if
isinstance
(
loaded_shard_id
,
tuple
):
if
isinstance
(
loaded_shard_id
,
tuple
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Shard id with multiple indices is not supported in weight_loader, "
"Shard id with multiple indices is not supported in weight_loader, "
...
@@ -874,6 +919,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -874,6 +919,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
):
):
self
.
validate_shard_id
(
loaded_shard_id
)
if
loaded_shard_id
is
None
or
isinstance
(
loaded_shard_id
,
tuple
):
if
loaded_shard_id
is
None
or
isinstance
(
loaded_shard_id
,
tuple
):
if
isinstance
(
param
,
PerTensorScaleParameter
):
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
...
@@ -1005,6 +1051,18 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -1005,6 +1051,18 @@ class QKVParallelLinear(ColumnParallelLinear):
disable_tp
=
disable_tp
,
disable_tp
=
disable_tp
,
)
)
def
validate_shard_id
(
self
,
loaded_shard_id
:
str
|
None
):
if
loaded_shard_id
is
None
:
return
if
isinstance
(
loaded_shard_id
,
str
):
if
loaded_shard_id
not
in
[
"q"
,
"k"
,
"v"
]:
raise
ValueError
(
"Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
f
"got shard id
{
loaded_shard_id
}
."
)
return
raise
ValueError
(
"This line should not be reached"
)
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
shard_offset_mapping
=
{
shard_offset_mapping
=
{
"q"
:
0
,
"q"
:
0
,
...
@@ -1073,6 +1131,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -1073,6 +1131,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
str
|
None
=
None
,
loaded_shard_id
:
str
|
None
=
None
,
):
):
self
.
validate_shard_id
(
loaded_shard_id
)
if
loaded_shard_id
is
None
:
# special case for certain models
if
loaded_shard_id
is
None
:
# special case for certain models
if
isinstance
(
param
,
PerTensorScaleParameter
):
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_qkv_weight
(
param
.
load_qkv_weight
(
...
@@ -1112,6 +1171,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -1112,6 +1171,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
str
|
None
=
None
,
loaded_shard_id
:
str
|
None
=
None
,
):
):
self
.
validate_shard_id
(
loaded_shard_id
)
# Special case for GGUF
# Special case for GGUF
# initialize GGUF param after we know the quantize type
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment