Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca871491
Unverified
Commit
ca871491
authored
Dec 10, 2024
by
Jee Jee Li
Committed by
GitHub
Dec 09, 2024
Browse files
[Misc][LoRA] Abstract PunicaWrapper (#10955)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
3b61cb45
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1058 additions
and
24 deletions
+1058
-24
tests/lora/test_layers.py
tests/lora/test_layers.py
+33
-16
vllm/lora/layers.py
vllm/lora/layers.py
+3
-4
vllm/lora/models.py
vllm/lora/models.py
+4
-4
vllm/lora/punica_wrapper/__init__.py
vllm/lora/punica_wrapper/__init__.py
+7
-0
vllm/lora/punica_wrapper/punica_base.py
vllm/lora/punica_wrapper/punica_base.py
+480
-0
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+358
-0
vllm/lora/punica_wrapper/punica_selector.py
vllm/lora/punica_wrapper/punica_selector.py
+14
-0
vllm/lora/punica_wrapper/utils.py
vllm/lora/punica_wrapper/utils.py
+159
-0
No files found.
tests/lora/test_layers.py
View file @
ca871491
...
@@ -28,7 +28,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
...
@@ -28,7 +28,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
# yapf: enable
# yapf: enable
from
vllm.lora.models
import
(
LongContextLoRAContext
,
LoRALayerWeights
,
from
vllm.lora.models
import
(
LongContextLoRAContext
,
LoRALayerWeights
,
PackedLoRALayerWeights
)
PackedLoRALayerWeights
)
from
vllm.lora.punica
import
P
unica
W
rapper
from
vllm.lora.punica
_wrapper
import
get_p
unica
_w
rapper
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -48,11 +48,12 @@ TOLERANCES = {
...
@@ -48,11 +48,12 @@ TOLERANCES = {
torch
.
float32
:
(
5e-3
,
5e-3
),
torch
.
float32
:
(
5e-3
,
5e-3
),
torch
.
bfloat16
:
(
3e-2
,
2e-2
),
torch
.
bfloat16
:
(
3e-2
,
2e-2
),
}
}
CUDA_DEVICES
=
[
# TODO: Modify this based on platform
DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
#
W
e will launch different triton kernels between the prefill and decode
#
For GPU, w
e will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES
=
[
True
,
False
]
STAGES
=
[
True
,
False
]
...
@@ -192,9 +193,18 @@ def create_random_inputs(
...
@@ -192,9 +193,18 @@ def create_random_inputs(
return
inputs
,
index_mapping
,
prompt_mapping
return
inputs
,
index_mapping
,
prompt_mapping
def
check_punica_wrapper
(
punica_wrapper
)
->
bool
:
if
current_platform
.
is_cuda_alike
():
from
vllm.lora.punica_wrapper.punica_gpu
import
PunicaWrapperGPU
return
type
(
punica_wrapper
)
is
PunicaWrapperGPU
else
:
return
False
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
,
stage
)
->
None
:
def
test_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
,
stage
)
->
None
:
...
@@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
)
...
@@ -296,7 +307,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -296,7 +307,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# @pytest.mark.skip(
# @pytest.mark.skip(
# reason="Fails when loras are in any slot other than the first.")
# reason="Fails when loras are in any slot other than the first.")
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
,
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
,
...
@@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
)
...
@@ -432,7 +444,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -432,7 +444,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
256512
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
256512
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
,
vocab_size
,
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
,
vocab_size
,
...
@@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...
@@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
)
...
@@ -563,7 +576,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...
@@ -563,7 +576,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
def
test_linear_replicated
(
dist_init
,
num_loras
,
device
,
stage
,
def
test_linear_replicated
(
dist_init
,
num_loras
,
device
,
stage
,
...
@@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
...
@@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
max_loras
=
8
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
...
@@ -675,7 +689,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
...
@@ -675,7 +689,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"orientation"
,
[
"row"
,
"column"
])
@
pytest
.
mark
.
parametrize
(
"orientation"
,
[
"row"
,
"column"
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
def
test_linear_parallel
(
dist_init
,
num_loras
,
orientation
,
fully_shard
,
def
test_linear_parallel
(
dist_init
,
num_loras
,
orientation
,
fully_shard
,
...
@@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
max_loras
=
8
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
...
@@ -797,7 +812,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -797,7 +812,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"repeats"
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"repeats"
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias_enabled"
,
[
True
,
False
])
def
test_column_parallel_packed
(
dist_init
,
num_loras
,
repeats
,
fully_shard
,
def
test_column_parallel_packed
(
dist_init
,
num_loras
,
repeats
,
fully_shard
,
...
@@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
...
@@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
max_loras
=
8
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
...
@@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
...
@@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seed
=
0
seed
=
0
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
)
assert
check_punica_wrapper
(
punica_wrapper
)
max_loras
=
8
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
...
...
vllm/lora/layers.py
View file @
ca871491
...
@@ -17,7 +17,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -17,7 +17,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_gather
)
tensor_model_parallel_gather
)
from
vllm.distributed.utils
import
divide
from
vllm.distributed.utils
import
divide
from
vllm.lora.punica
import
PunicaWrapper
# yapf: disable
# yapf: disable
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
LinearBase
,
...
@@ -33,7 +32,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -33,7 +32,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
VocabParallelEmbedding
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
pass
from
vllm.lora.punica_wrapper
import
PunicaWrapperBase
def
_get_lora_device
(
base_layer
:
nn
.
Module
)
->
torch
.
device
:
def
_get_lora_device
(
base_layer
:
nn
.
Module
)
->
torch
.
device
:
...
@@ -115,9 +114,9 @@ class BaseLayerWithLoRA(nn.Module):
...
@@ -115,9 +114,9 @@ class BaseLayerWithLoRA(nn.Module):
def
set_mapping
(
def
set_mapping
(
self
,
self
,
punica_wrapper
:
PunicaWrapper
,
punica_wrapper
,
):
):
self
.
punica_wrapper
:
PunicaWrapper
=
punica_wrapper
self
.
punica_wrapper
:
PunicaWrapper
Base
=
punica_wrapper
@
classmethod
@
classmethod
def
can_replace_layer
(
def
can_replace_layer
(
...
...
vllm/lora/models.py
View file @
ca871491
...
@@ -21,7 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
...
@@ -21,7 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora
,
LinearScalingRotaryEmbeddingWithLora
,
LoRAMapping
)
LoRAMapping
)
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.punica
import
P
unica
W
rapper
from
vllm.lora.punica
_wrapper
import
get_p
unica
_w
rapper
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
is_regex_target_modules
,
is_regex_target_modules
,
parse_fine_tuned_lora_name
,
replace_submodule
)
parse_fine_tuned_lora_name
,
replace_submodule
)
...
@@ -331,7 +331,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -331,7 +331,7 @@ class LoRAModelManager(AdapterModelManager):
self
.
lora_index_to_id
:
List
[
Optional
[
int
]]
=
[
None
]
*
self
.
lora_slots
self
.
lora_index_to_id
:
List
[
Optional
[
int
]]
=
[
None
]
*
self
.
lora_slots
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
long_lora_context
:
Optional
[
LongContextLoRAContext
]
=
None
self
.
long_lora_context
:
Optional
[
LongContextLoRAContext
]
=
None
self
.
punica_wrapper
=
P
unica
W
rapper
(
max_num_batched_tokens
,
self
.
punica_wrapper
=
get_p
unica
_w
rapper
(
max_num_batched_tokens
,
max_batches
=
self
.
max_num_seqs
,
max_batches
=
self
.
max_num_seqs
,
device
=
self
.
device
)
device
=
self
.
device
)
# Scaling factor -> offset to the sin_cos_cache to it.
# Scaling factor -> offset to the sin_cos_cache to it.
...
...
vllm/lora/punica_wrapper/__init__.py
0 → 100644
View file @
ca871491
from
vllm.lora.punica_wrapper.punica_base
import
PunicaWrapperBase
from
vllm.lora.punica_wrapper.punica_selector
import
get_punica_wrapper
__all__
=
[
"PunicaWrapperBase"
,
"get_punica_wrapper"
,
]
vllm/lora/punica.py
→
vllm/lora/punica
_wrapper/punica_base
.py
View file @
ca871491
This diff is collapsed.
Click to expand it.
vllm/lora/punica_wrapper/punica_gpu.py
0 → 100644
View file @
ca871491
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
,
final
import
torch
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.lora.ops.bgmv_expand
import
bgmv_expand
from
vllm.lora.ops.bgmv_expand_slice
import
bgmv_expand_slice
from
vllm.lora.ops.bgmv_shrink
import
bgmv_shrink
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
.punica_base
import
PunicaWrapperBase
@
final
class
PunicaWrapperGPU
(
PunicaWrapperBase
):
"""
PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def
__init__
(
self
,
max_num_batched_tokens
:
int
,
max_batches
:
int
,
device
:
Union
[
torch
.
device
,
str
],
**
kwargs
):
PunicaWrapperBase
.
__init__
(
self
,
max_num_batched_tokens
,
max_batches
,
device
)
def
_shrink_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_shrink
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
scale
,
)
def
_shrink_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
,
):
bgmv_shrink
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
scale
)
def
_expand_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_expand
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
add_input
,
)
def
_expand_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
):
bgmv_expand
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
add_input
)
def
_expand_slice_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_expand_slice
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
y_offset
,
y_slice_size
,
add_input
,
)
def
_expand_slice_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
):
bgmv_expand_slice
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
y_offset
,
y_slice_size
,
add_input
)
def
_apply_expand
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
=
True
,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun
:
Callable
=
(
self
.
_expand_slice_prefill
if
self
.
is_prefill
else
self
.
_expand_slice_decode
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
)
def
_apply_shrink
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
shrink_fun
:
Callable
=
(
self
.
_shrink_prefill
if
self
.
is_prefill
else
self
.
_shrink_decode
)
shrink_fun
(
y
,
x
,
w_t_all
,
scale
)
y
=
y
.
view_as
(
y_org
)
def
add_shrink
(
self
,
y
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
scale
:
float
,
**
kwargs
):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
# TODO fuse these kernels
for
slice_idx
in
range
(
len
(
lora_a_stacked
)):
self
.
_apply_shrink
(
y
[
slice_idx
],
x
,
lora_a_stacked
[
slice_idx
],
scale
)
def
add_expand
(
self
,
y
:
torch
.
Tensor
,
x
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
output_slices
:
Tuple
[
int
,
...],
offset_start
:
int
=
0
,
add_input
=
True
,
**
kwargs
)
->
None
:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
offset_left
=
offset_start
if
lora_bias_stacked
is
not
None
:
self
.
_apply_bias
(
self
.
token_lora_indices
,
y
,
output_slices
,
lora_bias_stacked
)
for
slice_idx
in
range
(
len
(
lora_b_stacked
)):
self
.
_apply_expand
(
y
,
x
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
offset_left
,
output_slices
[
slice_idx
],
add_input
=
add_input
,
)
offset_left
+=
output_slices
[
slice_idx
]
y
=
y
.
view_as
(
y_org
)
def
add_lora_embedding
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
add_input
:
bool
=
True
,
**
kwargs
)
->
None
:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun
:
Callable
=
(
self
.
_expand_prefill
if
self
.
is_prefill
else
self
.
_expand_decode
)
expand_fun
(
y
,
x
,
lora_b_stacked
,
add_input
)
def
add_lora_linear
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
...],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
scale
:
float
,
output_slices
:
Tuple
[
int
,
...],
*
,
buffer
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]]
=
None
,
**
kwargs
)
->
None
:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert
len
(
lora_a_stacked
)
==
len
(
lora_b_stacked
)
==
len
(
output_slices
)
if
lora_bias_stacked
is
not
None
:
assert
len
(
lora_bias_stacked
)
==
len
(
output_slices
)
y
=
self
.
_apply_bias
(
self
.
token_lora_indices
,
y
,
output_slices
,
lora_bias_stacked
)
if
buffer
is
None
:
r
=
lora_b_stacked
[
0
].
size
(
-
1
)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
tuple
(
torch
.
zeros
(
(
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
for
_
in
range
(
len
(
output_slices
)))
self
.
add_shrink
(
buffer
,
x
,
lora_a_stacked
,
scale
,
**
kwargs
)
self
.
add_expand
(
y
,
buffer
,
lora_b_stacked
,
None
,
output_slices
,
add_input
=
True
,
**
kwargs
)
def
add_lora_logits
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
scale
,
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
)
->
None
:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
r
=
lora_b_stacked
.
size
(
-
1
)
if
buffer
is
None
:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink
(
x
,
lora_a_stacked
,
buffer
,
self
.
sampler_indices
,
scale
)
bgmv_expand
(
buffer
,
lora_b_stacked
,
y
,
self
.
sampler_indices
,
add_inputs
=
True
)
y
=
y
.
view_as
(
y_org
)
vllm/lora/punica_wrapper/punica_selector.py
0 → 100644
View file @
ca871491
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_info_once
from
.punica_base
import
PunicaWrapperBase
def
get_punica_wrapper
(
*
args
,
**
kwargs
)
->
PunicaWrapperBase
:
if
current_platform
.
is_cuda_alike
():
# Lazy import to avoid ImportError
from
vllm.lora.punica_wrapper.punica_gpu
import
PunicaWrapperGPU
print_info_once
(
"Using PunicaWrapperGPU."
)
return
PunicaWrapperGPU
(
*
args
,
**
kwargs
)
else
:
raise
NotImplementedError
vllm/lora/punica_wrapper/utils.py
0 → 100644
View file @
ca871491
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
torch
if
TYPE_CHECKING
:
# avoid circuit import
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.models
import
LongContextLoRAContext
def
compute_meta
(
token_lora_tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
int
,
bool
]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
lora_indices_tensor
,
seq_length_tensor
=
torch
.
unique_consecutive
(
token_lora_tensor
,
return_counts
=
True
)
cum_result
=
torch
.
cumsum
(
seq_length_tensor
,
dim
=
0
)
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
max_length
=
seq_length_tensor
.
max
().
item
()
token_nums
=
seq_length_tensor
.
sum
().
item
()
batch_size
=
lora_indices_tensor
.
size
(
0
)
no_lora
=
False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
no_lora
=
True
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
token_nums
,
no_lora
)
# TODO see if this can be vectorized
def
convert_mapping
(
mapping
:
"LoRAMapping"
,
lora_index_to_id
:
List
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
device
:
torch
.
device
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
List
[
int
]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
index_mapping_indices
:
List
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
embedding_indices
=
index_mapping_indices
.
copy
()
lora_indices
=
index_mapping_indices
.
copy
()
long_lora_offsets
:
Optional
[
torch
.
Tensor
]
=
None
if
long_lora_context
:
long_lora_offsets
=
torch
.
zeros
(
len
(
index_mapping_indices
),
device
=
device
,
dtype
=
torch
.
long
)
prompt_mapping
:
List
[
int
]
=
[
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
for
x
in
mapping
.
prompt_mapping
]
lora_idx
=
None
for
i
in
range
(
len
(
index_mapping_indices
)):
# TODO index can be slow. optimize
lora_idx
=
(
lora_index_to_id
.
index
(
index_mapping_indices
[
i
])
if
index_mapping_indices
[
i
]
>
0
else
-
1
)
embedding_indices
[
i
]
=
lora_idx
if
index_mapping_indices
[
i
]
>
0
else
0
lora_indices
[
i
]
=
lora_idx
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
lora_offset
:
int
=
long_lora_context
.
offsets_by_lora_id
.
get
(
index_mapping_indices
[
i
],
0
)
long_lora_offsets
[
i
]
=
lora_offset
indices_list
:
List
[
Union
[
List
[
int
],
torch
.
Tensor
]]
=
[
index_mapping_indices
,
lora_indices
,
embedding_indices
,
]
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
indices_list
.
append
(
long_lora_offsets
)
indices
=
torch
.
tensor
(
indices_list
,
dtype
=
torch
.
long
,
device
=
device
)
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
embeddings_indices
=
torch
.
stack
([
indices
[
2
]
*
extra_vocab_size
,
indices
[
2
]
*
(
vocab_size
+
extra_vocab_size
),
])
embeddings_indices
[
embeddings_indices
==
-
1
]
=
max_loras
-
1
base_indices
=
indices
[
1
]
sampler_indices
=
prompt_mapping_tensor
sampler_indices_padded
=
sampler_indices
.
clone
()
sampler_indices_padded
[
sampler_indices_padded
==
-
1
]
=
max_loras
-
1
sampler_indices_padded
=
torch
.
arange
(
0
,
len
(
sampler_indices_padded
),
device
=
device
,
dtype
=
torch
.
long
)
+
(
sampler_indices_padded
*
len
(
sampler_indices_padded
))
long_lora_indices
=
None
long_lora_indices_len
:
Optional
[
int
]
=
None
if
long_lora_context
:
long_lora_indices
=
indices
[
3
]
long_lora_indices_len
=
long_lora_indices
.
shape
[
-
1
]
# Contain length of indices tensors. Used to index into each tensor.
indices_len
=
[
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
],
]
if
long_lora_indices_len
is
not
None
:
indices_len
.
append
(
long_lora_indices_len
)
else
:
# If long_lora doesn't exist,append None
indices_len
.
append
(
None
)
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
long_lora_indices
,
indices_len
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment