Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca871491
Unverified
Commit
ca871491
authored
Dec 10, 2024
by
Jee Jee Li
Committed by
GitHub
Dec 09, 2024
Browse files
[Misc][LoRA] Abstract PunicaWrapper (#10955)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
3b61cb45
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1058 additions
and
24 deletions
+1058
-24
tests/lora/test_layers.py
tests/lora/test_layers.py
+33
-16
vllm/lora/layers.py
vllm/lora/layers.py
+3
-4
vllm/lora/models.py
vllm/lora/models.py
+4
-4
vllm/lora/punica_wrapper/__init__.py
vllm/lora/punica_wrapper/__init__.py
+7
-0
vllm/lora/punica_wrapper/punica_base.py
vllm/lora/punica_wrapper/punica_base.py
+480
-0
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+358
-0
vllm/lora/punica_wrapper/punica_selector.py
vllm/lora/punica_wrapper/punica_selector.py
+14
-0
vllm/lora/punica_wrapper/utils.py
vllm/lora/punica_wrapper/utils.py
+159
-0
No files found.
tests/lora/test_layers.py
View file @
ca871491
...
@@ -28,7 +28,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
...
@@ -28,7 +28,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
# yapf: enable
# yapf: enable
from
vllm.lora.models
import
(
LongContextLoRAContext
,
LoRALayerWeights
,
from
vllm.lora.models
import
(
LongContextLoRAContext
,
LoRALayerWeights
,
PackedLoRALayerWeights
)
PackedLoRALayerWeights
)
from
vllm.lora.punica
import
P
unica
W
rapper
from
vllm.lora.punica
_wrapper
import
get_p
unica
_w
rapper
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -48,11 +48,12 @@ TOLERANCES = {
...
@@ -48,11 +48,12 @@ TOLERANCES = {
torch
.
float32
:
(
5e-3
,
5e-3
),
torch
.
float32
:
(
5e-3
,
5e-3
),
torch
.
bfloat16
:
(
3e-2
,
2e-2
),
torch
.
bfloat16
:
(
3e-2
,
2e-2
),
}
}
CUDA_DEVICES
=
[
# TODO: Modify this based on platform
DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
#
W
e will launch different triton kernels between the prefill and decode
#
For GPU, w
e will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES
=
[
True
,
False
]
STAGES
=
[
True
,
False
]
...
@@ -192,9 +193,18 @@ def create_random_inputs(
...
@@ -192,9 +193,18 @@ def create_random_inputs(
return
inputs
,
index_mapping
,
prompt_mapping
return
inputs
,
index_mapping
,
prompt_mapping
def
check_punica_wrapper
(
punica_wrapper
)
->
bool
:
if
current_platform
.
is_cuda_alike
():
from
vllm.lora.punica_wrapper.punica_gpu
import
PunicaWrapperGPU
return
type
(
punica_wrapper
)
is
PunicaWrapperGPU
else
:
return
False
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
,
stage
)
->
None
:
def
test_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
,
stage
)
->
None
:
...
@@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
)
...
@@ -296,7 +307,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -296,7 +307,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# @pytest.mark.skip(
# @pytest.mark.skip(
# reason="Fails when loras are in any slot other than the first.")
# reason="Fails when loras are in any slot other than the first.")
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
,
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
,
...
@@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
)
...
@@ -432,7 +444,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -432,7 +444,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
256512
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
256512
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
,
vocab_size
,
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
,
vocab_size
,
...
@@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...
@@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
)
...
@@ -563,7 +576,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...
@@ -563,7 +576,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
def
test_linear_replicated
(
dist_init
,
num_loras
,
device
,
stage
,
def
test_linear_replicated
(
dist_init
,
num_loras
,
device
,
stage
,
...
@@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
...
@@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
max_loras
=
8
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
...
@@ -675,7 +689,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
...
@@ -675,7 +689,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"orientation"
,
[
"row"
,
"column"
])
@
pytest
.
mark
.
parametrize
(
"orientation"
,
[
"row"
,
"column"
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
def
test_linear_parallel
(
dist_init
,
num_loras
,
orientation
,
fully_shard
,
def
test_linear_parallel
(
dist_init
,
num_loras
,
orientation
,
fully_shard
,
...
@@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
max_loras
=
8
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
...
@@ -797,7 +812,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -797,7 +812,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"repeats"
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"repeats"
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
def
test_column_parallel_packed
(
dist_init
,
num_loras
,
repeats
,
fully_shard
,
def
test_column_parallel_packed
(
dist_init
,
num_loras
,
repeats
,
fully_shard
,
...
@@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
...
@@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
max_loras
=
8
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
...
@@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
...
@@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seed
=
0
seed
=
0
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
max_loras
=
8
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
...
...
vllm/lora/layers.py
View file @
ca871491
...
@@ -17,7 +17,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -17,7 +17,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_gather
)
tensor_model_parallel_gather
)
from
vllm.distributed.utils
import
divide
from
vllm.distributed.utils
import
divide
from
vllm.lora.punica
import
PunicaWrapper
# yapf: disable
# yapf: disable
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
LinearBase
,
...
@@ -33,7 +32,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -33,7 +32,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
VocabParallelEmbedding
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
pass
from
vllm.lora.punica_wrapper
import
PunicaWrapperBase
def
_get_lora_device
(
base_layer
:
nn
.
Module
)
->
torch
.
device
:
def
_get_lora_device
(
base_layer
:
nn
.
Module
)
->
torch
.
device
:
...
@@ -115,9 +114,9 @@ class BaseLayerWithLoRA(nn.Module):
...
@@ -115,9 +114,9 @@ class BaseLayerWithLoRA(nn.Module):
def
set_mapping
(
def
set_mapping
(
self
,
self
,
punica_wrapper
:
PunicaWrapper
,
punica_wrapper
,
):
):
self
.
punica_wrapper
:
PunicaWrapper
=
punica_wrapper
self
.
punica_wrapper
:
PunicaWrapper
Base
=
punica_wrapper
@
classmethod
@
classmethod
def
can_replace_layer
(
def
can_replace_layer
(
...
...
vllm/lora/models.py
View file @
ca871491
...
@@ -21,7 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
...
@@ -21,7 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora
,
LinearScalingRotaryEmbeddingWithLora
,
LoRAMapping
)
LoRAMapping
)
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.punica
import
P
unica
W
rapper
from
vllm.lora.punica
_wrapper
import
get_p
unica
_w
rapper
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
is_regex_target_modules
,
is_regex_target_modules
,
parse_fine_tuned_lora_name
,
replace_submodule
)
parse_fine_tuned_lora_name
,
replace_submodule
)
...
@@ -331,9 +331,9 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -331,9 +331,9 @@ class LoRAModelManager(AdapterModelManager):
self
.
lora_index_to_id
:
List
[
Optional
[
int
]]
=
[
None
]
*
self
.
lora_slots
self
.
lora_index_to_id
:
List
[
Optional
[
int
]]
=
[
None
]
*
self
.
lora_slots
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
long_lora_context
:
Optional
[
LongContextLoRAContext
]
=
None
self
.
long_lora_context
:
Optional
[
LongContextLoRAContext
]
=
None
self
.
punica_wrapper
=
P
unica
W
rapper
(
max_num_batched_tokens
,
self
.
punica_wrapper
=
get_p
unica
_w
rapper
(
max_num_batched_tokens
,
max_batches
=
self
.
max_num_seqs
,
max_batches
=
self
.
max_num_seqs
,
device
=
self
.
device
)
device
=
self
.
device
)
# Scaling factor -> offset to the sin_cos_cache to it.
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
# Used for long context lora.
self
.
scaling_factor_to_offset
:
Dict
[
float
,
int
]
=
{}
self
.
scaling_factor_to_offset
:
Dict
[
float
,
int
]
=
{}
...
...
vllm/lora/punica_wrapper/__init__.py
0 → 100644
View file @
ca871491
from
vllm.lora.punica_wrapper.punica_base
import
PunicaWrapperBase
from
vllm.lora.punica_wrapper.punica_selector
import
get_punica_wrapper
__all__
=
[
"PunicaWrapperBase"
,
"get_punica_wrapper"
,
]
vllm/lora/punica.py
→
vllm/lora/punica
_wrapper/punica_base
.py
View file @
ca871491
...
@@ -5,19 +5,12 @@ Punica: Multi-Tenant LoRA Serving.
...
@@ -5,19 +5,12 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
https://arxiv.org/abs/2310.18547
"""
"""
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Tuple
,
Union
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
vllm.triton_utils
import
HAS_TRITON
from
.utils
import
compute_meta
,
convert_mapping
if
HAS_TRITON
:
from
vllm.lora.ops.bgmv_expand
import
bgmv_expand
from
vllm.lora.ops.bgmv_expand_slice
import
bgmv_expand_slice
from
vllm.lora.ops.bgmv_shrink
import
bgmv_shrink
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
# avoid circuit import
# avoid circuit import
...
@@ -25,166 +18,117 @@ if TYPE_CHECKING:
...
@@ -25,166 +18,117 @@ if TYPE_CHECKING:
from
vllm.lora.models
import
LongContextLoRAContext
from
vllm.lora.models
import
LongContextLoRAContext
def
compute_meta
(
class
PunicaWrapperABC
(
ABC
):
token_lora_tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
,
bool
]:
"""
"""
Get the information required for the sgmv kernel. With the features:
PunicaWrapper ABC.
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
"""
lora_indices_tensor
,
seq_length_tensor
=
torch
.
unique_consecutive
(
@
abstractmethod
token_lora_tensor
,
return_counts
=
True
)
def
update_metadata
(
cum_result
=
torch
.
cumsum
(
seq_length_tensor
,
dim
=
0
)
self
,
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
mapping
:
"LoRAMapping"
,
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
lora_index_to_id
:
List
[
Optional
[
int
]],
max_length
=
seq_length_tensor
.
max
().
item
()
max_loras
:
int
,
token_nums
=
seq_length_tensor
.
sum
().
item
()
vocab_size
:
int
,
batch_size
=
lora_indices_tensor
.
size
(
0
)
extra_vocab_size
:
int
,
no_lora
=
False
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
# -1 means no lora should be applied. Use `no_lora` to determine whether
**
kwargs
,
# the current step requires LoRA. If LoRA is not needed, the prefill stage
)
->
None
:
# does not need to launch the triton kernel, which can improve performance
"""
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
Update the lora-related metadata
no_lora
=
True
"""
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
raise
NotImplementedError
batch_size
,
max_length
,
token_nums
,
no_lora
)
@
abstractmethod
def
add_shrink
(
# TODO see if this can be vectorized
self
,
def
convert_mapping
(
y
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
mapping
:
"LoRAMapping"
,
x
:
torch
.
Tensor
,
lora_index_to_id
:
List
[
Optional
[
int
]],
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
max_loras
:
int
,
scale
:
float
,
vocab_size
:
int
,
**
kwargs
,
extra_vocab_size
:
int
,
)
->
None
:
device
:
torch
.
device
,
"""
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
Performs GEMM for multiple slices of lora_a.
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
"""
Optional
[
torch
.
Tensor
],
List
[
int
]]:
"""Converts LoRAMapping to index tensors.
raise
NotImplementedError
Args:
@
abstractmethod
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
def
add_expand
(
lora_index_to_id: List mapping LoRA ids to LoRA indices.
self
,
max_loras: Maximum number of LoRAs.
y
:
torch
.
Tensor
,
vocab_size: Model vocab size.
x
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
extra_vocab_size: Extra vocab size each LoRA can have.
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
long_lora_context: Passed if there are long context lora in a batch.
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
output_slices
:
Tuple
[
int
,
...],
Returns:
offset_start
:
int
=
0
,
A tuple of tensors:
add_input
=
True
,
base_indices: Tensor of shape [batch_size] mapping batch rows to
**
kwargs
,
LoRA indices.
)
->
None
:
sampler_indices: Tensor of shape [batch_size] mapping requests to
"""
LoRA indices for sampler. For generation, this will be the
Performs GEMM and bias addition for multiple slices of lora_b.
same as base_indicies. For prefill, this will map requests
"""
to LoRA indices.
raise
NotImplementedError
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
@
abstractmethod
Same as sampler_indicies, but -1 is replaced with
def
add_lora_embedding
(
max_loras.
self
,
embeddings_indices: Tensor of shape [2, batch_size] mapping
y
:
torch
.
Tensor
,
requests to embedding indices. First row is for embeddings
x
:
torch
.
Tensor
,
added by the LoRAs, second row is for the LoRA.lora_a
lora_b_stacked
:
torch
.
Tensor
,
embeddings.
add_input
:
bool
=
True
,
long_lora_indices: Tensor of shape [batch_size] mapping
**
kwargs
,
requests to RoPE offsets and rot dims for long LoRAs.
)
->
None
:
None if long context lora doesn't exist.
"""
indices_len: List of lengths of the above tensors. It contains
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
(base_indices, sampler_indices, sampler_indices_padded,
and this layer only requires the expand operation.
embeddings_indices, long_lora_indices).
"""
"""
raise
NotImplementedError
index_mapping_indices
:
List
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
embedding_indices
=
index_mapping_indices
.
copy
()
@
abstractmethod
lora_indices
=
index_mapping_indices
.
copy
()
def
add_lora_linear
(
self
,
long_lora_offsets
:
Optional
[
torch
.
Tensor
]
=
None
y
:
torch
.
Tensor
,
if
long_lora_context
:
x
:
torch
.
Tensor
,
long_lora_offsets
=
torch
.
zeros
(
len
(
index_mapping_indices
),
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
device
=
device
,
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
dtype
=
torch
.
long
)
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
prompt_mapping
:
List
[
int
]
=
[
scale
:
float
,
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
output_slices
:
Tuple
[
int
,
...],
for
x
in
mapping
.
prompt_mapping
*
,
]
buffer
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]]
=
None
,
lora_idx
=
None
**
kwargs
)
->
None
:
for
i
in
range
(
len
(
index_mapping_indices
)):
"""
# TODO index can be slow. optimize
Applicable to linear-related lora.
lora_idx
=
(
lora_index_to_id
.
index
(
index_mapping_indices
[
i
])
"""
if
index_mapping_indices
[
i
]
>
0
else
-
1
)
embedding_indices
[
i
]
=
lora_idx
if
index_mapping_indices
[
i
]
>
0
else
0
raise
NotImplementedError
lora_indices
[
i
]
=
lora_idx
if
long_lora_context
:
@
abstractmethod
assert
long_lora_offsets
is
not
None
def
add_lora_logits
(
self
,
lora_offset
:
int
=
long_lora_context
.
offsets_by_lora_id
.
get
(
y
:
torch
.
Tensor
,
index_mapping_indices
[
i
],
0
)
x
:
torch
.
Tensor
,
long_lora_offsets
[
i
]
=
lora_offset
lora_a_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
indices_list
:
List
[
Union
[
List
[
int
],
torch
.
Tensor
]]
=
[
scale
,
index_mapping_indices
,
*
,
lora_indices
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
embedding_indices
,
**
kwargs
)
->
None
:
]
"""
if
long_lora_context
:
Applies lora specifically for LogitsProcessorWithLoRA.
assert
long_lora_offsets
is
not
None
"""
indices_list
.
append
(
long_lora_offsets
)
raise
NotImplementedError
indices
=
torch
.
tensor
(
indices_list
,
dtype
=
torch
.
long
,
device
=
device
)
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
dtype
=
torch
.
long
,
class
PunicaWrapperBase
(
PunicaWrapperABC
):
device
=
device
)
embeddings_indices
=
torch
.
stack
([
indices
[
2
]
*
extra_vocab_size
,
indices
[
2
]
*
(
vocab_size
+
extra_vocab_size
),
])
embeddings_indices
[
embeddings_indices
==
-
1
]
=
max_loras
-
1
base_indices
=
indices
[
1
]
sampler_indices
=
prompt_mapping_tensor
sampler_indices_padded
=
sampler_indices
.
clone
()
sampler_indices_padded
[
sampler_indices_padded
==
-
1
]
=
max_loras
-
1
sampler_indices_padded
=
torch
.
arange
(
0
,
len
(
sampler_indices_padded
),
device
=
device
,
dtype
=
torch
.
long
)
+
(
sampler_indices_padded
*
len
(
sampler_indices_padded
))
long_lora_indices
=
None
long_lora_indices_len
:
Optional
[
int
]
=
None
if
long_lora_context
:
long_lora_indices
=
indices
[
3
]
long_lora_indices_len
=
long_lora_indices
.
shape
[
-
1
]
# Contain length of indices tensors. Used to index into each tensor.
indices_len
=
[
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
],
]
if
long_lora_indices_len
is
not
None
:
indices_len
.
append
(
long_lora_indices_len
)
else
:
# If long_lora doesn't exist,append None
indices_len
.
append
(
None
)
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
long_lora_indices
,
indices_len
,
)
class
PunicaWrapper
:
"""
"""
PunicaWrapper is designed to manage and provide metadata for the punica
PunicaWrapper
Base
is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica
kernel
.
Multi-LoRA, and to provide the interface for the punica.
"""
"""
def
__init__
(
self
,
max_num_batched_tokens
:
int
,
max_batches
:
int
,
def
__init__
(
self
,
max_num_batched_tokens
:
int
,
max_batches
:
int
,
device
:
Union
[
torch
.
device
,
str
]):
device
:
Union
[
torch
.
device
,
str
]
,
**
kwargs
):
self
.
_token_lora_indices
=
torch
.
empty
(
max_num_batched_tokens
,
self
.
_token_lora_indices
=
torch
.
empty
(
max_num_batched_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
...
@@ -223,26 +167,6 @@ class PunicaWrapper:
...
@@ -223,26 +167,6 @@ class PunicaWrapper:
self
.
is_prefill
=
False
self
.
is_prefill
=
False
self
.
no_lora
=
False
self
.
no_lora
=
False
def
update_metadata
(
self
,
mapping
:
"LoRAMapping"
,
lora_index_to_id
:
List
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
):
self
.
_update_base_metadata
(
mapping
,
lora_index_to_id
,
max_loras
,
vocab_size
,
extra_vocab_size
,
long_lora_context
)
if
mapping
.
is_prefill
:
# Update metadata required for prefill-related operators.
self
.
_update_prefill_metada
(
self
.
token_lora_indices
)
self
.
is_prefill
=
True
else
:
self
.
is_prefill
=
False
def
_update_base_metadata
(
def
_update_base_metadata
(
self
,
self
,
mapping
:
"LoRAMapping"
,
mapping
:
"LoRAMapping"
,
...
@@ -298,6 +222,38 @@ class PunicaWrapper:
...
@@ -298,6 +222,38 @@ class PunicaWrapper:
self
.
token_nums
=
token_nums
self
.
token_nums
=
token_nums
self
.
no_lora
=
no_lora
self
.
no_lora
=
no_lora
def
_apply_bias
(
self
,
indices
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output_slices
:
Tuple
[
int
,
...],
lora_bias_stacked
:
Tuple
[
Optional
[
torch
.
Tensor
],
...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output
=
output
output
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
])
indices
=
indices
.
view
(
-
1
)
offset_left
=
0
for
slice_idx
,
slice
in
enumerate
(
output_slices
):
bias
=
lora_bias_stacked
[
slice_idx
]
if
bias
is
not
None
:
bias
=
bias
.
view
(
-
1
,
bias
.
shape
[
-
1
])
bias
=
bias
[
indices
]
bias
[
indices
==
-
1
]
=
0
output
[:,
offset_left
:
offset_left
+
slice
]
+=
bias
offset_left
+=
slice
return
output
.
view_as
(
org_output
)
@
property
@
property
def
prefill_metadata
(
def
prefill_metadata
(
self
self
...
@@ -362,180 +318,33 @@ class PunicaWrapper:
...
@@ -362,180 +318,33 @@ class PunicaWrapper:
long_lora_len
=
self
.
indices_len
[
4
]
long_lora_len
=
self
.
indices_len
[
4
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
def
_shrink_prefill
(
def
update_metadata
(
self
,
self
,
y
:
torch
.
Tensor
,
mapping
:
"LoRAMapping"
,
x
:
torch
.
Tensor
,
lora_index_to_id
:
List
[
Optional
[
int
]],
w_t_all
:
torch
.
Tensor
,
max_loras
:
int
,
scale
:
float
,
vocab_size
:
int
,
):
extra_vocab_size
:
int
,
#No LoRA request, so return directly
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
if
self
.
no_lora
:
**
kwargs
):
return
sgmv_shrink
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
scale
,
)
def
_shrink_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
,
):
bgmv_shrink
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
scale
)
def
_expand_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_expand
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
add_input
,
)
def
_expand_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
):
bgmv_expand
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
add_input
)
def
_expand_slice_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_expand_slice
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
y_offset
,
y_slice_size
,
add_input
,
)
def
_expand_slice_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
):
bgmv_expand_slice
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
y_offset
,
y_slice_size
,
add_input
)
def
_apply_expand
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
=
True
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun
:
Callable
=
(
self
.
_expand_slice_prefill
if
self
.
is_prefill
else
self
.
_expand_slice_decode
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
)
def
_apply_bias
(
self
,
indices
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output_slices
:
Tuple
[
int
,
...],
lora_bias_stacked
:
Tuple
[
Optional
[
torch
.
Tensor
],
...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output
=
output
output
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
])
indices
=
indices
.
view
(
-
1
)
offset_left
=
0
for
slice_idx
,
slice
in
enumerate
(
output_slices
):
bias
=
lora_bias_stacked
[
slice_idx
]
if
bias
is
not
None
:
bias
=
bias
.
view
(
-
1
,
bias
.
shape
[
-
1
])
bias
=
bias
[
indices
]
bias
[
indices
==
-
1
]
=
0
output
[:,
offset_left
:
offset_left
+
slice
]
+=
bias
offset_left
+=
slice
return
output
.
view_as
(
org_output
)
def
_apply_shrink
(
self
.
_update_base_metadata
(
mapping
,
lora_index_to_id
,
max_loras
,
self
,
vocab_size
,
extra_vocab_size
,
y
:
torch
.
Tensor
,
long_lora_context
)
x
:
torch
.
Tensor
,
if
mapping
.
is_prefill
:
w_t_all
:
torch
.
Tensor
,
# Update metadata required for prefill-related operators.
scale
:
float
,
self
.
_update_prefill_metada
(
self
.
token_lora_indices
)
):
self
.
is_prefill
=
True
"""
else
:
Perform the ` y+=x@w_t_all` computation, which is suitable for the
self
.
is_prefill
=
False
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
shrink_fun
:
Callable
=
(
self
.
_shrink_prefill
if
self
.
is_prefill
else
self
.
_shrink_decode
)
shrink_fun
(
y
,
x
,
w_t_all
,
scale
)
y
=
y
.
view_as
(
y_org
)
def
add_shrink
(
@
abstractmethod
self
,
def
add_shrink
(
self
,
y
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
y
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
x
:
torch
.
Tensor
,
scale
:
float
,
**
kwargs
)
->
None
:
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
scale
:
float
,
):
"""
"""
Performs GEMM for multiple slices of lora_a.
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
Semantics:
for i in range(len(lora_a_stacked)):
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
y[i] += (x @ lora_a_stacked[i]) * scale
...
@@ -545,24 +354,21 @@ class PunicaWrapper:
...
@@ -545,24 +354,21 @@ class PunicaWrapper:
x (torch.Tensor): Input tensor
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
scale (float): Scaling factor for the operation
"""
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
# TODO fuse these kernels
for
slice_idx
in
range
(
len
(
lora_a_stacked
)):
self
.
_apply_shrink
(
y
[
slice_idx
],
x
,
lora_a_stacked
[
slice_idx
],
scale
)
def
add_expand
(
"""
self
,
# TODO: implement it based on torch ops
y
:
torch
.
Tensor
,
raise
NotImplementedError
x
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
@
abstractmethod
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
def
add_expand
(
self
,
output_slices
:
Tuple
[
int
,
...],
y
:
torch
.
Tensor
,
offset_start
:
int
=
0
,
x
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
add_input
=
True
,
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
)
->
None
:
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
output_slices
:
Tuple
[
int
,
...],
offset_start
:
int
=
0
,
add_input
=
True
,
**
kwargs
)
->
None
:
"""
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Performs GEMM and bias addition for multiple slices of lora_b.
...
@@ -581,35 +387,21 @@ class PunicaWrapper:
...
@@ -581,35 +387,21 @@ class PunicaWrapper:
bias's weight
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
add_input (bool): Defaults to True.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
offset_left
=
offset_start
if
lora_bias_stacked
is
not
None
:
self
.
_apply_bias
(
self
.
token_lora_indices
,
y
,
output_slices
,
lora_bias_stacked
)
for
slice_idx
in
range
(
len
(
lora_b_stacked
)):
self
.
_apply_expand
(
y
,
x
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
offset_left
,
output_slices
[
slice_idx
],
add_input
=
add_input
,
)
offset_left
+=
output_slices
[
slice_idx
]
y
=
y
.
view_as
(
y_org
)
def
add_lora_embedding
(
"""
self
,
# TODO: implement it based on torch ops
y
:
torch
.
Tensor
,
raise
NotImplementedError
x
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
@
abstractmethod
add_input
:
bool
=
True
,
def
add_lora_embedding
(
self
,
):
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
add_input
:
bool
=
True
,
**
kwargs
)
->
None
:
"""
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
and this layer only requires the expand operation.
Semantics:
Semantics:
y += x @ lora_b_stacked
y += x @ lora_b_stacked
...
@@ -618,25 +410,22 @@ class PunicaWrapper:
...
@@ -618,25 +410,22 @@ class PunicaWrapper:
x (torch.Tensor): Input tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
add_input (bool): Default to True.
"""
"""
# TODO: implement it based on torch ops
raise
NotImplementedError
# Embedding layer only need expand op
@
abstractmethod
expand_fun
:
Callable
=
(
self
.
_expand_prefill
def
add_lora_linear
(
self
,
if
self
.
is_prefill
else
self
.
_expand_decode
)
y
:
torch
.
Tensor
,
expand_fun
(
y
,
x
,
lora_b_stacked
,
add_input
)
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
def
add_lora_linear
(
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
self
,
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
y
:
torch
.
Tensor
,
scale
:
float
,
x
:
torch
.
Tensor
,
output_slices
:
Tuple
[
int
,
...],
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
*
,
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
buffer
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]]
=
None
,
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
**
kwargs
)
->
None
:
scale
:
float
,
output_slices
:
Tuple
[
int
,
...],
*
,
buffer
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]]
=
None
)
->
None
:
"""
"""
Applicable to linear-related lora.
Applicable to linear-related lora.
...
@@ -659,29 +448,10 @@ class PunicaWrapper:
...
@@ -659,29 +448,10 @@ class PunicaWrapper:
output_slices (Tuple[int, ...]): Every slice's size.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
"""
# TODO: implement it based on torch ops
raise
NotImplementedError
assert
len
(
lora_a_stacked
)
==
len
(
lora_b_stacked
)
==
len
(
output_slices
)
@
abstractmethod
if
lora_bias_stacked
is
not
None
:
assert
len
(
lora_bias_stacked
)
==
len
(
output_slices
)
y
=
self
.
_apply_bias
(
self
.
token_lora_indices
,
y
,
output_slices
,
lora_bias_stacked
)
if
buffer
is
None
:
r
=
lora_b_stacked
[
0
].
size
(
-
1
)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
tuple
(
torch
.
zeros
(
(
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
for
_
in
range
(
len
(
output_slices
)))
self
.
add_shrink
(
buffer
,
x
,
lora_a_stacked
,
scale
)
self
.
add_expand
(
y
,
buffer
,
lora_b_stacked
,
None
,
output_slices
,
add_input
=
True
)
def
add_lora_logits
(
self
,
def
add_lora_logits
(
self
,
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -689,7 +459,8 @@ class PunicaWrapper:
...
@@ -689,7 +459,8 @@ class PunicaWrapper:
lora_b_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
scale
,
scale
,
*
,
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
)
->
None
:
"""
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Applies lora specifically for LogitsProcessorWithLoRA.
...
@@ -704,22 +475,6 @@ class PunicaWrapper:
...
@@ -704,22 +475,6 @@ class PunicaWrapper:
lora_b_stacked (torch.Tensor):lora_b's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
buffer (Optional[torch.Tensor]):Default to None.
"""
"""
y_org
=
y
# TODO: implement it based on torch ops
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
raise
NotImplementedError
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
r
=
lora_b_stacked
.
size
(
-
1
)
if
buffer
is
None
:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink
(
x
,
lora_a_stacked
,
buffer
,
self
.
sampler_indices
,
scale
)
bgmv_expand
(
buffer
,
lora_b_stacked
,
y
,
self
.
sampler_indices
,
add_inputs
=
True
)
y
=
y
.
view_as
(
y_org
)
vllm/lora/punica_wrapper/punica_gpu.py
0 → 100644
View file @
ca871491
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
,
final
import
torch
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.lora.ops.bgmv_expand
import
bgmv_expand
from
vllm.lora.ops.bgmv_expand_slice
import
bgmv_expand_slice
from
vllm.lora.ops.bgmv_shrink
import
bgmv_shrink
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
.punica_base
import
PunicaWrapperBase
@
final
class
PunicaWrapperGPU
(
PunicaWrapperBase
):
"""
PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def
__init__
(
self
,
max_num_batched_tokens
:
int
,
max_batches
:
int
,
device
:
Union
[
torch
.
device
,
str
],
**
kwargs
):
PunicaWrapperBase
.
__init__
(
self
,
max_num_batched_tokens
,
max_batches
,
device
)
def
_shrink_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_shrink
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
scale
,
)
def
_shrink_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
,
):
bgmv_shrink
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
scale
)
def
_expand_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_expand
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
add_input
,
)
def
_expand_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
):
bgmv_expand
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
add_input
)
def
_expand_slice_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_expand_slice
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
y_offset
,
y_slice_size
,
add_input
,
)
def
_expand_slice_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
):
bgmv_expand_slice
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
y_offset
,
y_slice_size
,
add_input
)
def
_apply_expand
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
=
True
,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun
:
Callable
=
(
self
.
_expand_slice_prefill
if
self
.
is_prefill
else
self
.
_expand_slice_decode
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
)
def
_apply_shrink
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
shrink_fun
:
Callable
=
(
self
.
_shrink_prefill
if
self
.
is_prefill
else
self
.
_shrink_decode
)
shrink_fun
(
y
,
x
,
w_t_all
,
scale
)
y
=
y
.
view_as
(
y_org
)
def
add_shrink
(
self
,
y
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
scale
:
float
,
**
kwargs
):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
# TODO fuse these kernels
for
slice_idx
in
range
(
len
(
lora_a_stacked
)):
self
.
_apply_shrink
(
y
[
slice_idx
],
x
,
lora_a_stacked
[
slice_idx
],
scale
)
def
add_expand
(
self
,
y
:
torch
.
Tensor
,
x
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
output_slices
:
Tuple
[
int
,
...],
offset_start
:
int
=
0
,
add_input
=
True
,
**
kwargs
)
->
None
:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
offset_left
=
offset_start
if
lora_bias_stacked
is
not
None
:
self
.
_apply_bias
(
self
.
token_lora_indices
,
y
,
output_slices
,
lora_bias_stacked
)
for
slice_idx
in
range
(
len
(
lora_b_stacked
)):
self
.
_apply_expand
(
y
,
x
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
offset_left
,
output_slices
[
slice_idx
],
add_input
=
add_input
,
)
offset_left
+=
output_slices
[
slice_idx
]
y
=
y
.
view_as
(
y_org
)
def
add_lora_embedding
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
add_input
:
bool
=
True
,
**
kwargs
)
->
None
:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun
:
Callable
=
(
self
.
_expand_prefill
if
self
.
is_prefill
else
self
.
_expand_decode
)
expand_fun
(
y
,
x
,
lora_b_stacked
,
add_input
)
def
add_lora_linear
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
scale
:
float
,
output_slices
:
Tuple
[
int
,
...],
*
,
buffer
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]]
=
None
,
**
kwargs
)
->
None
:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert
len
(
lora_a_stacked
)
==
len
(
lora_b_stacked
)
==
len
(
output_slices
)
if
lora_bias_stacked
is
not
None
:
assert
len
(
lora_bias_stacked
)
==
len
(
output_slices
)
y
=
self
.
_apply_bias
(
self
.
token_lora_indices
,
y
,
output_slices
,
lora_bias_stacked
)
if
buffer
is
None
:
r
=
lora_b_stacked
[
0
].
size
(
-
1
)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
tuple
(
torch
.
zeros
(
(
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
for
_
in
range
(
len
(
output_slices
)))
self
.
add_shrink
(
buffer
,
x
,
lora_a_stacked
,
scale
,
**
kwargs
)
self
.
add_expand
(
y
,
buffer
,
lora_b_stacked
,
None
,
output_slices
,
add_input
=
True
,
**
kwargs
)
def
add_lora_logits
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
scale
,
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
)
->
None
:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
r
=
lora_b_stacked
.
size
(
-
1
)
if
buffer
is
None
:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink
(
x
,
lora_a_stacked
,
buffer
,
self
.
sampler_indices
,
scale
)
bgmv_expand
(
buffer
,
lora_b_stacked
,
y
,
self
.
sampler_indices
,
add_inputs
=
True
)
y
=
y
.
view_as
(
y_org
)
vllm/lora/punica_wrapper/punica_selector.py
0 → 100644
View file @
ca871491
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_info_once
from
.punica_base
import
PunicaWrapperBase
def
get_punica_wrapper
(
*
args
,
**
kwargs
)
->
PunicaWrapperBase
:
if
current_platform
.
is_cuda_alike
():
# Lazy import to avoid ImportError
from
vllm.lora.punica_wrapper.punica_gpu
import
PunicaWrapperGPU
print_info_once
(
"Using PunicaWrapperGPU."
)
return
PunicaWrapperGPU
(
*
args
,
**
kwargs
)
else
:
raise
NotImplementedError
vllm/lora/punica_wrapper/utils.py
0 → 100644
View file @
ca871491
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
torch
if
TYPE_CHECKING
:
# avoid circuit import
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.models
import
LongContextLoRAContext
def
compute_meta
(
token_lora_tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
,
bool
]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
lora_indices_tensor
,
seq_length_tensor
=
torch
.
unique_consecutive
(
token_lora_tensor
,
return_counts
=
True
)
cum_result
=
torch
.
cumsum
(
seq_length_tensor
,
dim
=
0
)
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
max_length
=
seq_length_tensor
.
max
().
item
()
token_nums
=
seq_length_tensor
.
sum
().
item
()
batch_size
=
lora_indices_tensor
.
size
(
0
)
no_lora
=
False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
no_lora
=
True
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
token_nums
,
no_lora
)
# TODO see if this can be vectorized
def
convert_mapping
(
mapping
:
"LoRAMapping"
,
lora_index_to_id
:
List
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
device
:
torch
.
device
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
List
[
int
]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
index_mapping_indices
:
List
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
embedding_indices
=
index_mapping_indices
.
copy
()
lora_indices
=
index_mapping_indices
.
copy
()
long_lora_offsets
:
Optional
[
torch
.
Tensor
]
=
None
if
long_lora_context
:
long_lora_offsets
=
torch
.
zeros
(
len
(
index_mapping_indices
),
device
=
device
,
dtype
=
torch
.
long
)
prompt_mapping
:
List
[
int
]
=
[
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
for
x
in
mapping
.
prompt_mapping
]
lora_idx
=
None
for
i
in
range
(
len
(
index_mapping_indices
)):
# TODO index can be slow. optimize
lora_idx
=
(
lora_index_to_id
.
index
(
index_mapping_indices
[
i
])
if
index_mapping_indices
[
i
]
>
0
else
-
1
)
embedding_indices
[
i
]
=
lora_idx
if
index_mapping_indices
[
i
]
>
0
else
0
lora_indices
[
i
]
=
lora_idx
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
lora_offset
:
int
=
long_lora_context
.
offsets_by_lora_id
.
get
(
index_mapping_indices
[
i
],
0
)
long_lora_offsets
[
i
]
=
lora_offset
indices_list
:
List
[
Union
[
List
[
int
],
torch
.
Tensor
]]
=
[
index_mapping_indices
,
lora_indices
,
embedding_indices
,
]
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
indices_list
.
append
(
long_lora_offsets
)
indices
=
torch
.
tensor
(
indices_list
,
dtype
=
torch
.
long
,
device
=
device
)
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
embeddings_indices
=
torch
.
stack
([
indices
[
2
]
*
extra_vocab_size
,
indices
[
2
]
*
(
vocab_size
+
extra_vocab_size
),
])
embeddings_indices
[
embeddings_indices
==
-
1
]
=
max_loras
-
1
base_indices
=
indices
[
1
]
sampler_indices
=
prompt_mapping_tensor
sampler_indices_padded
=
sampler_indices
.
clone
()
sampler_indices_padded
[
sampler_indices_padded
==
-
1
]
=
max_loras
-
1
sampler_indices_padded
=
torch
.
arange
(
0
,
len
(
sampler_indices_padded
),
device
=
device
,
dtype
=
torch
.
long
)
+
(
sampler_indices_padded
*
len
(
sampler_indices_padded
))
long_lora_indices
=
None
long_lora_indices_len
:
Optional
[
int
]
=
None
if
long_lora_context
:
long_lora_indices
=
indices
[
3
]
long_lora_indices_len
=
long_lora_indices
.
shape
[
-
1
]
# Contain length of indices tensors. Used to index into each tensor.
indices_len
=
[
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
],
]
if
long_lora_indices_len
is
not
None
:
indices_len
.
append
(
long_lora_indices_len
)
else
:
# If long_lora doesn't exist,append None
indices_len
.
append
(
None
)
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
long_lora_indices
,
indices_len
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment