Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9b0e3ec9
Unverified
Commit
9b0e3ec9
authored
Sep 24, 2024
by
Jee Jee Li
Committed by
GitHub
Sep 23, 2024
Browse files
[Kernel][LoRA] Add assertion for punica sgmv kernels (#7585)
parent
86e9c8df
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
64 additions
and
38 deletions
+64
-38
tests/lora/test_punica_sizes.py
tests/lora/test_punica_sizes.py
+5
-0
tests/lora/test_punica_variation.py
tests/lora/test_punica_variation.py
+5
-0
vllm/lora/ops/bgmv_expand.py
vllm/lora/ops/bgmv_expand.py
+1
-1
vllm/lora/ops/bgmv_expand_slice.py
vllm/lora/ops/bgmv_expand_slice.py
+1
-1
vllm/lora/ops/sgmv_expand.py
vllm/lora/ops/sgmv_expand.py
+10
-6
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+11
-7
vllm/lora/ops/sgmv_shrink.py
vllm/lora/ops/sgmv_shrink.py
+10
-6
vllm/lora/punica.py
vllm/lora/punica.py
+21
-17
No files found.
tests/lora/test_punica_sizes.py
View file @
9b0e3ec9
...
@@ -169,6 +169,7 @@ def test_punica_sgmv(
...
@@ -169,6 +169,7 @@ def test_punica_sgmv(
device
,
device
,
)
)
max_seq_length
=
seq_len_tensor
.
max
()
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
else
:
...
@@ -183,6 +184,7 @@ def test_punica_sgmv(
...
@@ -183,6 +184,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
lora_indices_tensor
,
batches
,
batches
,
max_seq_length
,
max_seq_length
,
token_nums
,
scaling
,
scaling
,
)
)
else
:
else
:
...
@@ -195,6 +197,7 @@ def test_punica_sgmv(
...
@@ -195,6 +197,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
lora_indices_tensor
,
batches
,
batches
,
max_seq_length
,
max_seq_length
,
token_nums
,
add_inputs
=
True
,
add_inputs
=
True
,
)
)
ref_torch_groupgemm
(
ref_torch_groupgemm
(
...
@@ -347,6 +350,7 @@ def test_punica_expand_nslices(
...
@@ -347,6 +350,7 @@ def test_punica_expand_nslices(
device
,
device
,
)
)
max_seq_length
=
seq_len_tensor
.
max
()
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
else
:
...
@@ -364,6 +368,7 @@ def test_punica_expand_nslices(
...
@@ -364,6 +368,7 @@ def test_punica_expand_nslices(
lora_indices_tensor
,
lora_indices_tensor
,
batches
,
batches
,
max_seq_length
,
max_seq_length
,
token_nums
,
slice_offset
,
slice_offset
,
hidden_size
,
hidden_size
,
add_inputs
=
True
,
add_inputs
=
True
,
...
...
tests/lora/test_punica_variation.py
View file @
9b0e3ec9
...
@@ -84,6 +84,7 @@ def test_punica_sgmv(
...
@@ -84,6 +84,7 @@ def test_punica_sgmv(
device
,
device
,
)
)
max_seq_length
=
seq_len_tensor
.
max
()
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
else
:
...
@@ -98,6 +99,7 @@ def test_punica_sgmv(
...
@@ -98,6 +99,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
lora_indices_tensor
,
batches
,
batches
,
max_seq_length
,
max_seq_length
,
token_nums
,
scaling
,
scaling
,
)
)
else
:
else
:
...
@@ -110,6 +112,7 @@ def test_punica_sgmv(
...
@@ -110,6 +112,7 @@ def test_punica_sgmv(
lora_indices_tensor
,
lora_indices_tensor
,
batches
,
batches
,
max_seq_length
,
max_seq_length
,
token_nums
,
add_inputs
=
True
,
add_inputs
=
True
,
)
)
ref_torch_groupgemm
(
ref_torch_groupgemm
(
...
@@ -262,6 +265,7 @@ def test_punica_expand_nslices(
...
@@ -262,6 +265,7 @@ def test_punica_expand_nslices(
device
,
device
,
)
)
max_seq_length
=
seq_len_tensor
.
max
()
max_seq_length
=
seq_len_tensor
.
max
()
token_nums
=
seq_len_tensor
.
sum
().
item
()
if
isinstance
(
max_seq_length
,
tuple
):
if
isinstance
(
max_seq_length
,
tuple
):
max_seq_length
=
max_seq_length
[
0
].
item
()
max_seq_length
=
max_seq_length
[
0
].
item
()
else
:
else
:
...
@@ -279,6 +283,7 @@ def test_punica_expand_nslices(
...
@@ -279,6 +283,7 @@ def test_punica_expand_nslices(
lora_indices_tensor
,
lora_indices_tensor
,
batches
,
batches
,
max_seq_length
,
max_seq_length
,
token_nums
,
slice_offset
,
slice_offset
,
hidden_size
,
hidden_size
,
add_inputs
=
True
,
add_inputs
=
True
,
...
...
vllm/lora/ops/bgmv_expand.py
View file @
9b0e3ec9
...
@@ -100,7 +100,7 @@ def _bgmv_expand(
...
@@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be
corresponding to each batch, An index of -1 means no lora should be
applied.
applied.
batches (int): batch size
batches (int): batch size
add_inputs (bool, optional): Defaults to False
.
adds the final lora
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
results to the output.
"""
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
...
vllm/lora/ops/bgmv_expand_slice.py
View file @
9b0e3ec9
...
@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
...
@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
corresponding to each batch, An index of -1 means no lora should be
applied.
applied.
slice_offst (int): output_tensor's offst
slice_offs
e
t (int): output_tensor's offs
e
t
slice_size (int): current output_tensor's size
slice_size (int): current output_tensor's size
batches (int): batch size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
add_inputs (bool, optional): Defaults to False.
...
...
vllm/lora/ops/sgmv_expand.py
View file @
9b0e3ec9
...
@@ -106,6 +106,7 @@ def _sgmv_expand(
...
@@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
batches
:
int
,
max_seq_length
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
add_inputs
:
bool
=
False
,
add_inputs
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -115,17 +116,19 @@ def _sgmv_expand(
...
@@ -115,17 +116,19 @@ def _sgmv_expand(
output_tensor (torch.Tensor): output tensor
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences
in the batch
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
corresponding to each batch. An index of -1 means no lora should be
applied.
applied.
batches (int): batch size
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences in the
in the batch
batch.
add_inputs (bool, optional): Defaults to False. adds the final lora
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
results to the output.
"""
"""
...
@@ -134,6 +137,7 @@ def _sgmv_expand(
...
@@ -134,6 +137,7 @@ def _sgmv_expand(
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
]
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_expand_slice.py
View file @
9b0e3ec9
...
@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
...
@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
batches
:
int
,
max_seq_length
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
slice_offset
:
int
,
slice_offset
:
int
,
slice_size
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
False
,
add_inputs
:
bool
=
False
,
...
@@ -124,20 +125,22 @@ def _sgmv_expand_slice(
...
@@ -124,20 +125,22 @@ def _sgmv_expand_slice(
output_tensor (torch.Tensor): output tensor
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences
in the batch
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
corresponding to each batch. An index of -1 means no lora should be
applied.
applied.
batches (int): batch size
batches (int): batch size
max_seq_length (int):
The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences
in the batch
in the batch
slice_offst (int): output_tensor's offst
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
slice_size (int): current output_tensor's size
add_inputs (bool, optional):
Defaults to False
.
adds the final lora
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
.
results to the output.
"""
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
@@ -145,6 +148,7 @@ def _sgmv_expand_slice(
...
@@ -145,6 +148,7 @@ def _sgmv_expand_slice(
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
]
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_shrink.py
View file @
9b0e3ec9
...
@@ -110,6 +110,7 @@ def _sgmv_shrink(
...
@@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
batches
:
int
,
max_seq_length
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
scaling
:
float
,
scaling
:
float
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -120,17 +121,19 @@ def _sgmv_shrink(
...
@@ -120,17 +121,19 @@ def _sgmv_shrink(
output_tensor (torch.Tensor): output tensor
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4].
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences
in the batch
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
corresponding to each batch. An index of -1 means no lora should be
applied.
applied.
batches (int): batch size
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences in the
in the batch
batch.
scaling (float): Scaling factor.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
"""
assert
inputs
.
dtype
==
lora_a_weights
.
dtype
assert
inputs
.
dtype
==
lora_a_weights
.
dtype
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
...
@@ -138,6 +141,7 @@ def _sgmv_shrink(
...
@@ -138,6 +141,7 @@ def _sgmv_shrink(
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
]
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_a_weights
.
size
(
-
1
)
assert
inputs
.
size
(
1
)
==
lora_a_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/punica.py
View file @
9b0e3ec9
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
def
compute_meta
(
def
compute_meta
(
token_lora_tensor
:
torch
.
Tensor
token_lora_tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
bool
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
,
bool
]:
"""
"""
Get the information required for the sgmv kernel. With the features:
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
1. If consecutive requests in the batch use the same LoRA, this function
...
@@ -43,7 +43,7 @@ def compute_meta(
...
@@ -43,7 +43,7 @@ def compute_meta(
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
max_length
=
seq_length_tensor
.
max
().
item
()
max_length
=
seq_length_tensor
.
max
().
item
()
token_nums
=
seq_length_tensor
.
sum
().
item
()
batch_size
=
lora_indices_tensor
.
size
(
0
)
batch_size
=
lora_indices_tensor
.
size
(
0
)
no_lora
=
False
no_lora
=
False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# -1 means no lora should be applied. Use `no_lora` to determine whether
...
@@ -52,7 +52,7 @@ def compute_meta(
...
@@ -52,7 +52,7 @@ def compute_meta(
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
no_lora
=
True
no_lora
=
True
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
# TODO see if this can be vectorized
# TODO see if this can be vectorized
...
@@ -178,7 +178,7 @@ def convert_mapping(
...
@@ -178,7 +178,7 @@ def convert_mapping(
class
PunicaWrapper
:
class
PunicaWrapper
:
"""
"""
PunicaWrapper is designed to manage and provide metadata for the punica
PunicaWrapper is designed to manage and provide metadata for the punica
kernel. The main function
is to maintain the state information for
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica kernel.
Multi-LoRA, and to provide the interface for the punica kernel.
"""
"""
...
@@ -216,6 +216,7 @@ class PunicaWrapper:
...
@@ -216,6 +216,7 @@ class PunicaWrapper:
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
self
.
max_length
:
int
=
0
self
.
max_length
:
int
=
0
self
.
token_nums
:
int
=
0
self
.
batch_size
:
int
=
-
1
self
.
batch_size
:
int
=
-
1
self
.
is_prefill
=
False
self
.
is_prefill
=
False
self
.
no_lora
=
False
self
.
no_lora
=
False
...
@@ -276,13 +277,13 @@ class PunicaWrapper:
...
@@ -276,13 +277,13 @@ class PunicaWrapper:
long_lora_offsets_tensor
)
long_lora_offsets_tensor
)
else
:
else
:
self
.
_long_lora_indices
.
zero_
()
self
.
_long_lora_indices
.
zero_
()
self
.
indices_len
[:]
=
indices_len
self
.
indices_len
[:]
=
indices_len
def
_update_prefill_metada
(
self
,
token_lora_tensor
:
torch
.
Tensor
)
->
None
:
def
_update_prefill_metada
(
self
,
token_lora_tensor
:
torch
.
Tensor
)
->
None
:
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
batch_size
,
max_length
,
token_nums
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
self
.
_seq_start_locs
[:
b_seq_start_tensor
.
shape
[
0
]].
copy_
(
self
.
_seq_start_locs
[:
b_seq_start_tensor
.
shape
[
0
]].
copy_
(
b_seq_start_tensor
)
b_seq_start_tensor
)
...
@@ -291,25 +292,28 @@ class PunicaWrapper:
...
@@ -291,25 +292,28 @@ class PunicaWrapper:
lora_indices_tensor
)
lora_indices_tensor
)
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
max_length
=
max_length
self
.
max_length
=
max_length
self
.
token_nums
=
token_nums
self
.
no_lora
=
no_lora
self
.
no_lora
=
no_lora
@
property
@
property
def
prefill_metadata
(
def
prefill_metadata
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
]:
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
]:
"""
"""
This property provides a convenient way to access the necessary
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
1. seq_start_locs: Tensor of sequence start positions
.
2. seq_lengths: Tensor of sequence lengths
2. seq_lengths: Tensor of sequence lengths
.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
"""
return
(
self
.
_seq_start_locs
[:
self
.
batch_size
],
return
(
self
.
_seq_start_locs
[:
self
.
batch_size
],
self
.
_seq_lengths
[:
self
.
batch_size
],
self
.
_seq_lengths
[:
self
.
batch_size
],
self
.
_lora_indices_per_batch
[:
self
.
batch_size
],
self
.
_lora_indices_per_batch
[:
self
.
batch_size
],
self
.
batch_size
,
self
.
max_length
)
self
.
batch_size
,
self
.
max_length
,
self
.
token_nums
)
@
property
@
property
def
token_lora_indices
(
self
)
->
torch
.
Tensor
:
def
token_lora_indices
(
self
)
->
torch
.
Tensor
:
...
@@ -324,7 +328,7 @@ class PunicaWrapper:
...
@@ -324,7 +328,7 @@ class PunicaWrapper:
def
sampler_indices
(
self
)
->
torch
.
Tensor
:
def
sampler_indices
(
self
)
->
torch
.
Tensor
:
"""
"""
This property is used to access the lora indices specifically for
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
LogitsProcessorWithLoRA
.
"""
"""
sampler_indices_len
=
self
.
indices_len
[
1
]
sampler_indices_len
=
self
.
indices_len
[
1
]
return
self
.
_sampler_indices
[:
sampler_indices_len
]
return
self
.
_sampler_indices
[:
sampler_indices_len
]
...
@@ -332,7 +336,7 @@ class PunicaWrapper:
...
@@ -332,7 +336,7 @@ class PunicaWrapper:
@
property
@
property
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
"""
"""
This property provides access to padded sampler indices
This property provides access to padded sampler indices
.
"""
"""
indices_padded_len
=
self
.
indices_len
[
2
]
indices_padded_len
=
self
.
indices_len
[
2
]
return
self
.
_sampler_indices_padded
[:
indices_padded_len
]
return
self
.
_sampler_indices_padded
[:
indices_padded_len
]
...
@@ -341,7 +345,7 @@ class PunicaWrapper:
...
@@ -341,7 +345,7 @@ class PunicaWrapper:
def
embeddings_indices
(
self
)
->
torch
.
Tensor
:
def
embeddings_indices
(
self
)
->
torch
.
Tensor
:
"""
"""
This property provides access to the indices used for lora embeddings,
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
specifically for VocabParallelEmbeddingWithLoRA
.
"""
"""
embeddings_indices_len
=
self
.
indices_len
[
3
]
embeddings_indices_len
=
self
.
indices_len
[
3
]
return
self
.
_embeddings_indices
[:,
:
embeddings_indices_len
]
return
self
.
_embeddings_indices
[:,
:
embeddings_indices_len
]
...
@@ -350,7 +354,7 @@ class PunicaWrapper:
...
@@ -350,7 +354,7 @@ class PunicaWrapper:
def
long_lora_indices
(
self
)
->
torch
.
Tensor
:
def
long_lora_indices
(
self
)
->
torch
.
Tensor
:
"""
"""
This property provides access to the indices used for long context
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
lora, specifically for LinearScalingRotaryEmbeddingWithLora
.
"""
"""
long_lora_len
=
self
.
indices_len
[
4
]
long_lora_len
=
self
.
indices_len
[
4
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
...
@@ -524,7 +528,7 @@ class PunicaWrapper:
...
@@ -524,7 +528,7 @@ class PunicaWrapper:
scale (float): Scaling factor.
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice.
.
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
"""
y_org
=
y
y_org
=
y
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment