Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6605af8e
Commit
6605af8e
authored
Oct 12, 2025
by
zhuwenwen
Browse files
update weight_loader_v2 layout of ColumnParallelLinear and MergedColumnParallelLinear
parent
e8700643
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
16 deletions
+56
-16
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+33
-12
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+23
-4
No files found.
vllm/model_executor/layers/linear.py
View file @
6605af8e
...
...
@@ -989,6 +989,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
):
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
loaded_shard_id
is
None
:
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
...
...
@@ -1020,11 +1022,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
.
output_sizes
[:
loaded_shard_id
])
//
self
.
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
self
.
tp_size
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
)
if
not
envs
.
VLLM_USE_NN
:
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
)
else
:
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
,
is_quantization
=
is_quantization
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
...
...
@@ -1164,6 +1174,8 @@ class QKVParallelLinear(ColumnParallelLinear):
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
):
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
loaded_shard_id
is
None
:
# special case for certain models
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
...
...
@@ -1194,12 +1206,21 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
shard_size
=
(
shard_size
+
block_n
-
1
)
//
block_n
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
)
if
not
envs
.
VLLM_USE_NN
:
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
)
else
:
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
,
is_quantization
=
is_quantization
)
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -1534,7 +1555,7 @@ class RowParallelLinear(LinearBase):
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
.
t
()
if
envs
.
VLLM_USE_NN
else
loaded_weight
)
def
forward
(
self
,
...
...
vllm/model_executor/parameter.py
View file @
6605af8e
...
...
@@ -8,6 +8,8 @@ from weakref import WeakValueDictionary
import
torch
from
torch.nn
import
Parameter
import
vllm.envs
as
envs
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
...
@@ -150,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset
=
kwargs
.
get
(
"shard_offset"
)
shard_size
=
kwargs
.
get
(
"shard_size"
)
is_quantization
=
kwargs
.
get
(
"is_quantization"
)
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if
isinstance
(
...
...
@@ -161,11 +164,19 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data
=
self
.
data
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
shard_size
)
if
not
envs
.
VLLM_USE_NN
or
is_quantization
:
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
int
(
not
(
self
.
output_dim
)),
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
self
.
tp_rank
*
shard_size
,
shard_size
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -175,6 +186,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_size
=
kwargs
.
get
(
"shard_size"
)
shard_id
=
kwargs
.
get
(
"shard_id"
)
num_heads
=
kwargs
.
get
(
"num_heads"
)
is_quantization
=
kwargs
.
get
(
"is_quantization"
)
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if
isinstance
(
...
...
@@ -187,11 +199,18 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data
=
self
.
data
shard_id
=
(
self
.
tp_rank
if
shard_id
==
"q"
else
self
.
tp_rank
//
num_heads
)
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
shard_size
)
if
not
envs
.
VLLM_USE_NN
or
len
(
param_data
.
shape
)
==
1
or
is_quantization
:
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
int
(
not
(
self
.
output_dim
)),
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
shard_id
*
shard_size
,
shard_size
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment