Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
571da8fc
Unverified
Commit
571da8fc
authored
Dec 05, 2024
by
Jee Jee Li
Committed by
GitHub
Dec 05, 2024
Browse files
[Misc][LoRA] Clean up the function interface of Punica (#10917)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
39c89e71
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
497 additions
and
631 deletions
+497
-631
tests/lora/test_layers.py
tests/lora/test_layers.py
+32
-10
vllm/lora/fully_sharded_layers.py
vllm/lora/fully_sharded_layers.py
+75
-100
vllm/lora/layers.py
vllm/lora/layers.py
+193
-345
vllm/lora/models.py
vllm/lora/models.py
+4
-4
vllm/lora/punica.py
vllm/lora/punica.py
+193
-172
No files found.
tests/lora/test_layers.py
View file @
571da8fc
...
...
@@ -565,7 +565,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_linear_replicated
(
dist_init
,
num_loras
,
device
,
stage
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
def
test_linear_replicated
(
dist_init
,
num_loras
,
device
,
stage
,
bias_enabled
)
->
None
:
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
...
...
@@ -573,7 +575,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
,
bias_enabled
=
bias_enabled
)
def
create_random_linear_replicated_layer
():
...
...
@@ -585,7 +588,12 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
lora_linear
=
ReplicatedLinearWithLoRA
(
linear
)
lora_linear
.
create_lora_weights
(
max_loras
,
lora_config
)
assert
(
lora_linear
.
n_slices
==
len
(
lora_linear
.
lora_a_stacked
)
==
len
(
lora_linear
.
lora_b_stacked
)
==
1
)
if
bias_enabled
:
assert
len
(
lora_linear
.
lora_bias_stacked
)
==
lora_linear
.
n_slices
else
:
assert
lora_linear
.
lora_bias_stacked
is
None
return
linear
,
lora_linear
for
i
in
range
(
10
):
...
...
@@ -669,8 +677,9 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
def
test_linear_parallel
(
dist_init
,
num_loras
,
orientation
,
fully_shard
,
device
,
stage
)
->
None
:
device
,
stage
,
bias_enabled
)
->
None
:
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
...
...
@@ -679,7 +688,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
fully_sharded_loras
=
fully_shard
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
,
bias_enabled
=
bias_enabled
)
def
create_random_linear_parallel_layer
():
if
orientation
==
"row"
:
...
...
@@ -700,7 +710,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
if
not
fully_shard
else
ColumnParallelLinearWithShardedLoRA
(
linear
))
lora_linear
.
create_lora_weights
(
max_loras
,
lora_config
)
assert
(
lora_linear
.
n_slices
==
len
(
lora_linear
.
lora_a_stacked
)
==
len
(
lora_linear
.
lora_b_stacked
)
==
1
)
if
bias_enabled
:
assert
len
(
lora_linear
.
lora_bias_stacked
)
==
lora_linear
.
n_slices
else
:
assert
lora_linear
.
lora_bias_stacked
is
None
return
linear
,
lora_linear
for
i
in
range
(
10
):
...
...
@@ -784,8 +799,9 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
def
test_column_parallel_packed
(
dist_init
,
num_loras
,
repeats
,
fully_shard
,
device
,
stage
)
->
None
:
device
,
stage
,
bias_enabled
)
->
None
:
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
...
...
@@ -794,7 +810,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
fully_sharded_loras
=
fully_shard
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
,
bias_enabled
=
bias_enabled
)
def
create_column_parallel_packed_layer
():
if
repeats
==
2
:
...
...
@@ -832,10 +849,16 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
num_key_value_heads
=
32
num_attention_heads
=
32
n_slices
=
repeats
lora_linear
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
=
FakeConfig
())
assert
(
lora_linear
.
n_slices
==
len
(
lora_linear
.
lora_a_stacked
)
==
len
(
lora_linear
.
lora_b_stacked
)
==
n_slices
)
if
bias_enabled
:
assert
len
(
lora_linear
.
lora_bias_stacked
)
==
lora_linear
.
n_slices
else
:
assert
lora_linear
.
lora_bias_stacked
is
None
return
linear
,
lora_linear
for
i
in
range
(
10
):
...
...
@@ -911,7 +934,6 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
512
,
lora_config
.
lora_extra_vocab_size
,
)
# lora_linear.set_mapping(*mapping_info)
lora_result
=
lora_linear
(
torch
.
cat
(
inputs
))[
0
]
expected_result
=
linear
(
torch
.
cat
(
inputs
))[
0
]
...
...
vllm/lora/fully_sharded_layers.py
View file @
571da8fc
# pylint: disable=unused-argument
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
,
cast
import
torch
import
torch.nn
as
nn
...
...
@@ -32,6 +32,44 @@ def _fully_sharded_can_replace(can_replace):
return
dec
def
_mcp_apply
(
x
,
bias
,
layer
:
ColumnParallelLinearWithLoRA
):
"""
For `ColumnParallelLinearWithLoRA` or classes that inherit from
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
"""
assert
(
layer
.
n_slices
==
len
(
layer
.
lora_a_stacked
)
==
len
(
layer
.
lora_b_stacked
)
==
len
(
layer
.
output_slices
))
if
layer
.
lora_bias_stacked
is
not
None
:
assert
layer
.
n_slices
==
len
(
layer
.
lora_bias_stacked
)
output
=
layer
.
base_layer
.
quant_method
.
apply
(
layer
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
# Since communication is needed, the buffer is directly initialized as a
# tensor rather than a tuple of tensor.
buffers
=
torch
.
zeros
(
(
layer
.
n_slices
,
x
.
shape
[
0
],
layer
.
lora_a_stacked
[
0
].
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
layer
.
punica_wrapper
.
add_shrink
(
buffers
,
x
,
layer
.
lora_a_stacked
,
1.0
)
buffers
=
tensor_model_parallel_all_gather
(
buffers
)
layer
.
punica_wrapper
.
add_expand
(
output
,
buffers
,
layer
.
lora_b_stacked
,
layer
.
lora_bias_stacked
,
layer
.
output_slices
,
offset_start
=
0
,
add_input
=
True
)
output
=
output
.
view
(
*
out_orig_shape
)
# now have column partitioned and packed output
return
output
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
...
...
@@ -51,34 +89,15 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
# gather operation.
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
lora_a_stacked
.
shape
[
2
]
shard_size
=
self
.
lora_a_stacked
[
0
]
.
shape
[
2
]
start_idx
=
tp_rank
*
shard_size
lora_a
=
lora_a
[:,
start_idx
:
start_idx
+
shard_size
]
return
lora_a
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffer
=
torch
.
zeros
(
(
x
.
shape
[
0
],
self
.
lora_a_stacked
.
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
self
.
punica_wrapper
.
add_shrink
(
buffer
,
x
,
self
.
lora_a_stacked
,
1.0
)
buffer
=
tensor_model_parallel_all_gather
(
buffer
)
self
.
punica_wrapper
.
add_expand
(
output
,
buffer
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
add_input
=
True
)
# now have column partitioned output
output
=
output
.
view
(
*
out_orig_shape
)
return
output
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
_mcp_apply
(
x
,
bias
,
self
)
@
classmethod
@
_fully_sharded_can_replace
...
...
@@ -99,46 +118,6 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
)
def
_mcp_apply
(
x
,
bias
,
layer
:
QKVParallelLinearWithLora
):
"""
MergedColumnParallelLinearWithShardedLoRA and
MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n
=
len
(
layer
.
lora_a_stacked
)
output
=
layer
.
base_layer
.
quant_method
.
apply
(
layer
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffers
=
torch
.
zeros
(
(
n
,
x
.
shape
[
0
],
layer
.
lora_a_stacked
[
0
].
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
for
idx
in
range
(
n
):
layer
.
punica_wrapper
.
add_shrink
(
buffers
[
idx
],
x
,
layer
.
lora_a_stacked
[
idx
],
1.0
)
buffers
=
tensor_model_parallel_all_gather
(
buffers
)
layer
.
punica_wrapper
.
add_expand_packed_nslice
(
output
,
buffers
,
layer
.
lora_b_stacked
,
layer
.
bias_stacked
,
1.0
,
layer
.
output_slices
,
)
output
=
output
.
view
(
*
out_orig_shape
)
# now have column partitioned and packed output
return
output
class
MergedColumnParallelLinearWithShardedLoRA
(
MergedColumnParallelLinearWithLoRA
):
"""
...
...
@@ -162,8 +141,9 @@ class MergedColumnParallelLinearWithShardedLoRA(
]
return
lora_a
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
_mcp_apply
(
x
,
bias
,
self
)
@
classmethod
...
...
@@ -195,31 +175,15 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
lora_a_stacked
.
shape
[
2
]
shard_size
=
self
.
lora_a_stacked
[
0
]
.
shape
[
2
]
start_idx
=
tp_rank
*
shard_size
lora_a
=
lora_a
[:,
start_idx
:
start_idx
+
shard_size
]
return
lora_a
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffer
=
torch
.
zeros
((
x
.
shape
[
0
],
self
.
lora_a_stacked
.
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
self
.
punica_wrapper
.
add_shrink
(
buffer
,
x
,
self
.
lora_a_stacked
,
1.0
)
buffer
=
tensor_model_parallel_all_gather
(
buffer
)
self
.
punica_wrapper
.
add_expand
(
output
,
buffer
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
add_input
=
True
)
# now have column partitioned output
output
=
output
.
view
(
*
out_orig_shape
)
return
output
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
_mcp_apply
(
x
,
bias
,
self
)
@
classmethod
@
_fully_sharded_can_replace
...
...
@@ -260,8 +224,9 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
]
return
lora_a
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
_mcp_apply
(
x
,
bias
,
self
)
@
classmethod
...
...
@@ -294,7 +259,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shard_size
=
self
.
lora_b_stacked
.
shape
[
2
]
shard_size
=
self
.
lora_b_stacked
[
0
]
.
shape
[
2
]
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
...
...
@@ -303,20 +268,24 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
def
slice_bias
(
self
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
bias
is
None
:
return
bias
shard_size
=
self
.
bias_stacked
.
shape
[
2
]
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
)
shard_size
=
self
.
lora_bias_stacked
[
0
].
shape
[
2
]
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
bias
=
bias
[
start_idx
:
end_idx
]
return
bias
def
apply
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffer
=
torch
.
zeros
(
(
x
.
shape
[
0
],
self
.
lora_a_stacked
.
shape
[
2
]),
(
self
.
n_slices
,
x
.
shape
[
0
],
self
.
lora_a_stacked
[
0
]
.
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
...
...
@@ -330,12 +299,18 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size
=
self
.
lora_b_stacked
.
shape
[
2
]
start_idx
=
self
.
tp_rank
*
shard_size
self
.
punica_wrapper
.
add_expand_slice
(
output
,
buffer
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
start_idx
,
shard_size
)
# NOTE offset are based on the rank.
shard_size
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
offset_start
=
self
.
tp_rank
*
shard_size
self
.
punica_wrapper
.
add_expand
(
output
,
buffer
,
self
.
lora_b_stacked
,
self
.
lora_bias_stacked
,
self
.
output_slices
,
offset_start
=
offset_start
,
add_input
=
True
,
)
output
=
output
.
view
(
*
out_orig_shape
)
return
output
...
...
vllm/lora/layers.py
View file @
571da8fc
# pylint: disable=unused-argument
import
math
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
cast
import
torch
import
torch.nn
as
nn
...
...
@@ -18,11 +18,14 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_gather
)
from
vllm.distributed.utils
import
divide
from
vllm.lora.punica
import
PunicaWrapper
# yapf: disable
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
# yapf: enable
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
(
LinearScalingRotaryEmbedding
,
RotaryEmbedding
)
...
...
@@ -249,13 +252,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
full_lora_a_embeddings
.
shape
[
1
],
-
1
,
)
# Embedding layer only need expand op
self
.
punica_wrapper
.
add_expand
(
full_output
,
full_lora_a_embeddings
,
self
.
lora_b_stacked
,
bias_all
=
None
,
add_input
=
True
)
self
.
punica_wrapper
.
add_lora_embedding
(
full_output
,
full_lora_a_embeddings
,
self
.
lora_b_stacked
,
add_input
=
True
)
return
full_output
.
view_as
(
full_output_org
)
@
classmethod
...
...
@@ -269,14 +269,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
return
type
(
source_layer
)
is
VocabParallelEmbedding
class
ReplicatedLinea
rWithLoRA
(
BaseLayerWithLoRA
):
class
BaseLinearLaye
rWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
ReplicatedLinear
)
->
None
:
def
__init__
(
self
,
base_layer
:
LinearBase
)
:
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
input_size
=
self
.
base_layer
.
input_size
self
.
output_size
=
self
.
base_layer
.
output_size
self
.
device
=
_get_lora_device
(
self
.
base_layer
)
self
.
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]]
=
None
self
.
output_slices
:
Tuple
[
int
,
...]
self
.
tp_size
:
int
self
.
output_size
:
int
self
.
n_slices
:
int
def
create_lora_weights
(
self
,
...
...
@@ -285,39 +290,64 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
self
.
lora_config
=
lora_config
lora_a_output_size
=
lora_config
.
max_lora_rank
self
.
lora_a_stacked
=
torch
.
zeros
(
max_loras
,
1
,
lora_a_output_size
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
self
.
lora_b_stacked
=
torch
.
zeros
(
max_loras
,
1
,
self
.
output_size
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
if
lora_config
.
bias_enabled
:
self
.
bias_stacked
=
torch
.
zeros
(
#
if
isinstance
(
self
.
base_layer
,
ReplicatedLinear
):
lora_a_out_size
=
lora_config
.
max_lora_rank
lora_b_out_size
=
self
.
output_size
elif
isinstance
(
self
.
base_layer
,
ColumnParallelLinear
):
lora_a_out_size
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
lora_b_out_size
=
self
.
output_size
elif
isinstance
(
self
.
base_layer
,
RowParallelLinear
):
lora_a_out_size
=
lora_config
.
max_lora_rank
lora_b_out_size
=
(
self
.
output_size
if
not
lora_config
.
fully_sharded_loras
else
divide
(
self
.
output_size
,
self
.
tp_size
))
else
:
raise
NotImplementedError
self
.
lora_a_stacked
=
tuple
(
torch
.
zeros
(
max_loras
,
1
,
self
.
output_size
,
lora_a_out_size
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
else
:
self
.
bias_stacked
=
None
)
for
_
in
range
(
self
.
n_slices
))
self
.
lora_b_stacked
=
tuple
(
torch
.
zeros
(
max_loras
,
1
,
lora_b_out_size
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
n_slices
))
if
lora_config
.
bias_enabled
:
lora_bias_out_size
=
lora_b_out_size
self
.
lora_bias_stacked
=
tuple
(
torch
.
zeros
(
max_loras
,
1
,
lora_bias_out_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
n_slices
))
self
.
output_slices
=
(
self
.
lora_b_stacked
[
0
].
shape
[
2
],
)
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
if
self
.
lora_config
.
bias_enabled
:
self
.
bias_stacked
[
index
]
=
0
for
s_index
in
range
(
self
.
n_slices
):
self
.
lora_a_stacked
[
s_index
][
index
]
=
0
self
.
lora_b_stacked
[
s_index
][
index
]
=
0
if
self
.
lora_config
.
bias_enabled
:
# Make mypy happy
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
)
self
.
lora_bias_stacked
[
s_index
][
index
]
=
0
def
set_lora
(
self
,
...
...
@@ -325,29 +355,56 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
lora_
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
reset_lora
(
index
)
# Except for QKVParallelLinearWithLora and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
assert
(
len
(
self
.
lora_a_stacked
)
==
len
(
self
.
lora_b_stacked
)
==
self
.
n_slices
==
1
)
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
if
bias
is
not
None
:
self
.
bias_stacked
[
index
,
0
,
:
bias
.
shape
[
0
]].
copy_
(
bias
.
T
,
non_blocking
=
True
)
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
if
lora_bias
is
not
None
:
lora_bias
=
self
.
slice_bias
(
lora_bias
)
self
.
lora_a_stacked
[
0
][
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
0
][
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
if
lora_bias
is
not
None
:
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
)
assert
len
(
self
.
lora_bias_stacked
)
self
.
lora_bias_stacked
[
0
][
index
,
0
,
:
lora_bias
.
shape
[
0
]].
copy_
(
lora_bias
.
T
,
non_blocking
=
True
)
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
1.0
)
self
.
punica_wrapper
.
add_lora_linear
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
self
.
lora_bias_stacked
,
1.0
,
self
.
output_slices
)
return
output
class
ReplicatedLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
ReplicatedLinear
)
->
None
:
super
().
__init__
(
base_layer
,
)
# To ensure interface compatibility, set to 1 always.
self
.
tp_size
=
1
self
.
output_size
=
self
.
base_layer
.
output_size
self
.
n_slices
=
1
def
forward
(
self
,
input_
):
"""Forward of ReplicatedLinearWithLoRA
...
...
@@ -380,73 +437,26 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
return
type
(
source_layer
)
is
ReplicatedLinear
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
class
ColumnParallelLinearWithLoRA
(
BaseL
inearL
ayerWithLoRA
):
"""
LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism.
There are two types for the `base_layer`:
1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
"""
def
__init__
(
self
,
base_layer
:
ColumnParallelLinear
)
->
None
:
super
().
__init__
()
super
().
__init__
(
base_layer
)
# The base_layer type is ColumnParallelLinear or
# MergedColumnParallelLinear, their weight sharding logic is
# inconsistent when TP is greater than 1.
self
.
is_merged_col_linear
=
type
(
base_layer
)
is
MergedColumnParallelLinear
self
.
base_layer
=
base_layer
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size
=
self
.
base_layer
.
input_size
self
.
output_size
=
self
.
base_layer
.
output_size_per_partition
self
.
device
=
_get_lora_device
(
self
.
base_layer
)
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
self
.
lora_config
=
lora_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
lora_a_output_size_per_partition
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
self
.
lora_a_stacked
=
torch
.
zeros
(
max_loras
,
1
,
lora_a_output_size_per_partition
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
self
.
lora_b_stacked
=
torch
.
zeros
(
max_loras
,
1
,
self
.
output_size
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
if
lora_config
.
bias_enabled
:
self
.
bias_stacked
=
torch
.
zeros
(
max_loras
,
1
,
self
.
output_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
else
:
self
.
bias_stacked
=
None
self
.
output_dim
=
self
.
lora_b_stacked
.
shape
[
2
]
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
if
self
.
lora_config
.
bias_enabled
:
self
.
bias_stacked
[
index
]
=
0
# There is only one LoRA layer
self
.
n_slices
=
1
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
lora_a
...
...
@@ -485,40 +495,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
bias
=
bias
[
start_idx
:
end_idx
]
return
bias
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
bias
=
self
.
slice_bias
(
bias
)
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
if
bias
is
not
None
:
self
.
bias_stacked
[
index
,
0
,
:
bias
.
shape
[
0
]].
copy_
(
bias
.
T
,
non_blocking
=
True
)
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
1.0
)
return
output
def
forward
(
self
,
input_
):
"""Forward of ColumnParallelLinear
...
...
@@ -568,6 +544,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def
__init__
(
self
,
base_layer
:
MergedColumnParallelLinear
)
->
None
:
super
().
__init__
(
base_layer
)
# There are two LoRA layers
self
.
n_slices
=
len
(
self
.
base_layer
.
output_sizes
)
def
create_lora_weights
(
self
,
...
...
@@ -575,9 +553,13 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
"""
The main reason for overriding this function is to enhance code
maintainability.
"""
self
.
lora_config
=
lora_config
n_slices
=
2
if
not
(
len
(
self
.
base_layer
.
output_sizes
)
==
n_slices
if
not
(
len
(
self
.
base_layer
.
output_sizes
)
==
self
.
n_slices
==
2
and
self
.
base_layer
.
output_sizes
[
0
]
==
self
.
base_layer
.
output_sizes
[
1
]):
raise
ValueError
(
...
...
@@ -598,7 +580,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
n_slices
))
)
for
_
in
range
(
self
.
n_slices
))
self
.
lora_b_stacked
=
tuple
(
torch
.
zeros
(
max_loras
,
...
...
@@ -607,30 +589,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
n_slices
))
)
for
_
in
range
(
self
.
n_slices
))
if
lora_config
.
bias_enabled
:
self
.
bias_stacked
=
tuple
(
self
.
lora_
bias_stacked
=
tuple
(
torch
.
zeros
(
max_loras
,
1
,
self
.
output_size
//
2
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
n_slices
))
else
:
self
.
bias_stacked
=
None
)
for
_
in
range
(
self
.
n_slices
))
self
.
output_dim
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
self
.
output_slices
=
(
self
.
output_dim
,
self
.
output_dim
)
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
0
][
index
]
=
0
self
.
lora_a_stacked
[
1
][
index
]
=
0
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
1
][
index
]
=
0
if
self
.
lora_config
.
bias_enabled
:
self
.
bias_stacked
[
0
][
index
]
=
0
self
.
bias_stacked
[
1
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
...
...
@@ -668,15 +639,15 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
lora_
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
if
bias
is
not
None
:
bias
=
self
.
slice_bias
(
bias
)
if
lora_
bias
is
not
None
:
lora_
bias
=
self
.
slice_bias
(
lora_
bias
)
if
lora_a
[
0
]
is
not
None
:
self
.
lora_a_stacked
[
0
][
...
...
@@ -685,10 +656,11 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
lora_b_stacked
[
0
][
index
,
0
,
:
lora_b
[
0
].
shape
[
1
],
:
lora_b
[
0
].
shape
[
0
]].
copy_
(
lora_b
[
0
].
T
,
non_blocking
=
True
)
if
bias
is
not
None
and
bias
[
0
]
is
not
None
:
self
.
bias_stacked
[
0
][
index
,
0
,
:
bias
[
0
].
shape
[
0
]].
copy_
(
bias
[
0
].
T
,
non_blocking
=
True
)
if
lora_bias
is
not
None
and
lora_bias
[
0
]
is
not
None
:
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
)
self
.
lora_bias_stacked
[
0
][
index
,
0
,
:
lora_bias
[
0
].
shape
[
0
]].
copy_
(
lora_bias
[
0
].
T
,
non_blocking
=
True
)
if
lora_a
[
1
]
is
not
None
:
self
.
lora_a_stacked
[
1
][
index
,
0
,
:
lora_a
[
1
].
shape
[
1
],
:
lora_a
[
1
].
shape
[
0
]].
copy_
(
...
...
@@ -696,18 +668,11 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
lora_b_stacked
[
1
][
index
,
0
,
:
lora_b
[
1
].
shape
[
1
],
:
lora_b
[
1
].
shape
[
0
]].
copy_
(
lora_b
[
1
].
T
,
non_blocking
=
True
)
if
bias
is
not
None
and
bias
[
1
]
is
not
None
:
self
.
bias_stacked
[
1
][
index
,
0
,
:
bias
[
1
].
shape
[
0
]].
copy_
(
bias
[
1
].
T
,
non_blocking
=
True
)
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
punica_wrapper
.
add_lora_packed_nslice
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
1.0
,
(
self
.
output_dim
,
self
.
output_dim
))
return
output
if
lora_bias
is
not
None
and
lora_bias
[
1
]
is
not
None
:
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
)
self
.
lora_bias_stacked
[
1
][
index
,
0
,
:
lora_bias
[
1
].
shape
[
0
]].
copy_
(
lora_bias
[
1
].
T
,
non_blocking
=
True
)
@
classmethod
@
_not_fully_sharded_can_replace
...
...
@@ -737,7 +702,6 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def
__init__
(
self
,
base_layer
:
QKVParallelLinear
)
->
None
:
super
().
__init__
(
base_layer
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
q_proj_total_size
=
(
self
.
base_layer
.
total_num_heads
*
self
.
base_layer
.
head_size
)
self
.
q_proj_shard_size
=
(
self
.
base_layer
.
num_heads
*
...
...
@@ -746,6 +710,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
base_layer
.
head_size
)
self
.
kv_proj_total_size
=
(
self
.
base_layer
.
total_num_kv_heads
*
self
.
base_layer
.
head_size
)
# There is only one LoRA layer
self
.
n_slices
=
1
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tp_rank
=
get_tensor_model_parallel_rank
()
...
...
@@ -780,32 +746,6 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
bias
=
torch
.
cat
([
bias_q
,
bias_k
,
bias_v
],
dim
=
1
)
return
bias
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
if
bias
is
not
None
:
bias
=
self
.
slice_bias
(
bias
)
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
if
bias
is
not
None
:
self
.
bias_stacked
[
index
,
0
,
:
bias
.
shape
[
0
]].
copy_
(
bias
.
T
,
non_blocking
=
True
)
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
...
...
@@ -828,6 +768,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def
__init__
(
self
,
base_layer
:
QKVParallelLinear
)
->
None
:
super
().
__init__
(
base_layer
)
# There are three LoRA layer.
self
.
n_slices
=
len
(
self
.
base_layer
.
output_sizes
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
def
create_lora_weights
(
self
,
...
...
@@ -835,9 +779,16 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
"""
The main reason for overloading this function is to handle inconsistent
weight dimensions in qkv lora.
"""
self
.
lora_config
=
lora_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
if
not
(
len
(
self
.
base_layer
.
output_sizes
)
==
self
.
n_slices
==
3
):
raise
ValueError
(
"LoRAColumnParallelLinear3Slice requires 3 slices."
)
self
.
q_proj_shard_size
=
(
self
.
base_layer
.
num_heads
*
self
.
base_layer
.
head_size
)
self
.
kv_proj_shard_size
=
(
self
.
base_layer
.
num_kv_heads
*
...
...
@@ -902,7 +853,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
),
)
if
lora_config
.
bias_enabled
:
self
.
bias_stacked
=
(
self
.
lora_
bias_stacked
=
(
torch
.
zeros
(
max_loras
,
1
,
...
...
@@ -925,9 +876,6 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
device
=
self
.
device
,
),
)
else
:
self
.
bias_stacked
=
None
self
.
output_slices
=
(
self
.
q_proj_shard_size
,
self
.
kv_proj_shard_size
,
...
...
@@ -939,18 +887,6 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_a_stacked
[
1
][
index
]
=
0
self
.
lora_b_stacked
[
1
][
index
]
=
0
self
.
lora_a_stacked
[
2
][
index
]
=
0
self
.
lora_b_stacked
[
2
][
index
]
=
0
if
self
.
lora_config
.
bias_enabled
:
self
.
bias_stacked
[
0
][
index
]
=
0
self
.
bias_stacked
[
1
][
index
]
=
0
self
.
bias_stacked
[
2
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
...
...
@@ -1000,15 +936,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
lora_
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
if
bias
is
not
None
:
bias
=
self
.
slice_bias
(
bias
)
if
lora_
bias
is
not
None
:
lora_
bias
=
self
.
slice_bias
(
lora_
bias
)
if
lora_b
[
0
]
is
not
None
:
lora_b_q
=
lora_b
[
0
]
...
...
@@ -1039,26 +975,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
index
,
0
,
:
lora_a
[
2
].
shape
[
1
],
:
lora_a
[
2
].
shape
[
0
]].
copy_
(
lora_a
[
2
].
T
,
non_blocking
=
True
)
if
bias
is
not
None
:
if
bias
[
0
]
is
not
None
:
self
.
bias_stacked
[
0
][
index
,
0
,
:
bias
[
0
].
shape
[
0
]].
copy_
(
bias
[
0
].
T
,
non_blocking
=
True
)
if
bias
[
1
]
is
not
None
:
self
.
bias_stacked
[
1
][
index
,
0
,
:
bias
[
1
].
shape
[
0
]].
copy_
(
bias
[
1
].
T
,
non_blocking
=
True
)
if
bias
[
2
]
is
not
None
:
self
.
bias_stacked
[
2
][
index
,
0
,
:
bias
[
2
].
shape
[
0
]].
copy_
(
bias
[
2
].
T
,
non_blocking
=
True
)
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
punica_wrapper
.
add_lora_packed_nslice
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
1.0
,
self
.
output_slices
)
return
output
if
lora_bias
is
not
None
:
self
.
lora_bias_stacked
=
cast
(
Tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
)
if
lora_bias
[
0
]
is
not
None
:
self
.
lora_bias_stacked
[
0
][
index
,
0
,
:
lora_bias
[
0
].
shape
[
0
]].
copy_
(
lora_bias
[
0
].
T
,
non_blocking
=
True
)
if
lora_bias
[
1
]
is
not
None
:
self
.
lora_bias_stacked
[
1
][
index
,
0
,
:
lora_bias
[
1
].
shape
[
0
]].
copy_
(
lora_bias
[
1
].
T
,
non_blocking
=
True
)
if
lora_bias
[
2
]
is
not
None
:
self
.
lora_bias_stacked
[
2
][
index
,
0
,
:
lora_bias
[
2
].
shape
[
0
]].
copy_
(
lora_bias
[
2
].
T
,
non_blocking
=
True
)
@
classmethod
@
_not_fully_sharded_can_replace
...
...
@@ -1073,76 +1007,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
and
len
(
packed_modules_list
)
==
3
)
class
RowParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
class
RowParallelLinearWithLoRA
(
BaseL
inearL
ayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
super
().
__init__
()
self
.
base_layer
=
base_layer
super
().
__init__
(
base_layer
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
# reset input_size
self
.
input_size
=
self
.
base_layer
.
input_size_per_partition
self
.
output_size
=
self
.
base_layer
.
output_size
self
.
device
=
_get_lora_device
(
self
.
base_layer
)
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
self
.
lora_config
=
lora_config
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
lora_a_stacked
=
torch
.
zeros
(
(
max_loras
,
1
,
lora_config
.
max_lora_rank
,
self
.
input_size
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
tp_size
=
get_tensor_model_parallel_world_size
()
lora_b_output_size_per_partition
=
(
self
.
output_size
if
not
lora_config
.
fully_sharded_loras
else
divide
(
self
.
output_size
,
tp_size
))
self
.
lora_b_stacked
=
torch
.
zeros
(
(
max_loras
,
1
,
lora_b_output_size_per_partition
,
lora_config
.
max_lora_rank
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
if
lora_config
.
bias_enabled
:
self
.
bias_stacked
=
torch
.
zeros
(
(
max_loras
,
1
,
self
.
output_size
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
else
:
self
.
bias_stacked
=
None
# Lazily initialized
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
if
self
.
lora_config
.
bias_enabled
:
self
.
bias_stacked
[
index
]
=
0
# There is only one LoRA layer.
self
.
n_slices
=
1
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
input_size
start_idx
=
tensor_model_parallel
_rank
*
shard_size
end_idx
=
(
tensor_model_parallel
_rank
+
1
)
*
shard_size
start_idx
=
self
.
tp
_rank
*
shard_size
end_idx
=
(
self
.
tp
_rank
+
1
)
*
shard_size
lora_a
=
lora_a
[
start_idx
:
end_idx
,
:]
return
lora_a
...
...
@@ -1152,40 +1035,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def
slice_bias
(
self
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
bias
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
reset_lora
(
index
)
if
self
.
base_layer
.
tp_size
>
1
:
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
if
bias
is
not
None
:
bias
=
self
.
slice_bias
(
bias
)
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
if
bias
is
not
None
:
self
.
bias_stacked
[
index
,
0
,
:
bias
.
shape
[
0
]].
copy_
(
bias
.
T
,
non_blocking
=
True
)
def
apply
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
)
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
1.0
)
return
output
def
forward
(
self
,
input_
):
"""Forward of RowParallelLinear
...
...
@@ -1203,10 +1052,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
input_parallel
=
input_
else
:
# TODO: simplify code below
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
base_layer
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
input_parallel
=
splitted_input
[
self
.
tp_rank
].
contiguous
()
# Matrix multiply.
output_parallel
=
self
.
apply
(
input_parallel
)
...
...
vllm/lora/models.py
View file @
571da8fc
...
...
@@ -555,17 +555,17 @@ class LoRAModelManager(AdapterModelManager):
input_dim
,
output_dim
,
rank
,
module
.
lora_a_stacked
.
dtype
,
module
.
lora_a_stacked
[
0
]
.
dtype
,
"cpu"
,
embeddings_tensor_dim
=
embeddings_tensor_dim
,
bias_enabled
=
bias_enabled
)
else
:
lora
=
LoRALayerWeights
.
create_dummy_lora_weights
(
module_name
,
module
.
lora_a_stacked
.
shape
[
-
1
],
module
.
lora_b_stacked
.
shape
[
-
2
],
module
.
lora_a_stacked
[
0
]
.
shape
[
-
1
],
module
.
lora_b_stacked
[
0
]
.
shape
[
-
2
],
rank
,
module
.
lora_a_stacked
.
dtype
,
module
.
lora_a_stacked
[
0
]
.
dtype
,
"cpu"
,
bias_enabled
=
bias_enabled
,
)
...
...
vllm/lora/punica.py
View file @
571da8fc
...
...
@@ -362,7 +362,7 @@ class PunicaWrapper:
long_lora_len
=
self
.
indices_len
[
4
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
def
shrink_prefill
(
def
_
shrink_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -380,7 +380,7 @@ class PunicaWrapper:
scale
,
)
def
shrink_decode
(
def
_
shrink_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -389,7 +389,7 @@ class PunicaWrapper:
):
bgmv_shrink
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
scale
)
def
expand_prefill
(
def
_
expand_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -407,7 +407,7 @@ class PunicaWrapper:
add_input
,
)
def
expand_decode
(
def
_
expand_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -416,7 +416,7 @@ class PunicaWrapper:
):
bgmv_expand
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
add_input
)
def
expand_slice_prefill
(
def
_
expand_slice_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -438,7 +438,7 @@ class PunicaWrapper:
add_input
,
)
def
expand_slice_decode
(
def
_
expand_slice_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -450,41 +450,35 @@ class PunicaWrapper:
bgmv_expand_slice
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
y_offset
,
y_slice_size
,
add_input
)
def
apply_bias
(
self
,
indices
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
bias_stacked
:
torch
.
Tensor
,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
def
_apply_expand
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
=
True
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
org_output
=
output
output
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
])
indices
=
indices
.
view
(
-
1
)
bias_stacked
=
bias_stacked
.
view
(
-
1
,
bias_stacked
.
shape
[
-
1
])
bias_stacked
=
bias_stacked
[
indices
]
bias_stacked
[
indices
==
-
1
]
=
0
output
+=
bias_stacked
return
output
.
view_as
(
org_output
)
expand_slice_fun
:
Callable
=
(
self
.
_expand_slice_prefill
if
self
.
is_prefill
else
self
.
_expand_slice_decode
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
)
def
apply_bias
_packed_nslice
(
def
_
apply_bias
(
self
,
indices
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output_slices
:
Tuple
[
int
,
...],
bias_stacked
:
Tuple
[
Optional
[
torch
.
Tensor
],
...],
lora_
bias_stacked
:
Tuple
[
Optional
[
torch
.
Tensor
],
...],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
lora_
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
...
...
@@ -496,7 +490,7 @@ class PunicaWrapper:
offset_left
=
0
for
slice_idx
,
slice
in
enumerate
(
output_slices
):
bias
=
bias_stacked
[
slice_idx
]
bias
=
lora_
bias_stacked
[
slice_idx
]
if
bias
is
not
None
:
bias
=
bias
.
view
(
-
1
,
bias
.
shape
[
-
1
])
bias
=
bias
[
indices
]
...
...
@@ -506,7 +500,7 @@ class PunicaWrapper:
return
output
.
view_as
(
org_output
)
def
add
_shrink
(
def
_apply
_shrink
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -517,188 +511,215 @@ class PunicaWrapper:
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the shrink_decode function
prefill stage, and the `
_
shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the
_
shrink_decode function
should be called.
"""
shrink_fun
:
Callable
=
(
self
.
shrink_prefill
if
self
.
is_prefill
else
self
.
shrink_decode
)
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
shrink_fun
:
Callable
=
(
self
.
_shrink_prefill
if
self
.
is_prefill
else
self
.
_shrink_decode
)
shrink_fun
(
y
,
x
,
w_t_all
,
scale
)
y
=
y
.
view_as
(
y_org
)
def
add_
expand
(
def
add_
shrink
(
self
,
y
:
torch
.
Tensor
,
y
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
bias_all
:
Optional
[
torch
.
Tensor
],
add_input
:
bool
=
True
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
scale
:
float
,
):
"""
Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the
GEMM of lora'b.
When `is_prefill` is true, it indicates that it is currently the
prefill stage, and the `expand_prefill` function should be called.
Otherwise, it is the decode stage, and the expand_decode function
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
if
bias_all
is
not
None
:
y
=
self
.
apply_bias
(
self
.
token_lora_indices
,
y
,
bias_all
)
expand_fun
:
Callable
=
(
self
.
expand_prefill
if
self
.
is_prefill
else
self
.
expand_decode
)
expand_fun
(
y
,
x
,
w_t_all
,
add_input
)
def
add_expand_slice
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
bias_all
:
Optional
[
torch
.
Tensor
],
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
=
True
):
"""
Similar to `add_expand`
"""
if
bias_all
is
not
None
:
y
=
self
.
apply_bias
(
self
.
token_lora_indices
,
y
,
bias_all
)
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
expand_slice_fun
:
Callable
=
(
self
.
expand_slice_prefill
if
self
.
is_prefill
else
self
.
expand_slice_decode
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
# TODO fuse these kernels
for
slice_idx
in
range
(
len
(
lora_a_stacked
)):
self
.
_apply_shrink
(
y
[
slice_idx
],
x
,
lora_a_stacked
[
slice_idx
],
scale
)
def
add_expand_packed_nslice
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
scale
:
float
,
output_slices
:
Tuple
[
int
,
...])
->
None
:
"""
Similar to `add_expand`
def
add_expand
(
self
,
y
:
torch
.
Tensor
,
x
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
output_slices
:
Tuple
[
int
,
...],
offset_start
:
int
=
0
,
add_input
=
True
,
)
->
None
:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
offset_left
=
0
if
bias_stacked
is
not
None
:
self
.
apply_bias
_packed_nslice
(
self
.
token_lora_indices
,
y
,
output_slices
,
bias_stacked
)
offset_left
=
offset_start
if
lora_
bias_stacked
is
not
None
:
self
.
_
apply_bias
(
self
.
token_lora_indices
,
y
,
output_slices
,
lora_
bias_stacked
)
for
slice_idx
in
range
(
len
(
lora_b_stacked
)):
self
.
add_expand_slice
(
y
,
x
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
None
,
offset_left
,
output_slices
[
slice_idx
],
add_input
=
True
)
self
.
_apply_expand
(
y
,
x
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
offset_left
,
output_slices
[
slice_idx
],
add_input
=
add_input
,
)
offset_left
+=
output_slices
[
slice_idx
]
y
=
y
.
view_as
(
y_org
)
def
add_lora
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
bias_all
:
Optional
[
torch
.
Tensor
],
scale
:
float
,
y_offset
:
Optional
[
int
]
=
None
,
y_slice_size
:
Optional
[
int
]
=
None
,
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
def
add_lora_embedding
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
add_input
:
bool
=
True
,
):
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun
:
Callable
=
(
self
.
_expand_prefill
if
self
.
is_prefill
else
self
.
_expand_decode
)
expand_fun
(
y
,
x
,
lora_b_stacked
,
add_input
)
def
add_lora_linear
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
scale
:
float
,
output_slices
:
Tuple
[
int
,
...],
*
,
buffer
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]]
=
None
)
->
None
:
"""
Applicable to linear-related lora.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)+bias[i]
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor):
Output tensor. Will be changed in-place.
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
wa_t_all (
torch.Tensor): lora_a's weight
wb_t_all (
torch.Tensor): lora_b's weight
bias_all: (
torch.Tensor): lora's bias
lora_a_stacked (Tuple[
torch.Tensor
, ...]
): lora_a's weight
.
lora_b_stacked (Tuple[
torch.Tensor
, ...]
): lora_b's weight
.
lora_bias_stacked (Optional[Tuple[
torch.Tensor
, ...]]
): lora's bias
.
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
r
=
wb_t_all
.
size
(
-
1
)
assert
len
(
lora_a_stacked
)
==
len
(
lora_b_stacked
)
==
len
(
output_slices
)
if
lora_bias_stacked
is
not
None
:
assert
len
(
lora_bias_stacked
)
==
len
(
output_slices
)
y
=
self
.
_apply_bias
(
self
.
token_lora_indices
,
y
,
output_slices
,
lora_bias_stacked
)
if
buffer
is
None
:
r
=
lora_b_stacked
[
0
].
size
(
-
1
)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
bias_all
is
not
None
:
y
=
self
.
apply_bias
(
self
.
token_lora_indices
,
y
,
bias_all
)
self
.
add_shrink
(
buffer
,
x
,
wa_t_all
,
scale
)
if
y_offset
is
None
and
y_slice_size
is
None
:
self
.
add_expand
(
y
,
buffer
,
wb_t_all
,
bias_all
=
None
,
add_input
=
True
)
else
:
self
.
add_expand_slice
(
y
,
buffer
,
wb_t_all
,
None
,
y_offset
,
y_slice_size
,
add_input
=
True
)
y
=
y
.
view_as
(
y_org
)
def
add_lora_packed_nslice
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
bias_all
:
Tuple
[
Optional
[
torch
.
Tensor
],
...],
scale
:
float
,
output_slices
:
Tuple
[
int
,
...])
->
None
:
"""
Applies lora to each input. Similar to add_lora, This method is
used for layers that are composed of multiple sublayers
(slices) packed together.
"""
y_org
=
y
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
offset_left
=
0
if
bias_all
is
not
None
:
y
=
self
.
apply_bias_packed_nslice
(
self
.
token_lora_indices
,
y
,
output_slices
,
bias_all
)
# TODO fuse these kernels
for
slice_idx
in
range
(
len
(
output_slices
)):
self
.
add_lora
(
y
,
x
,
lora_a_stacked
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
None
,
scale
,
offset_left
,
output_slices
[
slice_idx
])
offset_left
+=
output_slices
[
slice_idx
]
y
=
y
.
view_as
(
y_org
)
buffer
=
tuple
(
torch
.
zeros
(
(
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
for
_
in
range
(
len
(
output_slices
)))
self
.
add_shrink
(
buffer
,
x
,
lora_a_stacked
,
scale
)
self
.
add_expand
(
y
,
buffer
,
lora_b_stacked
,
None
,
output_slices
,
add_input
=
True
)
def
add_lora_logits
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
lora_a_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
scale
,
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
"""
LogitsProcessorWithLoRA always using bgmv
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
r
=
wb_t_all
.
size
(
-
1
)
r
=
lora_b_stacked
.
size
(
-
1
)
if
buffer
is
None
:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
bgmv_shrink
(
x
,
wa_t_all
,
buffer
,
self
.
sampler_indices
,
scale
)
bgmv_expand
(
buffer
,
wb_t_all
,
y
,
self
.
sampler_indices
,
add_inputs
=
True
)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink
(
x
,
lora_a_stacked
,
buffer
,
self
.
sampler_indices
,
scale
)
bgmv_expand
(
buffer
,
lora_b_stacked
,
y
,
self
.
sampler_indices
,
add_inputs
=
True
)
y
=
y
.
view_as
(
y_org
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment