Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b26b4cd0
Unverified
Commit
b26b4cd0
authored
Dec 07, 2024
by
Isotr0py
Committed by
GitHub
Dec 07, 2024
Browse files
[Misc][LoRA] Refactor and clean MergedQKVParallelLinearWithLora implementation (#10958)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
f13cf9ad
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
263 deletions
+60
-263
vllm/lora/layers.py
vllm/lora/layers.py
+60
-263
No files found.
vllm/lora/layers.py
View file @
b26b4cd0
...
@@ -542,10 +542,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -542,10 +542,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
Both slices must have the same size.
Both slices must have the same size.
"""
"""
def
__init__
(
self
,
base_layer
:
MergedColumnParallelLinear
)
->
None
:
def
__init__
(
self
,
base_layer
:
Union
[
MergedColumnParallelLinear
,
QKVParallelLinear
])
->
None
:
super
().
__init__
(
base_layer
)
super
().
__init__
(
base_layer
)
# There are two LoRA layers
# There are two LoRA layers
self
.
n_slices
=
len
(
self
.
base_layer
.
output_sizes
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size
output_sizes
=
self
.
base_layer
.
output_sizes
self
.
output_slices
=
tuple
(
divide
(
output_size
,
self
.
tp_size
)
for
output_size
in
output_sizes
)
self
.
n_slices
=
len
(
self
.
output_slices
)
self
.
output_ids
=
(
self
.
tp_rank
,
)
*
self
.
n_slices
def
create_lora_weights
(
def
create_lora_weights
(
self
,
self
,
...
@@ -559,15 +569,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -559,15 +569,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""
"""
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
if
not
(
len
(
self
.
base_layer
.
output_sizes
)
==
self
.
n_slices
==
2
and
self
.
base_layer
.
output_sizes
[
0
]
==
self
.
base_layer
.
output_sizes
[
1
]):
raise
ValueError
(
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size."
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
lora_a_output_size_per_partition
=
(
lora_a_output_size_per_partition
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
...
@@ -585,22 +586,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -585,22 +586,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
self
.
output_size
//
2
,
output_size
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
n
_slices
)
)
)
for
output_size
in
self
.
output
_slices
)
if
lora_config
.
bias_enabled
:
if
lora_config
.
bias_enabled
:
self
.
lora_bias_stacked
=
tuple
(
self
.
lora_bias_stacked
=
tuple
(
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
self
.
output_size
//
2
,
output_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
n_slices
))
)
for
output_size
in
self
.
output_slices
)
self
.
output_dim
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
self
.
output_slices
=
(
self
.
output_dim
,
self
.
output_dim
)
def
slice_lora_a
(
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
...
@@ -610,27 +609,21 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -610,27 +609,21 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def
slice_lora_b
(
def
slice_lora_b
(
self
,
lora_b
:
List
[
Union
[
torch
.
Tensor
,
None
]]
self
,
lora_b
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
#NOTE: lora_b contains 2 subloras, and each sublora could be None.
for
i
,
(
shard_id
,
shard_size
)
in
enumerate
(
shard_size
=
self
.
output_dim
zip
(
self
.
output_ids
,
self
.
output_slices
)):
start_idx
=
self
.
tp_rank
*
shard_size
if
(
lora_b_i
:
=
lora_b
[
i
])
is
not
None
:
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
lora_b
[
i
]
=
lora_b_i
[:,
shard_size
*
shard_id
:
shard_size
*
lora_b
=
[
(
shard_id
+
1
)]
lora_b
[
0
][:,
start_idx
:
end_idx
]
if
lora_b
[
0
]
is
not
None
else
None
,
lora_b
[
1
][:,
start_idx
:
end_idx
]
if
lora_b
[
1
]
is
not
None
else
None
,
]
return
lora_b
return
lora_b
def
slice_bias
(
def
slice_bias
(
self
,
bias
:
List
[
Union
[
torch
.
Tensor
,
self
,
bias
:
List
[
Union
[
torch
.
Tensor
,
None
]])
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
None
]])
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
# NOTE : each bias could be None.
for
i
,
(
shard_id
,
shard_size
)
in
enumerate
(
shard_size
=
self
.
output_dim
zip
(
self
.
output_ids
,
self
.
output_slices
)):
start_idx
=
self
.
tp_rank
*
shard_size
if
(
bias_i
:
=
bias
[
i
])
is
not
None
:
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
bias
[
i
]
=
bias_i
[
shard_size
*
shard_id
:
shard_size
*
bias
=
[
(
shard_id
+
1
)]
bias
[
0
][
start_idx
:
end_idx
]
if
bias
[
0
]
is
not
None
else
None
,
bias
[
1
][
start_idx
:
end_idx
]
if
bias
[
1
]
is
not
None
else
None
]
return
bias
return
bias
def
set_lora
(
def
set_lora
(
...
@@ -649,30 +642,25 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -649,30 +642,25 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if
lora_bias
is
not
None
:
if
lora_bias
is
not
None
:
lora_bias
=
self
.
slice_bias
(
lora_bias
)
lora_bias
=
self
.
slice_bias
(
lora_bias
)
if
lora_a
[
0
]
is
not
None
:
for
i
in
range
(
self
.
n_slices
):
self
.
lora_a_stacked
[
0
][
if
(
lora_a_i
:
=
lora_a
[
i
])
is
not
None
:
index
,
0
,
:
lora_a
[
0
].
shape
[
1
],
:
lora_a
[
0
].
shape
[
0
]].
copy_
(
self
.
lora_a_stacked
[
i
][
lora_a
[
0
].
T
,
non_blocking
=
True
)
index
,
0
,
:
lora_a_i
.
shape
[
1
],
:
lora_a_i
.
shape
[
0
]].
copy_
(
self
.
lora_b_stacked
[
0
][
lora_a_i
.
T
,
non_blocking
=
True
)
index
,
0
,
:
lora_b
[
0
].
shape
[
1
],
:
lora_b
[
0
].
shape
[
0
]].
copy_
(
if
(
lora_b_i
:
=
lora_b
[
i
])
is
not
None
:
lora_b
[
0
].
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
i
][
if
lora_bias
is
not
None
and
lora_bias
[
0
]
is
not
None
:
index
,
0
,
:
lora_b_i
.
shape
[
1
],
:
lora_b_i
.
shape
[
0
]].
copy_
(
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
lora_b_i
.
T
,
non_blocking
=
True
)
self
.
lora_bias_stacked
)
self
.
lora_bias_stacked
[
0
][
index
,
0
,
:
lora_bias
[
0
].
shape
[
0
]].
copy_
(
if
lora_bias
is
not
None
:
lora_bias
[
0
].
T
,
non_blocking
=
True
)
if
lora_a
[
1
]
is
not
None
:
self
.
lora_a_stacked
[
1
][
index
,
0
,
:
lora_a
[
1
].
shape
[
1
],
:
lora_a
[
1
].
shape
[
0
]].
copy_
(
lora_a
[
1
].
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
1
][
index
,
0
,
:
lora_b
[
1
].
shape
[
1
],
:
lora_b
[
1
].
shape
[
0
]].
copy_
(
lora_b
[
1
].
T
,
non_blocking
=
True
)
if
lora_bias
is
not
None
and
lora_bias
[
1
]
is
not
None
:
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
)
self
.
lora_bias_stacked
)
self
.
lora_bias_stacked
[
1
][
index
,
0
,
:
lora_bias
[
1
].
shape
[
0
]].
copy_
(
for
i
in
range
(
self
.
n_slices
):
lora_bias
[
1
].
T
,
non_blocking
=
True
)
if
(
lora_bias_i
:
=
lora_bias
[
i
])
is
not
None
:
self
.
lora_bias_stacked
[
i
][
index
,
0
,
:
lora_bias_i
.
shape
[
0
]].
copy_
(
lora_bias_i
.
T
,
non_blocking
=
True
)
@
classmethod
@
classmethod
@
_not_fully_sharded_can_replace
@
_not_fully_sharded_can_replace
...
@@ -755,8 +743,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -755,8 +743,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
packed_modules_list
)
==
1
packed_modules_list
)
==
1
class
MergedQKVParallelLinearWithLora
(
ColumnParallelLinearWithLoRA
):
class
MergedQKVParallelLinearWithLora
(
Merged
ColumnParallelLinearWithLoRA
):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
"""
Merged
ColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
(q_proj + k_proj + v_proj -> qkv_proj).
...
@@ -773,22 +761,6 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -773,22 +761,6 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
"""
The main reason for overloading this function is to handle inconsistent
weight dimensions in qkv lora.
"""
self
.
lora_config
=
lora_config
if
not
(
len
(
self
.
base_layer
.
output_sizes
)
==
self
.
n_slices
==
3
):
raise
ValueError
(
"LoRAColumnParallelLinear3Slice requires 3 slices."
)
self
.
q_proj_shard_size
=
(
self
.
base_layer
.
num_heads
*
self
.
q_proj_shard_size
=
(
self
.
base_layer
.
num_heads
*
self
.
base_layer
.
head_size
)
self
.
base_layer
.
head_size
)
self
.
kv_proj_shard_size
=
(
self
.
base_layer
.
num_kv_heads
*
self
.
kv_proj_shard_size
=
(
self
.
base_layer
.
num_kv_heads
*
...
@@ -796,203 +768,28 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -796,203 +768,28 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
q_shard_id
=
self
.
tp_rank
self
.
q_shard_id
=
self
.
tp_rank
self
.
kv_shard_id
=
self
.
tp_rank
//
self
.
base_layer
.
num_kv_head_replicas
self
.
kv_shard_id
=
self
.
tp_rank
//
self
.
base_layer
.
num_kv_head_replicas
lora_a_output_size_per_partition
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
# q, k, v
self
.
lora_a_stacked
=
(
torch
.
zeros
(
max_loras
,
1
,
lora_a_output_size_per_partition
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
torch
.
zeros
(
max_loras
,
1
,
lora_a_output_size_per_partition
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
torch
.
zeros
(
max_loras
,
1
,
lora_a_output_size_per_partition
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
)
self
.
lora_b_stacked
=
(
torch
.
zeros
(
max_loras
,
1
,
self
.
q_proj_shard_size
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
torch
.
zeros
(
max_loras
,
1
,
self
.
kv_proj_shard_size
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
torch
.
zeros
(
max_loras
,
1
,
self
.
kv_proj_shard_size
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
)
if
lora_config
.
bias_enabled
:
self
.
lora_bias_stacked
=
(
torch
.
zeros
(
max_loras
,
1
,
self
.
q_proj_shard_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
torch
.
zeros
(
max_loras
,
1
,
self
.
kv_proj_shard_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
torch
.
zeros
(
max_loras
,
1
,
self
.
kv_proj_shard_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
)
self
.
output_slices
=
(
self
.
output_slices
=
(
self
.
q_proj_shard_size
,
self
.
q_proj_shard_size
,
self
.
kv_proj_shard_size
,
self
.
kv_proj_shard_size
,
self
.
kv_proj_shard_size
,
self
.
kv_proj_shard_size
,
)
)
self
.
packed_indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
output_ids
=
(
self
.
standard_indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
q_shard_id
,
# lazily initialized.
self
.
kv_shard_id
,
self
.
indices
:
torch
.
Tensor
self
.
kv_shard_id
,
self
.
indices_len
:
List
[
int
]
)
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
lora_b_q
,
lora_b_k
,
lora_b_v
=
None
,
None
,
None
if
lora_b
[
0
]
is
not
None
:
lora_b_q
=
lora_b
[
0
][:,
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
(
self
.
q_shard_id
+
1
),
]
if
lora_b
[
1
]
is
not
None
:
lora_b_k
=
lora_b
[
1
][:,
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
),
]
if
lora_b
[
2
]
is
not
None
:
lora_b_v
=
lora_b
[
2
][:,
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
),
]
lora_b
=
[
lora_b_q
,
lora_b_k
,
lora_b_v
]
return
lora_b
def
slice_bias
(
self
,
bias
:
List
[
Union
[
torch
.
Tensor
,
None
]])
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
bias_q
,
bias_k
,
bias_v
=
bias
if
bias_q
is
not
None
:
bias_q
=
bias_q
[
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
(
self
.
q_shard_id
+
1
)]
if
bias_k
is
not
None
:
bias_k
=
bias_k
[
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
if
bias_v
is
not
None
:
bias_v
=
bias_v
[
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
bias
=
[
bias_q
,
bias_k
,
bias_v
]
return
bias
def
set_lora
(
def
create_lora_weights
(
self
,
self
,
index
:
int
,
max_loras
:
int
,
lora_a
:
torch
.
Tensor
,
lora_config
:
LoRAConfig
,
lora_b
:
torch
.
Tensor
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
)
->
None
:
lora_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
"""
):
The main reason for overloading this function is to handle inconsistent
self
.
reset_lora
(
index
)
weight dimensions in qkv lora.
"""
if
self
.
tp_size
>
1
:
super
().
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
if
lora_bias
is
not
None
:
lora_bias
=
self
.
slice_bias
(
lora_bias
)
if
lora_b
[
0
]
is
not
None
:
lora_b_q
=
lora_b
[
0
]
self
.
lora_b_stacked
[
0
][
index
,
0
,
:
lora_b_q
.
shape
[
1
],
:
lora_b_q
.
shape
[
0
]].
copy_
(
lora_b_q
.
T
,
non_blocking
=
True
)
if
lora_b
[
1
]
is
not
None
:
lora_b_k
=
lora_b
[
1
]
self
.
lora_b_stacked
[
1
][
index
,
0
,
:
lora_b_k
.
shape
[
1
],
:
lora_b_k
.
shape
[
0
]].
copy_
(
lora_b_k
.
T
,
non_blocking
=
True
)
if
lora_b
[
2
]
is
not
None
:
lora_b_v
=
lora_b
[
2
]
self
.
lora_b_stacked
[
2
][
index
,
0
,
:
lora_b_v
.
shape
[
1
],
:
lora_b_v
.
shape
[
0
]].
copy_
(
lora_b_v
.
T
,
non_blocking
=
True
)
if
lora_a
[
0
]
is
not
None
:
self
.
lora_a_stacked
[
0
][
index
,
0
,
:
lora_a
[
0
].
shape
[
1
],
:
lora_a
[
0
].
shape
[
0
]].
copy_
(
lora_a
[
0
].
T
,
non_blocking
=
True
)
if
lora_a
[
1
]
is
not
None
:
self
.
lora_a_stacked
[
1
][
index
,
0
,
:
lora_a
[
1
].
shape
[
1
],
:
lora_a
[
1
].
shape
[
0
]].
copy_
(
lora_a
[
1
].
T
,
non_blocking
=
True
)
if
lora_a
[
2
]
is
not
None
:
self
.
lora_a_stacked
[
2
][
index
,
0
,
:
lora_a
[
2
].
shape
[
1
],
:
lora_a
[
2
].
shape
[
0
]].
copy_
(
lora_a
[
2
].
T
,
non_blocking
=
True
)
if
lora_bias
is
not
None
:
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
)
if
lora_bias
[
0
]
is
not
None
:
self
.
lora_bias_stacked
[
0
][
index
,
0
,
:
lora_bias
[
0
].
shape
[
0
]].
copy_
(
lora_bias
[
0
].
T
,
non_blocking
=
True
)
if
lora_bias
[
1
]
is
not
None
:
self
.
lora_bias_stacked
[
1
][
index
,
0
,
:
lora_bias
[
1
].
shape
[
0
]].
copy_
(
lora_bias
[
1
].
T
,
non_blocking
=
True
)
if
lora_bias
[
2
]
is
not
None
:
self
.
lora_bias_stacked
[
2
][
index
,
0
,
:
lora_bias
[
2
].
shape
[
0
]].
copy_
(
lora_bias
[
2
].
T
,
non_blocking
=
True
)
@
classmethod
@
classmethod
@
_not_fully_sharded_can_replace
@
_not_fully_sharded_can_replace
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment