Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b45f0d79
Unverified
Commit
b45f0d79
authored
Dec 03, 2024
by
Jee Jee Li
Committed by
GitHub
Dec 02, 2024
Browse files
[Misc][LoRA] Move the implementation of lora bias to punica.py (#10829)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
a4c4daf3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
156 additions
and
175 deletions
+156
-175
tests/lora/test_llama_tp.py
tests/lora/test_llama_tp.py
+27
-33
vllm/lora/fully_sharded_layers.py
vllm/lora/fully_sharded_layers.py
+12
-29
vllm/lora/layers.py
vllm/lora/layers.py
+12
-101
vllm/lora/punica.py
vllm/lora/punica.py
+105
-12
No files found.
tests/lora/test_llama_tp.py
View file @
b45f0d79
...
@@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
...
@@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return
generated_texts
return
generated_texts
@
fork_new_process_for_each_test
def
generate_and_test
(
llm
,
sql_lora_files
):
def
test_llama_lora
(
sql_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_loras
=
4
,
tensor_parallel_size
=
1
)
print
(
"lora adapter created"
)
print
(
"lora adapter created"
)
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
0
)
==
EXPECTED_NO_LORA_OUTPUT
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
0
)
==
EXPECTED_NO_LORA_OUTPUT
...
@@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files):
...
@@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files):
print
(
"removing lora"
)
print
(
"removing lora"
)
@
fork_new_process_for_each_test
def
test_llama_lora
(
sql_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_loras
=
4
,
tensor_parallel_size
=
1
)
generate_and_test
(
llm
,
sql_lora_files
)
@
fork_new_process_for_each_test
@
fork_new_process_for_each_test
def
test_llama_lora_warmup
(
sql_lora_files
):
def
test_llama_lora_warmup
(
sql_lora_files
):
"""Test that the LLM initialization works with a warmup LORA path and
"""Test that the LLM initialization works with a warmup LORA path and
...
@@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
...
@@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
max_loras
=
4
,
max_loras
=
4
,
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
)
)
generate_and_test
(
llm
,
sql_lora_files
)
print
(
"lora adapter created"
)
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
0
)
==
EXPECTED_NO_LORA_OUTPUT
print
(
"lora 1"
)
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
1
)
==
EXPECTED_LORA_OUTPUT
print
(
"no lora"
)
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
0
)
==
EXPECTED_NO_LORA_OUTPUT
print
(
"lora 2"
)
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
2
)
==
EXPECTED_LORA_OUTPUT
print
(
"removing lora"
)
@
multi_gpu_test
(
num_gpus
=
4
)
@
multi_gpu_test
(
num_gpus
=
4
)
...
@@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
...
@@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
fully_sharded_loras
=
True
,
fully_sharded_loras
=
True
,
)
)
print
(
"lora adapter created"
)
generate_and_test
(
llm
,
sql_lora_files
)
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
0
)
==
EXPECTED_NO_LORA_OUTPUT
print
(
"lora 1"
)
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
1
)
==
EXPECTED_LORA_OUTPUT
print
(
"no lora"
)
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
0
)
==
EXPECTED_NO_LORA_OUTPUT
print
(
"lora 2"
)
@
multi_gpu_test
(
num_gpus
=
4
)
assert
do_sample
(
llm
,
sql_lora_files
,
lora_id
=
2
)
==
EXPECTED_LORA_OUTPUT
@
fork_new_process_for_each_test
def
test_llama_lora_tp4_fully_sharded_enable_bias
(
sql_lora_files
):
print
(
"removing lora"
)
llm
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_loras
=
4
,
tensor_parallel_size
=
4
,
fully_sharded_loras
=
True
,
enable_lora_bias
=
True
,
)
generate_and_test
(
llm
,
sql_lora_files
)
vllm/lora/fully_sharded_layers.py
View file @
b45f0d79
...
@@ -73,6 +73,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
...
@@ -73,6 +73,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
self
.
punica_wrapper
.
add_expand
(
output
,
self
.
punica_wrapper
.
add_expand
(
output
,
buffer
,
buffer
,
self
.
lora_b_stacked
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
add_input
=
True
)
add_input
=
True
)
# now have column partitioned output
# now have column partitioned output
...
@@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
...
@@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
layer
.
lora_a_stacked
[
idx
],
1.0
)
layer
.
lora_a_stacked
[
idx
],
1.0
)
buffers
=
tensor_model_parallel_all_gather
(
buffers
)
buffers
=
tensor_model_parallel_all_gather
(
buffers
)
left_offset
=
0
layer
.
punica_wrapper
.
add_expand_packed_nslice
(
for
idx
in
range
(
n
):
shard_size
=
layer
.
lora_b_stacked
[
idx
].
shape
[
2
]
if
layer
.
bias_stacked
is
not
None
:
bias
=
layer
.
bias_stacked
[
idx
]
if
bias
is
not
None
:
bias
=
bias
.
view
(
-
1
,
bias
.
shape
[
-
1
])
bias
=
bias
[
layer
.
punica_wrapper
.
token_lora_indices
]
bias
[
layer
.
punica_wrapper
.
token_lora_indices
==
-
1
]
=
0
output
[:,
left_offset
:
left_offset
+
shard_size
]
+=
bias
layer
.
punica_wrapper
.
add_expand_slice
(
output
,
output
,
buffers
[
idx
]
,
buffers
,
layer
.
lora_b_stacked
[
idx
]
,
layer
.
lora_b_stacked
,
left_offset
,
layer
.
bias_stacked
,
shard_size
,
1.0
,
add_input
=
True
,
layer
.
output_slices
,
)
)
left_offset
+=
shard_size
output
=
output
.
view
(
*
out_orig_shape
)
output
=
output
.
view
(
*
out_orig_shape
)
# now have column partitioned and packed output
# now have column partitioned and packed output
...
@@ -234,6 +222,7 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
...
@@ -234,6 +222,7 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
self
.
punica_wrapper
.
add_expand
(
output
,
self
.
punica_wrapper
.
add_expand
(
output
,
buffer
,
buffer
,
self
.
lora_b_stacked
,
self
.
lora_b_stacked
,
self
.
bias_all
,
add_input
=
True
)
add_input
=
True
)
# now have column partitioned output
# now have column partitioned output
output
=
output
.
view
(
*
out_orig_shape
)
output
=
output
.
view
(
*
out_orig_shape
)
...
@@ -350,15 +339,9 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
...
@@ -350,15 +339,9 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# reduced before being used
# reduced before being used
shard_size
=
self
.
lora_b_stacked
.
shape
[
2
]
shard_size
=
self
.
lora_b_stacked
.
shape
[
2
]
start_idx
=
self
.
tp_rank
*
shard_size
start_idx
=
self
.
tp_rank
*
shard_size
if
self
.
bias_stacked
is
not
None
:
bias
=
self
.
bias_stacked
.
view
(
-
1
,
self
.
bias_stacked
.
shape
[
-
1
])
bias
=
bias
[
self
.
punica_wrapper
.
token_lora_indices
]
bias
[
self
.
punica_wrapper
.
token_lora_indices
==
-
1
]
=
0
output
+=
bias
self
.
punica_wrapper
.
add_expand_slice
(
output
,
buffer
,
self
.
punica_wrapper
.
add_expand_slice
(
output
,
buffer
,
self
.
lora_b_stacked
,
start_idx
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
start_idx
,
shard_size
)
shard_size
)
output
=
output
.
view
(
*
out_orig_shape
)
output
=
output
.
view
(
*
out_orig_shape
)
return
output
return
output
...
...
vllm/lora/layers.py
View file @
b45f0d79
...
@@ -67,63 +67,6 @@ def _not_fully_sharded_can_replace(can_replace):
...
@@ -67,63 +67,6 @@ def _not_fully_sharded_can_replace(can_replace):
return
dec
return
dec
def
apply_bias
(
indices
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
bias_stacked
:
torch
.
Tensor
,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output
=
output
output
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
])
indices
=
indices
.
view
(
-
1
)
bias_stacked
=
bias_stacked
.
view
(
-
1
,
bias_stacked
.
shape
[
-
1
])
bias_stacked
=
bias_stacked
[
indices
]
bias_stacked
[
indices
==
-
1
]
=
0
output
+=
bias_stacked
return
output
.
view_as
(
org_output
)
def
apply_bias_packed_nslice
(
indices
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output_slices
:
Tuple
[
int
,
...],
bias_stacked
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output
=
output
output
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
])
indices
=
indices
.
view
(
-
1
)
offset_left
=
0
for
slice_idx
,
slice
in
enumerate
(
output_slices
):
bias
=
bias_stacked
[
slice_idx
]
if
bias
is
not
None
:
bias
=
bias
.
view
(
-
1
,
bias
.
shape
[
-
1
])
bias
=
bias
[
indices
]
bias
[
indices
==
-
1
]
=
0
output
[:,
offset_left
:
offset_left
+
slice
]
+=
bias
offset_left
+=
slice
return
output
.
view_as
(
org_output
)
@
dataclass
@
dataclass
class
LoRAMapping
(
AdapterMapping
):
class
LoRAMapping
(
AdapterMapping
):
is_prefill
:
bool
=
False
is_prefill
:
bool
=
False
...
@@ -311,6 +254,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -311,6 +254,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self
.
punica_wrapper
.
add_expand
(
full_output
,
self
.
punica_wrapper
.
add_expand
(
full_output
,
full_lora_a_embeddings
,
full_lora_a_embeddings
,
self
.
lora_b_stacked
,
self
.
lora_b_stacked
,
bias_all
=
None
,
add_input
=
True
)
add_input
=
True
)
return
full_output
.
view_as
(
full_output_org
)
return
full_output
.
view_as
(
full_output_org
)
...
@@ -399,15 +343,9 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -399,15 +343,9 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
def
apply
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
if
self
.
bias_stacked
is
not
None
:
self
.
indices
=
self
.
punica_wrapper
.
token_lora_indices
output
=
apply_bias
(
self
.
indices
,
output
,
self
.
bias_stacked
,
)
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
1.0
)
self
.
lora_b_stacked
,
self
.
bias_stacked
,
1.0
)
return
output
return
output
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
@@ -576,15 +514,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -576,15 +514,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def
apply
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
if
self
.
bias_stacked
is
not
None
:
self
.
indices
=
self
.
punica_wrapper
.
token_lora_indices
output
=
apply_bias
(
self
.
indices
,
output
,
self
.
bias_stacked
,
)
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
1.0
)
self
.
lora_b_stacked
,
self
.
bias_stacked
,
1.0
)
return
output
return
output
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
@@ -687,8 +619,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -687,8 +619,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
)
for
_
in
range
(
n_slices
))
)
for
_
in
range
(
n_slices
))
else
:
else
:
self
.
bias_stacked
=
None
self
.
bias_stacked
=
None
self
.
output_dim
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
self
.
output_dim
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
self
.
output_slices
=
(
self
.
output_dim
,
self
.
output_dim
)
def
reset_lora
(
self
,
index
:
int
):
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
0
][
index
]
=
0
self
.
lora_a_stacked
[
0
][
index
]
=
0
...
@@ -772,17 +704,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -772,17 +704,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def
apply
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
if
self
.
bias_stacked
is
not
None
:
self
.
indices
=
self
.
punica_wrapper
.
token_lora_indices
output
=
apply_bias_packed_nslice
(
self
.
indices
,
output
,
(
self
.
output_dim
,
self
.
output_dim
),
self
.
bias_stacked
,
)
self
.
punica_wrapper
.
add_lora_packed_nslice
(
self
.
punica_wrapper
.
add_lora_packed_nslice
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
1.0
,
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
(
self
.
output_dim
,
self
.
output_dim
))
self
.
bias_stacked
,
1.0
,
(
self
.
output_dim
,
self
.
output_dim
))
return
output
return
output
@
classmethod
@
classmethod
...
@@ -1129,17 +1053,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -1129,17 +1053,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def
apply
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
if
self
.
bias_stacked
is
not
None
:
self
.
indices
=
self
.
punica_wrapper
.
token_lora_indices
output
=
apply_bias_packed_nslice
(
self
.
indices
,
output
,
self
.
output_slices
,
self
.
bias_stacked
,
)
self
.
punica_wrapper
.
add_lora_packed_nslice
(
output
,
x
,
self
.
punica_wrapper
.
add_lora_packed_nslice
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
1.0
,
self
.
lora_b_stacked
,
self
.
bias_stacked
,
1.0
,
self
.
output_slices
)
self
.
output_slices
)
return
output
return
output
...
@@ -1264,15 +1181,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -1264,15 +1181,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def
apply
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
)
if
self
.
bias_stacked
is
not
None
:
self
.
indices
=
self
.
punica_wrapper
.
token_lora_indices
output
=
apply_bias
(
self
.
indices
,
output
,
self
.
bias_stacked
,
)
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
1.0
)
self
.
lora_b_stacked
,
self
.
bias_stacked
,
1.0
)
return
output
return
output
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
...
vllm/lora/punica.py
View file @
b45f0d79
...
@@ -450,6 +450,62 @@ class PunicaWrapper:
...
@@ -450,6 +450,62 @@ class PunicaWrapper:
bgmv_expand_slice
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
y_offset
,
bgmv_expand_slice
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
y_offset
,
y_slice_size
,
add_input
)
y_slice_size
,
add_input
)
def
apply_bias
(
self
,
indices
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
bias_stacked
:
torch
.
Tensor
,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output
=
output
output
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
])
indices
=
indices
.
view
(
-
1
)
bias_stacked
=
bias_stacked
.
view
(
-
1
,
bias_stacked
.
shape
[
-
1
])
bias_stacked
=
bias_stacked
[
indices
]
bias_stacked
[
indices
==
-
1
]
=
0
output
+=
bias_stacked
return
output
.
view_as
(
org_output
)
def
apply_bias_packed_nslice
(
self
,
indices
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output_slices
:
Tuple
[
int
,
...],
bias_stacked
:
Tuple
[
Optional
[
torch
.
Tensor
],
...],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output
=
output
output
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
])
indices
=
indices
.
view
(
-
1
)
offset_left
=
0
for
slice_idx
,
slice
in
enumerate
(
output_slices
):
bias
=
bias_stacked
[
slice_idx
]
if
bias
is
not
None
:
bias
=
bias
.
view
(
-
1
,
bias
.
shape
[
-
1
])
bias
=
bias
[
indices
]
bias
[
indices
==
-
1
]
=
0
output
[:,
offset_left
:
offset_left
+
slice
]
+=
bias
offset_left
+=
slice
return
output
.
view_as
(
org_output
)
def
add_shrink
(
def
add_shrink
(
self
,
self
,
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
...
@@ -474,16 +530,19 @@ class PunicaWrapper:
...
@@ -474,16 +530,19 @@ class PunicaWrapper:
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
bias_all
:
Optional
[
torch
.
Tensor
],
add_input
:
bool
=
True
,
add_input
:
bool
=
True
,
):
):
"""
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
Perform the ` y+=x@w_t_all
+bias
` computation, which is suitable for the
GEMM of lora'b.
GEMM of lora'b.
When `is_prefill` is true, it indicates that it is currently the
When `is_prefill` is true, it indicates that it is currently the
prefill stage, and the `expand_prefill` function should be called.
prefill stage, and the `expand_prefill` function should be called.
Otherwise, it is the decode stage, and the expand_decode function
Otherwise, it is the decode stage, and the expand_decode function
should be called.
should be called.
"""
"""
if
bias_all
is
not
None
:
y
=
self
.
apply_bias
(
self
.
token_lora_indices
,
y
,
bias_all
)
expand_fun
:
Callable
=
(
self
.
expand_prefill
expand_fun
:
Callable
=
(
self
.
expand_prefill
if
self
.
is_prefill
else
self
.
expand_decode
)
if
self
.
is_prefill
else
self
.
expand_decode
)
...
@@ -493,23 +552,54 @@ class PunicaWrapper:
...
@@ -493,23 +552,54 @@ class PunicaWrapper:
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
bias_all
:
Optional
[
torch
.
Tensor
],
y_offset
:
Optional
[
int
],
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
=
True
):
add_input
:
bool
=
True
):
"""
"""
Similar to `add_expand`
Similar to `add_expand`
"""
"""
if
bias_all
is
not
None
:
y
=
self
.
apply_bias
(
self
.
token_lora_indices
,
y
,
bias_all
)
expand_slice_fun
:
Callable
=
(
self
.
expand_slice_prefill
expand_slice_fun
:
Callable
=
(
self
.
expand_slice_prefill
if
self
.
is_prefill
else
if
self
.
is_prefill
else
self
.
expand_slice_decode
)
self
.
expand_slice_decode
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
)
def
add_expand_packed_nslice
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
scale
:
float
,
output_slices
:
Tuple
[
int
,
...])
->
None
:
"""
Similar to `add_expand`
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
offset_left
=
0
if
bias_stacked
is
not
None
:
self
.
apply_bias_packed_nslice
(
self
.
token_lora_indices
,
y
,
output_slices
,
bias_stacked
)
for
slice_idx
in
range
(
len
(
lora_b_stacked
)):
self
.
add_expand_slice
(
y
,
x
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
None
,
offset_left
,
output_slices
[
slice_idx
],
add_input
=
True
)
offset_left
+=
output_slices
[
slice_idx
]
y
=
y
.
view_as
(
y_org
)
def
add_lora
(
self
,
def
add_lora
(
self
,
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
bias_all
:
Optional
[
torch
.
Tensor
],
scale
:
float
,
scale
:
float
,
y_offset
:
Optional
[
int
]
=
None
,
y_offset
:
Optional
[
int
]
=
None
,
y_slice_size
:
Optional
[
int
]
=
None
,
y_slice_size
:
Optional
[
int
]
=
None
,
...
@@ -522,12 +612,13 @@ class PunicaWrapper:
...
@@ -522,12 +612,13 @@ class PunicaWrapper:
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
* scale
).squeeze(0)
).squeeze(0)
+bias[i]
Args:
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
x (torch.Tensor): Input tensor
wa_t_all (torch.Tensor): lora_a's weight
wa_t_all (torch.Tensor): lora_a's weight
wb_t_all (torch.Tensor): lora_b's weight
wb_t_all (torch.Tensor): lora_b's weight
bias_all: (torch.Tensor): lora's bias
scale (float): Scaling factor.
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
column of y.
...
@@ -544,27 +635,26 @@ class PunicaWrapper:
...
@@ -544,27 +635,26 @@ class PunicaWrapper:
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
device
=
x
.
device
)
if
bias_all
is
not
None
:
y
=
self
.
apply_bias
(
self
.
token_lora_indices
,
y
,
bias_all
)
self
.
add_shrink
(
buffer
,
x
,
wa_t_all
,
scale
)
self
.
add_shrink
(
buffer
,
x
,
wa_t_all
,
scale
)
if
y_offset
is
None
and
y_slice_size
is
None
:
if
y_offset
is
None
and
y_slice_size
is
None
:
self
.
add_expand
(
y
,
buffer
,
wb_t_all
,
add_input
=
True
)
self
.
add_expand
(
y
,
buffer
,
wb_t_all
,
bias_all
=
None
,
add_input
=
True
)
else
:
else
:
self
.
add_expand_slice
(
y
,
self
.
add_expand_slice
(
y
,
buffer
,
buffer
,
wb_t_all
,
wb_t_all
,
None
,
y_offset
,
y_offset
,
y_slice_size
,
y_slice_size
,
add_input
=
True
)
add_input
=
True
)
y
=
y
.
view_as
(
y_org
)
y
=
y
.
view_as
(
y_org
)
def
add_lora_packed_nslice
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
def
add_lora_packed_nslice
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
,
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
bias_all
:
Tuple
[
Optional
[
torch
.
Tensor
],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
scale
:
float
,
torch
.
Tensor
,
torch
.
Tensor
],
scale
:
float
,
output_slices
:
Tuple
[
int
,
...])
->
None
:
output_slices
:
Tuple
[
int
,
...])
->
None
:
"""
"""
Applies lora to each input. Similar to add_lora, This method is
Applies lora to each input. Similar to add_lora, This method is
...
@@ -575,10 +665,13 @@ class PunicaWrapper:
...
@@ -575,10 +665,13 @@ class PunicaWrapper:
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
offset_left
=
0
offset_left
=
0
if
bias_all
is
not
None
:
y
=
self
.
apply_bias_packed_nslice
(
self
.
token_lora_indices
,
y
,
output_slices
,
bias_all
)
# TODO fuse these kernels
# TODO fuse these kernels
for
slice_idx
in
range
(
len
(
output_slices
)):
for
slice_idx
in
range
(
len
(
output_slices
)):
self
.
add_lora
(
y
,
x
,
lora_a_stacked
[
slice_idx
],
self
.
add_lora
(
y
,
x
,
lora_a_stacked
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
scale
,
offset_left
,
lora_b_stacked
[
slice_idx
],
None
,
scale
,
offset_left
,
output_slices
[
slice_idx
])
output_slices
[
slice_idx
])
offset_left
+=
output_slices
[
slice_idx
]
offset_left
+=
output_slices
[
slice_idx
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment