Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f3731273
Commit
f3731273
authored
Oct 20, 2025
by
zhuwenwen
Browse files
update the layout of load_column_parallel_weight and load_row_parallel_weight
parent
5db8533c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
36 deletions
+33
-36
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+16
-32
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+17
-4
No files found.
vllm/model_executor/layers/linear.py
View file @
f3731273
...
...
@@ -647,7 +647,7 @@ class ColumnParallelLinear(LinearBase):
if
len
(
loaded_weight
.
shape
)
==
0
:
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
i
f
not
envs
.
VLLM_USE_NN
or
self
.
is_quantization
else
loaded_weight
.
t
()
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
,
i
s_quantization
=
self
.
is_quantization
)
def
forward
(
self
,
...
...
@@ -835,7 +835,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return
if
is_gguf_weight
:
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_size
=
loaded_weight
.
size
(
output_dim
)
//
self
.
tp_size
start_idx
=
self
.
tp_rank
*
shard_size
...
...
@@ -1027,20 +1026,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
self
.
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
self
.
tp_size
if
not
envs
.
VLLM_USE_NN
:
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
)
else
:
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
,
is_quantization
=
self
.
is_quantization
)
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
,
is_quantization
=
self
.
is_quantization
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
...
...
@@ -1212,21 +1204,13 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
shard_size
=
(
shard_size
+
block_n
-
1
)
//
block_n
if
not
envs
.
VLLM_USE_NN
:
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
)
else
:
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
,
is_quantization
=
self
.
is_quantization
)
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
,
is_quantization
=
self
.
is_quantization
)
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -1559,7 +1543,7 @@ class RowParallelLinear(LinearBase):
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
i
f
not
envs
.
VLLM_USE_NN
or
self
.
is_quantization
else
loaded_weight
.
t
()
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
,
i
s_quantization
=
self
.
is_quantization
)
def
forward
(
self
,
...
...
vllm/model_executor/parameter.py
View file @
f3731273
...
...
@@ -140,11 +140,18 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def
output_dim
(
self
):
return
self
.
_output_dim
def
load_column_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
shard_size
=
self
.
data
.
shape
[
self
.
output_dim
]
def
load_column_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
is_quantization
:
Optional
[
bool
]):
if
not
envs
.
VLLM_USE_NN
or
len
(
self
.
data
.
shape
)
==
1
or
is_quantization
:
shard_size
=
self
.
data
.
shape
[
self
.
output_dim
]
else
:
shard_size
=
self
.
data
.
shape
[
int
(
not
(
self
.
output_dim
))]
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
self
.
tp_rank
*
shard_size
,
shard_size
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
self
.
data
.
shape
==
loaded_weight
.
shape
self
.
data
.
copy_
(
loaded_weight
)
...
...
@@ -231,8 +238,11 @@ class RowvLLMParameter(BasevLLMParameter):
def
input_dim
(
self
):
return
self
.
_input_dim
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
shard_size
=
self
.
data
.
shape
[
self
.
input_dim
]
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
is_quantization
:
Optional
[
bool
]):
if
not
envs
.
VLLM_USE_NN
or
is_quantization
:
shard_size
=
self
.
data
.
shape
[
self
.
input_dim
]
else
:
shard_size
=
self
.
data
.
shape
[
int
(
not
(
self
.
input_dim
))]
loaded_weight
=
loaded_weight
.
narrow
(
self
.
input_dim
,
self
.
tp_rank
*
shard_size
,
shard_size
)
...
...
@@ -240,6 +250,9 @@ class RowvLLMParameter(BasevLLMParameter):
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
self
.
data
.
shape
==
loaded_weight
.
shape
self
.
data
.
copy_
(
loaded_weight
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment