Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e17e4488
Unverified
Commit
e17e4488
authored
Mar 05, 2025
by
Isotr0py
Committed by
GitHub
Mar 05, 2025
Browse files
[LoRA] Remove linear hack outside transformers backend (#14177)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
257e200a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
142 additions
and
105 deletions
+142
-105
vllm/lora/layers.py
vllm/lora/layers.py
+30
-21
vllm/lora/utils.py
vllm/lora/utils.py
+0
-10
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+110
-61
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+2
-13
No files found.
vllm/lora/layers.py
View file @
e17e4488
...
@@ -395,17 +395,20 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
...
@@ -395,17 +395,20 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if
x
.
ndim
==
3
and
output
.
ndim
==
3
:
output
=
output
.
flatten
(
0
,
1
)
x
=
x
.
flatten
(
0
,
1
)
self
.
punica_wrapper
.
add_lora_linear
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
punica_wrapper
.
add_lora_linear
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
self
.
lora_b_stacked
,
self
.
lora_bias_stacked
,
1.0
,
self
.
lora_bias_stacked
,
1.0
,
self
.
output_slices
)
self
.
output_slices
)
return
output
return
output
@
classmethod
def
get_source_layer
(
cls
,
source_layer
:
nn
.
Module
)
->
type
:
# Check parent_cls in case source_layer is a HFCompatibleLinear.
return
getattr
(
source_layer
,
"parent_cls"
,
type
(
source_layer
))
class
ReplicatedLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
class
ReplicatedLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
...
@@ -418,7 +421,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -418,7 +421,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
def
forward
(
def
forward
(
self
,
input_
:
torch
.
Tensor
self
,
input_
:
torch
.
Tensor
)
->
Tuple
[
Optional
[
torch
.
Tensor
]
,
Optional
[
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]
]
:
"""Forward of ReplicatedLinearWithLoRA
"""Forward of ReplicatedLinearWithLoRA
Args:
Args:
...
@@ -436,6 +439,10 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -436,6 +439,10 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
output_bias
=
(
self
.
base_layer
.
bias
output_bias
=
(
self
.
base_layer
.
bias
if
self
.
base_layer
.
skip_bias_add
else
None
)
if
self
.
base_layer
.
skip_bias_add
else
None
)
if
not
self
.
base_layer
.
return_bias
:
return
output
return
output
,
output_bias
return
output
,
output_bias
# ReplicatedLinear should always be replaced, regardless of the fully
# ReplicatedLinear should always be replaced, regardless of the fully
...
@@ -448,8 +455,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -448,8 +455,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list
:
List
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
],
model_config
:
Optional
[
PretrainedConfig
],
)
->
bool
:
)
->
bool
:
source_layer
=
cls
.
get_source_layer
(
source_layer
)
return
type
(
source_layer
)
is
ReplicatedLinear
return
source_layer
is
ReplicatedLinear
class
ColumnParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
class
ColumnParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
...
@@ -512,7 +518,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -512,7 +518,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def
forward
(
def
forward
(
self
,
input_
:
torch
.
Tensor
self
,
input_
:
torch
.
Tensor
)
->
Tuple
[
Optional
[
torch
.
Tensor
]
,
Optional
[
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]
]
:
"""Forward of ColumnParallelLinear
"""Forward of ColumnParallelLinear
Args:
Args:
...
@@ -532,6 +538,10 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -532,6 +538,10 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
if
not
self
.
base_layer
.
return_bias
:
return
output
output_bias
=
(
self
.
base_layer
.
bias
output_bias
=
(
self
.
base_layer
.
bias
if
self
.
base_layer
.
skip_bias_add
else
None
)
if
self
.
base_layer
.
skip_bias_add
else
None
)
return
output
,
output_bias
return
output
,
output_bias
...
@@ -545,9 +555,8 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -545,9 +555,8 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list
:
List
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
],
model_config
:
Optional
[
PretrainedConfig
],
)
->
bool
:
)
->
bool
:
source_layer
=
cls
.
get_source_layer
(
source_layer
)
return
type
(
source_layer
)
is
ColumnParallelLinear
or
(
return
source_layer
is
ColumnParallelLinear
or
(
type
(
source_layer
)
is
MergedColumnParallelLinear
source_layer
is
MergedColumnParallelLinear
and
len
(
packed_modules_list
)
==
1
)
and
len
(
packed_modules_list
)
==
1
)
...
@@ -689,8 +698,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -689,8 +698,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
packed_modules_list
:
List
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
],
model_config
:
Optional
[
PretrainedConfig
],
)
->
bool
:
)
->
bool
:
source_layer
=
cls
.
get_source_layer
(
source_layer
)
return
(
type
(
source_layer
)
is
MergedColumnParallelLinear
return
(
source_layer
is
MergedColumnParallelLinear
and
len
(
packed_modules_list
)
==
2
)
and
len
(
packed_modules_list
)
==
2
)
...
@@ -758,8 +766,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -758,8 +766,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
source_layer
=
cls
.
get_source_layer
(
source_layer
)
return
type
(
source_layer
)
is
QKVParallelLinear
and
len
(
return
source_layer
is
QKVParallelLinear
and
len
(
packed_modules_list
)
==
1
packed_modules_list
)
==
1
...
@@ -820,8 +827,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
...
@@ -820,8 +827,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
packed_modules_list
:
List
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
],
model_config
:
Optional
[
PretrainedConfig
],
)
->
bool
:
)
->
bool
:
source_layer
=
cls
.
get_source_layer
(
source_layer
)
return
(
type
(
source_layer
)
is
QKVParallelLinear
return
(
source_layer
is
QKVParallelLinear
and
len
(
packed_modules_list
)
==
3
)
and
len
(
packed_modules_list
)
==
3
)
...
@@ -855,7 +861,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -855,7 +861,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def
forward
(
def
forward
(
self
,
input_
:
torch
.
Tensor
self
,
input_
:
torch
.
Tensor
)
->
Tuple
[
Optional
[
torch
.
Tensor
]
,
Optional
[
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]
]
:
"""Forward of RowParallelLinear
"""Forward of RowParallelLinear
Args:
Args:
...
@@ -890,6 +896,10 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -890,6 +896,10 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
else
:
else
:
output
=
output_
output
=
output_
output_bias
=
self
.
base_layer
.
bias
output_bias
=
self
.
base_layer
.
bias
if
not
self
.
base_layer
.
return_bias
:
return
output
return
output
,
output_bias
return
output
,
output_bias
@
property
@
property
...
@@ -906,8 +916,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -906,8 +916,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list
:
List
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
],
model_config
:
Optional
[
PretrainedConfig
],
)
->
bool
:
)
->
bool
:
source_layer
=
cls
.
get_source_layer
(
source_layer
)
return
type
(
source_layer
)
is
RowParallelLinear
return
source_layer
is
RowParallelLinear
class
LogitsProcessorWithLoRA
(
BaseLayerWithLoRA
):
class
LogitsProcessorWithLoRA
(
BaseLayerWithLoRA
):
...
...
vllm/lora/utils.py
View file @
e17e4488
...
@@ -67,16 +67,6 @@ def from_layer(layer: nn.Module,
...
@@ -67,16 +67,6 @@ def from_layer(layer: nn.Module,
packed_modules_list
=
packed_modules_list
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
):
model_config
=
model_config
):
instance_layer
=
lora_cls
(
layer
)
instance_layer
=
lora_cls
(
layer
)
if
layer
.
__class__
.
__name__
==
"HFCompatibleLinear"
:
# HACK: Make the forward method compatible with the original
# forward method of the instance_layer.
original_forward
=
instance_layer
.
forward
def
new_forward
(
input
):
input
=
input
.
squeeze
(
0
)
return
original_forward
(
input
)[
0
]
# noqa: B023
instance_layer
.
forward
=
new_forward
instance_layer
.
create_lora_weights
(
max_loras
,
lora_config
,
instance_layer
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
model_config
)
return
instance_layer
return
instance_layer
...
...
vllm/model_executor/layers/linear.py
View file @
e17e4488
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
itertools
import
itertools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Optional
from
typing
import
Optional
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -152,6 +152,7 @@ class LinearBase(torch.nn.Module):
...
@@ -152,6 +152,7 @@ class LinearBase(torch.nn.Module):
skip_bias_add: If true, skip adding bias but instead return it.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
quant_config: Quantization configure.
return_bias: If true, return bias together with outputs in forward pass.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -162,6 +163,8 @@ class LinearBase(torch.nn.Module):
...
@@ -162,6 +163,8 @@ class LinearBase(torch.nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -178,9 +181,11 @@ class LinearBase(torch.nn.Module):
...
@@ -178,9 +181,11 @@ class LinearBase(torch.nn.Module):
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
prefix
=
prefix
)
self
.
return_bias
=
return_bias
def
forward
(
self
,
def
forward
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]:
self
,
x
:
torch
.
Tensor
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -198,20 +203,25 @@ class ReplicatedLinear(LinearBase):
...
@@ -198,20 +203,25 @@ class ReplicatedLinear(LinearBase):
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
input_size
:
int
,
self
,
output_size
:
int
,
input_size
:
int
,
bias
:
bool
=
True
,
output_size
:
int
,
skip_bias_add
:
bool
=
False
,
bias
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
skip_bias_add
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
prefix
:
str
=
""
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
):
super
().
__init__
(
input_size
,
super
().
__init__
(
input_size
,
output_size
,
output_size
,
skip_bias_add
,
skip_bias_add
,
params_dtype
,
params_dtype
,
quant_config
,
quant_config
,
prefix
=
prefix
)
prefix
=
prefix
,
return_bias
=
return_bias
)
# All the linear layer supports quant method.
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -254,12 +264,15 @@ class ReplicatedLinear(LinearBase):
...
@@ -254,12 +264,15 @@ class ReplicatedLinear(LinearBase):
f
"to a parameter of size
{
param
.
size
()
}
"
)
f
"to a parameter of size
{
param
.
size
()
}
"
)
param
.
data
.
copy_
(
loaded_weight
)
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
def
forward
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]:
self
,
x
:
torch
.
Tensor
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
@@ -293,16 +306,20 @@ class ColumnParallelLinear(LinearBase):
...
@@ -293,16 +306,20 @@ class ColumnParallelLinear(LinearBase):
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
input_size
:
int
,
self
,
output_size
:
int
,
input_size
:
int
,
bias
:
bool
=
True
,
output_size
:
int
,
gather_output
:
bool
=
False
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
gather_output
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
skip_bias_add
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
):
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
input_size
self
.
input_size_per_partition
=
input_size
...
@@ -315,8 +332,13 @@ class ColumnParallelLinear(LinearBase):
...
@@ -315,8 +332,13 @@ class ColumnParallelLinear(LinearBase):
for
output_size
in
self
.
output_sizes
for
output_size
in
self
.
output_sizes
]
]
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
quant_config
,
prefix
)
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
,
return_bias
=
return_bias
)
self
.
gather_output
=
gather_output
self
.
gather_output
=
gather_output
...
@@ -393,7 +415,9 @@ class ColumnParallelLinear(LinearBase):
...
@@ -393,7 +415,9 @@ class ColumnParallelLinear(LinearBase):
loaded_weight
=
loaded_weight
.
reshape
(
1
)
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
self
,
input_
)
->
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]:
def
forward
(
self
,
input_
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
# Matrix multiply.
...
@@ -405,6 +429,8 @@ class ColumnParallelLinear(LinearBase):
...
@@ -405,6 +429,8 @@ class ColumnParallelLinear(LinearBase):
else
:
else
:
output
=
output_parallel
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
@@ -439,15 +465,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -439,15 +465,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
input_size
:
int
,
self
,
output_sizes
:
list
[
int
],
input_size
:
int
,
bias
:
bool
=
True
,
output_sizes
:
list
[
int
],
gather_output
:
bool
=
False
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
gather_output
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
skip_bias_add
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
prefix
:
str
=
""
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
):
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
...
@@ -458,7 +488,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -458,7 +488,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
prefix
=
prefix
,
return_bias
=
return_bias
)
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -711,16 +742,20 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -711,16 +742,20 @@ class QKVParallelLinear(ColumnParallelLinear):
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
hidden_size
:
int
,
self
,
head_size
:
int
,
hidden_size
:
int
,
total_num_heads
:
int
,
head_size
:
int
,
total_num_kv_heads
:
Optional
[
int
]
=
None
,
total_num_heads
:
int
,
bias
:
bool
=
True
,
total_num_kv_heads
:
Optional
[
int
]
=
None
,
skip_bias_add
:
bool
=
False
,
bias
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
skip_bias_add
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
prefix
:
str
=
""
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
self
.
total_num_heads
=
total_num_heads
...
@@ -753,7 +788,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -753,7 +788,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
prefix
=
prefix
,
return_bias
=
return_bias
)
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
shard_offset_mapping
=
{
shard_offset_mapping
=
{
...
@@ -1048,16 +1084,20 @@ class RowParallelLinear(LinearBase):
...
@@ -1048,16 +1084,20 @@ class RowParallelLinear(LinearBase):
quant_config: Quantization configure.
quant_config: Quantization configure.
"""
"""
def
__init__
(
self
,
def
__init__
(
input_size
:
int
,
self
,
output_size
:
int
,
input_size
:
int
,
bias
:
bool
=
True
,
output_size
:
int
,
input_is_parallel
:
bool
=
True
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
input_is_parallel
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
skip_bias_add
:
bool
=
False
,
reduce_results
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
):
# Divide the weight matrix along the first dimension.
# Divide the weight matrix along the first dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -1065,8 +1105,13 @@ class RowParallelLinear(LinearBase):
...
@@ -1065,8 +1105,13 @@ class RowParallelLinear(LinearBase):
self
.
output_size_per_partition
=
output_size
self
.
output_size_per_partition
=
output_size
self
.
output_partition_sizes
=
[
output_size
]
self
.
output_partition_sizes
=
[
output_size
]
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
quant_config
,
prefix
)
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
,
return_bias
=
return_bias
)
self
.
input_is_parallel
=
input_is_parallel
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
...
@@ -1145,7 +1190,9 @@ class RowParallelLinear(LinearBase):
...
@@ -1145,7 +1190,9 @@ class RowParallelLinear(LinearBase):
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
self
,
input_
)
->
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]:
def
forward
(
self
,
input_
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
...
@@ -1169,6 +1216,8 @@ class RowParallelLinear(LinearBase):
...
@@ -1169,6 +1216,8 @@ class RowParallelLinear(LinearBase):
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
vllm/model_executor/models/transformers.py
View file @
e17e4488
...
@@ -96,23 +96,12 @@ def replace_linear_class(
...
@@ -96,23 +96,12 @@ def replace_linear_class(
"rowwise"
:
RowParallelLinear
,
"rowwise"
:
RowParallelLinear
,
}.
get
(
style
,
ReplicatedLinear
)
}.
get
(
style
,
ReplicatedLinear
)
class
HFCompatibleLinear
(
vllm_linear_cls
):
return
vllm_linear_cls
(
"""
Wrapper class that removes `output_bias` from returned output.
"""
# NOTE: The LoRA layer needs to use `parent_cls`.
@
property
def
parent_cls
(
self
)
->
type
:
return
vllm_linear_cls
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
super
().
forward
(
input
)[
0
]
return
HFCompatibleLinear
(
input_size
=
linear
.
in_features
,
input_size
=
linear
.
in_features
,
output_size
=
linear
.
out_features
,
output_size
=
linear
.
out_features
,
bias
=
linear
.
bias
is
not
None
,
bias
=
linear
.
bias
is
not
None
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
return_bias
=
False
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment