Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
40b4284f
Unverified
Commit
40b4284f
authored
Apr 09, 2025
by
Isotr0py
Committed by
GitHub
Apr 08, 2025
Browse files
[Bugfix] Handle `process_weights_after_loading` for `QKVCrossParallelLinear` (#15328)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
4ebc0b96
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
6 deletions
+33
-6
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+21
-6
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+3
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+9
-0
No files found.
vllm/model_executor/layers/linear.py
View file @
40b4284f
...
@@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
...
@@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
prefix
=
f
"
{
prefix
}
.kv_proj_encoder"
)
prefix
=
f
"
{
prefix
}
.kv_proj_encoder"
)
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self
.
q_size
=
self
.
q_proj_decoder
.
output_size_per_partition
self
.
kv_size
=
self
.
kv_proj_encoder
.
num_kv_heads
*
head_size
self
.
kv_size
=
self
.
kv_proj_encoder
.
num_kv_heads
*
head_size
if
bias
:
if
bias
:
...
@@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
...
@@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
def
process_weights_after_loading
(
self
):
for
layer
in
self
.
proj
.
values
():
if
self
.
quant_method
is
not
None
:
self
.
quant_method
.
process_weights_after_loading
(
layer
)
@
property
@
property
def
q_proj_decoder
(
self
)
->
ColumnParallelLinear
:
def
q_proj_decoder
(
self
)
->
ColumnParallelLinear
:
layer
=
self
.
proj
[
"q_proj_decoder"
]
layer
=
self
.
proj
[
"q_proj_decoder"
]
for
name
,
param
in
self
.
named_parameters
():
for
name
,
param
in
self
.
named_parameters
():
target_param
=
getattr
(
layer
,
name
)
target_param
=
getattr
(
layer
,
name
,
None
)
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"q_proj_decoder"
)
if
target_param
is
not
None
:
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"q_proj_decoder"
)
return
layer
return
layer
@
property
@
property
def
kv_proj_encoder
(
self
)
->
QKVParallelLinear
:
def
kv_proj_encoder
(
self
)
->
QKVParallelLinear
:
layer
=
self
.
proj
[
"kv_proj_encoder"
]
layer
=
self
.
proj
[
"kv_proj_encoder"
]
for
name
,
param
in
self
.
named_parameters
():
for
name
,
param
in
self
.
named_parameters
():
target_param
=
getattr
(
layer
,
name
)
target_param
=
getattr
(
layer
,
name
,
None
)
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"kv_proj_encoder"
)
if
target_param
is
not
None
:
self
.
sync_weight_attrs
(
param
,
target_param
,
mode
=
"kv_proj_encoder"
)
return
layer
return
layer
def
sync_weight_attrs
(
def
sync_weight_attrs
(
...
@@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
...
@@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
if
loaded_shard_id
==
"q"
else
self
.
kv_proj_encoder
)
if
loaded_shard_id
==
"q"
else
self
.
kv_proj_encoder
)
target_param
=
self
.
select_proj_params
(
layer
,
param
)
target_param
=
self
.
select_proj_params
(
layer
,
param
)
shard_id_args
=
(
loaded_shard_id
,
)
if
loaded_shard_id
!=
"q"
else
()
shard_id_args
=
(
loaded_shard_id
,
)
if
loaded_shard_id
!=
"q"
else
()
layer
.
weight_loader
(
target_param
,
loaded_weight
,
*
shard_id_args
)
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
:
layer
.
weight_loader_v2
(
target_param
,
loaded_weight
,
*
shard_id_args
)
else
:
layer
.
weight_loader
(
target_param
,
loaded_weight
,
*
shard_id_args
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", q_size=
{
self
.
q_
proj_decoder
.
output_size_per_partition
}
"
s
+=
f
", q_size=
{
self
.
q_
size
}
"
s
+=
f
", kv_size=
{
self
.
kv_size
}
"
s
+=
f
", kv_size=
{
self
.
kv_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
40b4284f
...
@@ -254,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -254,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"scale_type"
:
"weight_scale"
})
layer
.
register_parameter
(
"weight_scale"
,
scale
)
layer
.
register_parameter
(
"weight_scale"
,
scale
)
else
:
else
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
...
@@ -268,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -268,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"scale_type"
:
"weight_scale"
})
# The weight_scale_inv name is intentional for deepseekv3
# The weight_scale_inv name is intentional for deepseekv3
layer
.
register_parameter
(
"weight_scale_inv"
,
scale
)
layer
.
register_parameter
(
"weight_scale_inv"
,
scale
)
...
@@ -278,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -278,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"scale_type"
:
"input_scale"
})
layer
.
register_parameter
(
"input_scale"
,
scale
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
else
:
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
layer
.
register_parameter
(
"input_scale"
,
None
)
...
...
vllm/model_executor/model_loader/loader.py
View file @
40b4284f
...
@@ -33,11 +33,15 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -33,11 +33,15 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVCrossParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
# yapf: enable
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizeMethodBase
)
QuantizeMethodBase
)
from
vllm.model_executor.model_loader.tensorizer
import
(
from
vllm.model_executor.model_loader.tensorizer
import
(
...
@@ -160,6 +164,11 @@ def _initialize_model(
...
@@ -160,6 +164,11 @@ def _initialize_model(
def
_process_weights_after_loading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
def
_process_weights_after_loading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
target_device
:
torch
.
device
)
->
None
:
target_device
:
torch
.
device
)
->
None
:
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
QKVCrossParallelLinear
):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module
.
process_weights_after_loading
()
continue
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
# When quant methods need to process weights after loading
# When quant methods need to process weights after loading
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment