Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9b0e3ec9
Unverified
Commit
9b0e3ec9
authored
Sep 24, 2024
by
Jee Jee Li
Committed by
GitHub
Sep 23, 2024
Browse files
[Kernel][LoRA] Add assertion for punica sgmv kernels (#7585)
parent
86e9c8df
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
64 additions
and
38 deletions
+64
-38
tests/lora/test_punica_sizes.py
tests/lora/test_punica_sizes.py
+5
-0
tests/lora/test_punica_variation.py
tests/lora/test_punica_variation.py
+5
-0
vllm/lora/ops/bgmv_expand.py
vllm/lora/ops/bgmv_expand.py
+1
-1
vllm/lora/ops/bgmv_expand_slice.py
vllm/lora/ops/bgmv_expand_slice.py
+1
-1
vllm/lora/ops/sgmv_expand.py
vllm/lora/ops/sgmv_expand.py
+10
-6
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+11
-7
vllm/lora/ops/sgmv_shrink.py
vllm/lora/ops/sgmv_shrink.py
+10
-6
vllm/lora/punica.py
vllm/lora/punica.py
+21
-17
No files found.
tests/lora/test_punica_sizes.py
View file @
9b0e3ec9
...
...
@@ -169,6 +169,7 @@ def test_punica_sgmv(
device
,
)
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
...
...
@@ -183,6 +184,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
scaling
,
)
else
:
...
...
@@ -195,6 +197,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
add_inputs
=
True
,
)
ref_torch_groupgemm
(
...
...
@@ -347,6 +350,7 @@ def test_punica_expand_nslices(
device
,
)
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
...
...
@@ -364,6 +368,7 @@ def test_punica_expand_nslices(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
slice_offset
,
hidden_size
,
add_inputs
=
True
,
...
...
tests/lora/test_punica_variation.py
View file @
9b0e3ec9
...
...
@@ -84,6 +84,7 @@ def test_punica_sgmv(
device
,
)
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
...
...
@@ -98,6 +99,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
scaling
,
)
else
:
...
...
@@ -110,6 +112,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
add_inputs
=
True
,
)
ref_torch_groupgemm
(
...
...
@@ -262,6 +265,7 @@ def test_punica_expand_nslices(
device
,
)
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
...
...
@@ -279,6 +283,7 @@ def test_punica_expand_nslices(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
slice_offset
,
hidden_size
,
add_inputs
=
True
,
...
...
vllm/lora/ops/bgmv_expand.py
View file @
9b0e3ec9
...
...
@@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False
.
adds the final lora
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
...
vllm/lora/ops/bgmv_expand_slice.py
View file @
9b0e3ec9
...
...
@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offst (int): output_tensor's offst
slice_offs
e
t (int): output_tensor's offs
e
t
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
...
...
vllm/lora/ops/sgmv_expand.py
View file @
9b0e3ec9
...
...
@@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
add_inputs
:
bool
=
False
,
)
->
None
:
"""
...
...
@@ -115,17 +116,19 @@ def _sgmv_expand(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
length of the sequences
in the batch
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
add_inputs (bool, optional): Defaults to False. adds the final lora
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
...
...
@@ -134,6 +137,7 @@ def _sgmv_expand(
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_expand_slice.py
View file @
9b0e3ec9
...
...
@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
slice_offset
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
False
,
...
...
@@ -124,20 +125,22 @@ def _sgmv_expand_slice(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
length of the sequences
in the batch
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int):
The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences
in the batch
slice_offst (int): output_tensor's offst
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional):
Defaults to False
.
adds the final lora
results to the output.
.
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
...
@@ -145,6 +148,7 @@ def _sgmv_expand_slice(
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_shrink.py
View file @
9b0e3ec9
...
...
@@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
scaling
:
float
,
)
->
None
:
"""
...
...
@@ -120,17 +121,19 @@ def _sgmv_shrink(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
length of the sequences
in the batch
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
scaling (float): Scaling factor.
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert
inputs
.
dtype
==
lora_a_weights
.
dtype
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
...
...
@@ -138,6 +141,7 @@ def _sgmv_shrink(
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_a_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/punica.py
View file @
9b0e3ec9
...
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
def
compute_meta
(
token_lora_tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
bool
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
,
bool
]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
...
...
@@ -43,7 +43,7 @@ def compute_meta(
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
max_length
=
seq_length_tensor
.
max
().
item
()
token_nums
=
seq_length_tensor
.
sum
().
item
()
batch_size
=
lora_indices_tensor
.
size
(
0
)
no_lora
=
False
# -1 means no lora should be applied. Use `no_lora` to determine whether
...
...
@@ -52,7 +52,7 @@ def compute_meta(
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
no_lora
=
True
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
# TODO see if this can be vectorized
...
...
@@ -178,7 +178,7 @@ def convert_mapping(
class
PunicaWrapper
:
"""
PunicaWrapper is designed to manage and provide metadata for the punica
kernel. The main function
is to maintain the state information for
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica kernel.
"""
...
...
@@ -216,6 +216,7 @@ class PunicaWrapper:
dtype
=
torch
.
long
,
device
=
device
)
self
.
max_length
:
int
=
0
self
.
token_nums
:
int
=
0
self
.
batch_size
:
int
=
-
1
self
.
is_prefill
=
False
self
.
no_lora
=
False
...
...
@@ -276,13 +277,13 @@ class PunicaWrapper:
long_lora_offsets_tensor
)
else
:
self
.
_long_lora_indices
.
zero_
()
self
.
indices_len
[:]
=
indices_len
def
_update_prefill_metada
(
self
,
token_lora_tensor
:
torch
.
Tensor
)
->
None
:
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
self
.
_seq_start_locs
[:
b_seq_start_tensor
.
shape
[
0
]].
copy_
(
b_seq_start_tensor
)
...
...
@@ -291,25 +292,28 @@ class PunicaWrapper:
lora_indices_tensor
)
self
.
batch_size
=
batch_size
self
.
max_length
=
max_length
self
.
token_nums
=
token_nums
self
.
no_lora
=
no_lora
@
property
def
prefill_metadata
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
]:
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
2. seq_lengths: Tensor of sequence lengths
1. seq_start_locs: Tensor of sequence start positions
.
2. seq_lengths: Tensor of sequence lengths
.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
5. max_length: The maximum sequence length in the batch
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return
(
self
.
_seq_start_locs
[:
self
.
batch_size
],
self
.
_seq_lengths
[:
self
.
batch_size
],
self
.
_lora_indices_per_batch
[:
self
.
batch_size
],
self
.
batch_size
,
self
.
max_length
)
self
.
batch_size
,
self
.
max_length
,
self
.
token_nums
)
@
property
def
token_lora_indices
(
self
)
->
torch
.
Tensor
:
...
...
@@ -324,7 +328,7 @@ class PunicaWrapper:
def
sampler_indices
(
self
)
->
torch
.
Tensor
:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
LogitsProcessorWithLoRA
.
"""
sampler_indices_len
=
self
.
indices_len
[
1
]
return
self
.
_sampler_indices
[:
sampler_indices_len
]
...
...
@@ -332,7 +336,7 @@ class PunicaWrapper:
@
property
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to padded sampler indices
This property provides access to padded sampler indices
.
"""
indices_padded_len
=
self
.
indices_len
[
2
]
return
self
.
_sampler_indices_padded
[:
indices_padded_len
]
...
...
@@ -341,7 +345,7 @@ class PunicaWrapper:
def
embeddings_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
specifically for VocabParallelEmbeddingWithLoRA
.
"""
embeddings_indices_len
=
self
.
indices_len
[
3
]
return
self
.
_embeddings_indices
[:,
:
embeddings_indices_len
]
...
...
@@ -350,7 +354,7 @@ class PunicaWrapper:
def
long_lora_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
lora, specifically for LinearScalingRotaryEmbeddingWithLora
.
"""
long_lora_len
=
self
.
indices_len
[
4
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
...
...
@@ -524,7 +528,7 @@ class PunicaWrapper:
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice.
.
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
y_org
=
y
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment