Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9b0e3ec9
Unverified
Commit
9b0e3ec9
authored
Sep 24, 2024
by
Jee Jee Li
Committed by
GitHub
Sep 23, 2024
Browse files
[Kernel][LoRA] Add assertion for punica sgmv kernels (#7585)
parent
86e9c8df
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
64 additions
and
38 deletions
+64
-38
tests/lora/test_punica_sizes.py
tests/lora/test_punica_sizes.py
+5
-0
tests/lora/test_punica_variation.py
tests/lora/test_punica_variation.py
+5
-0
vllm/lora/ops/bgmv_expand.py
vllm/lora/ops/bgmv_expand.py
+1
-1
vllm/lora/ops/bgmv_expand_slice.py
vllm/lora/ops/bgmv_expand_slice.py
+1
-1
vllm/lora/ops/sgmv_expand.py
vllm/lora/ops/sgmv_expand.py
+10
-6
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+11
-7
vllm/lora/ops/sgmv_shrink.py
vllm/lora/ops/sgmv_shrink.py
+10
-6
vllm/lora/punica.py
vllm/lora/punica.py
+21
-17
No files found.
tests/lora/test_punica_sizes.py
View file @
9b0e3ec9
...
...
@@ -169,6 +169,7 @@ def test_punica_sgmv(
device
,
)
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
...
...
@@ -183,6 +184,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
scaling
,
)
else
:
...
...
@@ -195,6 +197,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
add_inputs
=
True
,
)
ref_torch_groupgemm
(
...
...
@@ -347,6 +350,7 @@ def test_punica_expand_nslices(
device
,
)
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
...
...
@@ -364,6 +368,7 @@ def test_punica_expand_nslices(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
slice_offset
,
hidden_size
,
add_inputs
=
True
,
...
...
tests/lora/test_punica_variation.py
View file @
9b0e3ec9
...
...
@@ -84,6 +84,7 @@ def test_punica_sgmv(
device
,
)
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
...
...
@@ -98,6 +99,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
scaling
,
)
else
:
...
...
@@ -110,6 +112,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
add_inputs
=
True
,
)
ref_torch_groupgemm
(
...
...
@@ -262,6 +265,7 @@ def test_punica_expand_nslices(
device
,
)
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
...
...
@@ -279,6 +283,7 @@ def test_punica_expand_nslices(
lora_indices_tensor
,
batches
,
max_seq_length
,
token_nums
,
slice_offset
,
hidden_size
,
add_inputs
=
True
,
...
...
vllm/lora/ops/bgmv_expand.py
View file @
9b0e3ec9
...
...
@@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False
.
adds the final lora
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
...
vllm/lora/ops/bgmv_expand_slice.py
View file @
9b0e3ec9
...
...
@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offst (int): output_tensor's offst
slice_offs
e
t (int): output_tensor's offs
e
t
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
...
...
vllm/lora/ops/sgmv_expand.py
View file @
9b0e3ec9
...
...
@@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
add_inputs
:
bool
=
False
,
)
->
None
:
"""
...
...
@@ -115,17 +116,19 @@ def _sgmv_expand(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
length of the sequences
in the batch
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
add_inputs (bool, optional): Defaults to False. adds the final lora
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
...
...
@@ -134,6 +137,7 @@ def _sgmv_expand(
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_expand_slice.py
View file @
9b0e3ec9
...
...
@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
slice_offset
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
False
,
...
...
@@ -124,9 +125,9 @@ def _sgmv_expand_slice(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
...
...
@@ -134,10 +135,12 @@ def _sgmv_expand_slice(
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
slice_offst (int): output_tensor's offst
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional):
Defaults to False
.
adds the final lora
results to the output.
.
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
...
@@ -145,6 +148,7 @@ def _sgmv_expand_slice(
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_shrink.py
View file @
9b0e3ec9
...
...
@@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
scaling
:
float
,
)
->
None
:
"""
...
...
@@ -120,16 +121,18 @@ def _sgmv_shrink(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
length of the sequences
in the batch
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert
inputs
.
dtype
==
lora_a_weights
.
dtype
...
...
@@ -138,6 +141,7 @@ def _sgmv_shrink(
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_a_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/punica.py
View file @
9b0e3ec9
...
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
def
compute_meta
(
token_lora_tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
bool
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
,
bool
]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
...
...
@@ -43,7 +43,7 @@ def compute_meta(
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
max_length
=
seq_length_tensor
.
max
().
item
()
token_nums
=
seq_length_tensor
.
sum
().
item
()
batch_size
=
lora_indices_tensor
.
size
(
0
)
no_lora
=
False
# -1 means no lora should be applied. Use `no_lora` to determine whether
...
...
@@ -52,7 +52,7 @@ def compute_meta(
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
no_lora
=
True
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
# TODO see if this can be vectorized
...
...
@@ -216,6 +216,7 @@ class PunicaWrapper:
dtype
=
torch
.
long
,
device
=
device
)
self
.
max_length
:
int
=
0
self
.
token_nums
:
int
=
0
self
.
batch_size
:
int
=
-
1
self
.
is_prefill
=
False
self
.
no_lora
=
False
...
...
@@ -276,13 +277,13 @@ class PunicaWrapper:
long_lora_offsets_tensor
)
else
:
self
.
_long_lora_indices
.
zero_
()
self
.
indices_len
[:]
=
indices_len
def
_update_prefill_metada
(
self
,
token_lora_tensor
:
torch
.
Tensor
)
->
None
:
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
self
.
_seq_start_locs
[:
b_seq_start_tensor
.
shape
[
0
]].
copy_
(
b_seq_start_tensor
)
...
...
@@ -291,25 +292,28 @@ class PunicaWrapper:
lora_indices_tensor
)
self
.
batch_size
=
batch_size
self
.
max_length
=
max_length
self
.
token_nums
=
token_nums
self
.
no_lora
=
no_lora
@
property
def
prefill_metadata
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
]:
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
2. seq_lengths: Tensor of sequence lengths
1. seq_start_locs: Tensor of sequence start positions
.
2. seq_lengths: Tensor of sequence lengths
.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
5. max_length: The maximum sequence length in the batch
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return
(
self
.
_seq_start_locs
[:
self
.
batch_size
],
self
.
_seq_lengths
[:
self
.
batch_size
],
self
.
_lora_indices_per_batch
[:
self
.
batch_size
],
self
.
batch_size
,
self
.
max_length
)
self
.
batch_size
,
self
.
max_length
,
self
.
token_nums
)
@
property
def
token_lora_indices
(
self
)
->
torch
.
Tensor
:
...
...
@@ -324,7 +328,7 @@ class PunicaWrapper:
def
sampler_indices
(
self
)
->
torch
.
Tensor
:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
LogitsProcessorWithLoRA
.
"""
sampler_indices_len
=
self
.
indices_len
[
1
]
return
self
.
_sampler_indices
[:
sampler_indices_len
]
...
...
@@ -332,7 +336,7 @@ class PunicaWrapper:
@
property
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to padded sampler indices
This property provides access to padded sampler indices
.
"""
indices_padded_len
=
self
.
indices_len
[
2
]
return
self
.
_sampler_indices_padded
[:
indices_padded_len
]
...
...
@@ -341,7 +345,7 @@ class PunicaWrapper:
def
embeddings_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
specifically for VocabParallelEmbeddingWithLoRA
.
"""
embeddings_indices_len
=
self
.
indices_len
[
3
]
return
self
.
_embeddings_indices
[:,
:
embeddings_indices_len
]
...
...
@@ -350,7 +354,7 @@ class PunicaWrapper:
def
long_lora_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
lora, specifically for LinearScalingRotaryEmbeddingWithLora
.
"""
long_lora_len
=
self
.
indices_len
[
4
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
...
...
@@ -524,7 +528,7 @@ class PunicaWrapper:
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice.
.
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
y_org
=
y
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment