Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1eaff278
Unverified
Commit
1eaff278
authored
Jul 19, 2025
by
Jee Jee Li
Committed by
GitHub
Jul 19, 2025
Browse files
[V0 deprecation] Remove long context LoRA (#21169)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
cf8cc326
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
35 additions
and
301 deletions
+35
-301
tests/lora/conftest.py
tests/lora/conftest.py
+0
-5
tests/lora/test_peft_helper.py
tests/lora/test_peft_helper.py
+4
-7
vllm/config.py
vllm/config.py
+1
-13
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-5
vllm/lora/layers.py
vllm/lora/layers.py
+0
-90
vllm/lora/models.py
vllm/lora/models.py
+10
-70
vllm/lora/peft_helper.py
vllm/lora/peft_helper.py
+0
-9
vllm/lora/punica_wrapper/punica_base.py
vllm/lora/punica_wrapper/punica_base.py
+9
-36
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+5
-16
vllm/lora/punica_wrapper/punica_tpu.py
vllm/lora/punica_wrapper/punica_tpu.py
+0
-14
vllm/lora/punica_wrapper/utils.py
vllm/lora/punica_wrapper/utils.py
+5
-33
vllm/lora/utils.py
vllm/lora/utils.py
+0
-2
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+1
-1
No files found.
tests/lora/conftest.py
View file @
1eaff278
...
...
@@ -221,11 +221,6 @@ def phi2_lora_files():
return
snapshot_download
(
repo_id
=
"isotr0py/phi-2-test-sql-lora"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
long_context_lora_files_16k_1
():
return
snapshot_download
(
repo_id
=
"SangBinCho/long_context_16k_testing_1"
)
@
pytest
.
fixture
def
llama_2_7b_engine_extra_embeddings
():
cleanup_dist_env_and_memory
(
shutdown_ray
=
True
)
...
...
tests/lora/test_peft_helper.py
View file @
1eaff278
...
...
@@ -38,8 +38,8 @@ ERROR_CASES = [
]
def
test_peft_helper_pass
(
long_context
_lora_files
_16k_1
,
tmp_path
):
peft_helper
=
PEFTHelper
.
from_local_dir
(
long_context
_lora_files
_16k_1
,
def
test_peft_helper_pass
(
sql
_lora_files
,
tmp_path
):
peft_helper
=
PEFTHelper
.
from_local_dir
(
sql
_lora_files
,
max_position_embeddings
=
4096
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
16
,
max_cpu_loras
=
3
,
max_loras
=
2
)
peft_helper
.
validate_legal
(
lora_config
)
...
...
@@ -56,15 +56,12 @@ def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
"embed_tokens"
,
"lm_head"
,
]
assert
peft_helper
.
context_length
==
16384
assert
peft_helper
.
vllm_max_position_embeddings
==
4096
assert
peft_helper
.
vllm_long_context_scaling_factor
==
float
(
math
.
ceil
(
peft_helper
.
context_length
/
peft_helper
.
vllm_max_position_embeddings
))
# test RSLoRA
rslora_config
=
dict
(
use_rslora
=
True
)
test_dir
=
tmp_path
/
"test_rslora"
shutil
.
copytree
(
long_context
_lora_files
_16k_1
,
test_dir
)
shutil
.
copytree
(
sql
_lora_files
,
test_dir
)
# Load and modify configuration
config_path
=
test_dir
/
"adapter_config.json"
...
...
vllm/config.py
View file @
1eaff278
...
...
@@ -3014,12 +3014,7 @@ class LoRAConfig:
(added to the base model vocabulary)."""
lora_vocab_padding_size
:
ClassVar
[
int
]
=
current_platform
\
.
get_lora_vocab_padding_size
()
long_lora_scaling_factors
:
Optional
[
tuple
[
float
,
...]]
=
None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
default_mm_loras
:
Optional
[
dict
[
str
,
str
]]
=
None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
...
...
@@ -3052,7 +3047,6 @@ class LoRAConfig:
factors
.
append
(
self
.
lora_dtype
)
factors
.
append
(
self
.
lora_extra_vocab_size
)
factors
.
append
(
self
.
lora_vocab_padding_size
)
factors
.
append
(
self
.
long_lora_scaling_factors
)
factors
.
append
(
self
.
bias_enabled
)
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
...
...
@@ -3091,11 +3085,6 @@ class LoRAConfig:
elif
isinstance
(
self
.
lora_dtype
,
str
):
self
.
lora_dtype
=
getattr
(
torch
,
self
.
lora_dtype
)
def
verify_lora_support
(
self
):
if
self
.
long_lora_scaling_factors
is
not
None
and
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"V1 LoRA does not support long LoRA, please use V0."
)
@
config
@
dataclass
(
config
=
ConfigDict
(
arbitrary_types_allowed
=
True
))
...
...
@@ -4564,7 +4553,6 @@ class VllmConfig:
if
self
.
lora_config
is
not
None
:
self
.
lora_config
.
verify_with_cache_config
(
self
.
cache_config
)
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_lora_support
()
if
self
.
prompt_adapter_config
is
not
None
:
self
.
prompt_adapter_config
.
verify_with_model_config
(
self
.
model_config
)
...
...
vllm/engine/arg_utils.py
View file @
1eaff278
...
...
@@ -358,8 +358,6 @@ class EngineArgs:
max_cpu_loras
:
Optional
[
int
]
=
LoRAConfig
.
max_cpu_loras
lora_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
LoRAConfig
.
lora_dtype
lora_extra_vocab_size
:
int
=
LoRAConfig
.
lora_extra_vocab_size
long_lora_scaling_factors
:
Optional
[
tuple
[
float
,
...]]
=
\
LoRAConfig
.
long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter
:
bool
=
False
max_prompt_adapters
:
int
=
PromptAdapterConfig
.
max_prompt_adapters
...
...
@@ -723,8 +721,6 @@ class EngineArgs:
"--lora-dtype"
,
**
lora_kwargs
[
"lora_dtype"
],
)
lora_group
.
add_argument
(
"--long-lora-scaling-factors"
,
**
lora_kwargs
[
"long_lora_scaling_factors"
])
lora_group
.
add_argument
(
"--max-cpu-loras"
,
**
lora_kwargs
[
"max_cpu_loras"
])
lora_group
.
add_argument
(
"--fully-sharded-loras"
,
...
...
@@ -1245,7 +1241,6 @@ class EngineArgs:
default_mm_loras
=
self
.
default_mm_loras
,
fully_sharded_loras
=
self
.
fully_sharded_loras
,
lora_extra_vocab_size
=
self
.
lora_extra_vocab_size
,
long_lora_scaling_factors
=
self
.
long_lora_scaling_factors
,
lora_dtype
=
self
.
lora_dtype
,
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
and
self
.
max_cpu_loras
>
0
else
None
)
if
self
.
enable_lora
else
None
...
...
vllm/lora/layers.py
View file @
1eaff278
...
...
@@ -28,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
# yapf: enable
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
(
LinearScalingRotaryEmbedding
,
RotaryEmbedding
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.platforms
import
current_platform
...
...
@@ -1193,91 +1191,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
)
->
bool
:
# Special handling for the LogitsProcessor.
return
False
class
LinearScalingRotaryEmbeddingWithLoRA
(
BaseLayerWithLoRA
):
"""Implements RoPE-scaled embeddings with linear scaling for
multiple LoRA adapters with a specialized kernel.
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
which can handle multi lora adapters in a specialized kernel.
"""
def
__init__
(
self
,
base_layer
:
RotaryEmbedding
)
->
None
:
super
().
__init__
()
self
.
base_layer
=
base_layer
@
property
def
scaling_factors
(
self
):
return
self
.
base_layer
.
scaling_factors
@
property
def
rotary_dim
(
self
):
return
self
.
base_layer
.
rotary_dim
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
scaling_factors
=
(
list
(
lora_config
.
long_lora_scaling_factors
)
if
lora_config
.
long_lora_scaling_factors
else
[])
base_scaling_factor
=
(
self
.
base_layer
.
scaling_factor
if
isinstance
(
self
.
base_layer
,
LinearScalingRotaryEmbedding
)
else
1.0
)
scaling_factors
=
sorted
(
list
(
set
([
base_scaling_factor
]
+
scaling_factors
)))
self
.
base_layer
=
LinearScalingRotaryEmbedding
(
self
.
base_layer
.
head_size
,
self
.
base_layer
.
rotary_dim
,
self
.
base_layer
.
max_position_embeddings
,
self
.
base_layer
.
base
,
self
.
base_layer
.
is_neox_style
,
scaling_factors
,
self
.
base_layer
.
dtype
,
)
def
reset_lora
(
self
,
index
:
int
):
...
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
...
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
base_layer
(
positions
,
query
,
key
,
offsets
=
self
.
punica_wrapper
.
long_lora_indices
,
)
@
property
def
scaling_factor_to_offset
(
self
)
->
dict
[
float
,
int
]:
return
self
.
base_layer
.
scaling_factor_to_offset
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
model_config
:
Optional
[
PretrainedConfig
],
)
->
bool
:
"""Returns True if the layer can be replaced by this LoRA layer."""
return
(
type
(
source_layer
)
is
LinearScalingRotaryEmbedding
or
type
(
source_layer
)
is
RotaryEmbedding
)
def
extra_repr
(
self
)
->
str
:
return
self
.
base_layer
.
extra_repr
()
vllm/lora/models.py
View file @
1eaff278
...
...
@@ -4,7 +4,6 @@
import
math
import
os
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
regex
as
re
...
...
@@ -19,9 +18,7 @@ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
remove_adapter
,
set_adapter_mapping
)
from
vllm.config
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
LinearScalingRotaryEmbeddingWithLoRA
,
LoRAMapping
)
from
vllm.lora.layers
import
BaseLayerWithLoRA
,
LoRAMapping
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.peft_helper
import
PEFTHelper
from
vllm.lora.punica_wrapper
import
get_punica_wrapper
...
...
@@ -43,18 +40,6 @@ logger = init_logger(__name__)
_GLOBAL_LORA_ID
=
0
@
dataclass
class
LongContextLoRAContext
:
"""Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models.
scaling_factors
:
list
[
float
]
# dimension to apply rotary embedding.
rot_dim
:
int
# offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified.
offsets_by_lora_id
:
dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
def
get_lora_id
():
global
_GLOBAL_LORA_ID
_GLOBAL_LORA_ID
+=
1
...
...
@@ -80,20 +65,16 @@ class LoRAModel(AdapterModel):
lora_model_id
:
int
,
rank
:
int
,
loras
:
dict
[
str
,
LoRALayerWeights
],
scaling_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
scaling_factor: Scaling factor to support long context lora model.
None if the lora is not tuned for long context support.
"""
self
.
id
=
lora_model_id
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self
.
scaling_factor
=
scaling_factor
assert
(
lora_model_id
>
0
),
f
"a valid lora id should be greater than 0, got
{
self
.
id
}
"
...
...
@@ -192,10 +173,7 @@ class LoRAModel(AdapterModel):
for
lora
in
loras
.
values
():
lora
.
optimize
()
return
cls
(
lora_model_id
,
peft_helper
.
r
,
loras
,
scaling_factor
=
peft_helper
.
vllm_long_context_scaling_factor
)
return
cls
(
lora_model_id
,
peft_helper
.
r
,
loras
)
@
classmethod
def
from_local_checkpoint
(
...
...
@@ -360,24 +338,17 @@ class LoRAModelManager(AdapterModelManager):
self
.
max_num_batched_tokens
=
math
.
ceil
(
max_num_batched_tokens
/
8
)
*
8
self
.
lora_index_to_id
:
list
[
Optional
[
int
]]
=
[
None
]
*
self
.
lora_slots
self
.
vocab_size
=
vocab_size
self
.
long_lora_context
:
Optional
[
LongContextLoRAContext
]
=
None
self
.
punica_wrapper
=
get_punica_wrapper
(
max_num_batched_tokens
,
max_batches
=
self
.
max_num_seqs
,
device
=
self
.
device
,
max_loras
=
self
.
lora_config
.
max_loras
)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self
.
scaling_factor_to_offset
:
dict
[
float
,
int
]
=
{}
super
().
__init__
(
model
)
self
.
supported_lora_modules
=
get_supported_lora_modules
(
self
.
model
)
assert
self
.
supported_lora_modules
,
"No supported LoRA modules found in"
f
"
{
self
.
model
.
__class__
.
__name__
}
."
if
lora_config
.
long_lora_scaling_factors
:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self
.
supported_lora_modules
.
append
(
"rotary_emb"
)
self
.
packed_modules_mapping
=
get_packed_modules_mapping
(
self
.
model
)
# Used to indicate whether the model is a multimodal model
...
...
@@ -454,25 +425,9 @@ class LoRAModelManager(AdapterModelManager):
except
ValueError
:
pass
def
_set_long_lora_context
(
self
,
lora
:
LoRAModel
):
if
self
.
long_lora_context
is
None
:
return
if
lora
.
scaling_factor
is
None
:
return
if
(
lora
.
scaling_factor
not
in
self
.
scaling_factor_to_offset
):
raise
ValueError
(
f
"Long LoRA scaling factor
{
lora
.
scaling_factor
}
"
" has not been initialized."
)
offsets
=
self
.
scaling_factor_to_offset
.
get
(
lora
.
scaling_factor
)
if
offsets
:
self
.
long_lora_context
.
offsets_by_lora_id
[
lora
.
id
]
=
offsets
def
_add_adapter
(
self
,
lora
:
LoRAModel
):
self
.
_create_merged_loras_inplace
(
lora
)
self
.
_registered_adapters
[
lora
.
id
]
=
lora
self
.
_set_long_lora_context
(
lora
)
def
pin_adapter
(
self
,
lora_id
:
int
)
->
bool
:
"""Pin a LoRAModel in the manager cache."""
...
...
@@ -488,7 +443,6 @@ class LoRAModelManager(AdapterModelManager):
self
.
lora_slots
+
1
,
self
.
vocab_size
,
self
.
lora_config
.
lora_extra_vocab_size
,
self
.
long_lora_context
,
)
def
remove_all_adapters
(
self
):
...
...
@@ -528,13 +482,6 @@ class LoRAModelManager(AdapterModelManager):
from_layer
(
module
,
self
.
lora_slots
,
self
.
lora_config
,
packed_moduled_lst
,
self
.
model
.
config
))
# LinearScalingRotaryEmbeddingWithLoRA is used to handle
# long context lora. Register relevant metadata.
if
isinstance
(
new_module
,
LinearScalingRotaryEmbeddingWithLoRA
):
self
.
long_lora_context
=
LongContextLoRAContext
(
new_module
.
scaling_factors
,
new_module
.
rotary_dim
)
self
.
scaling_factor_to_offset
=
\
new_module
.
scaling_factor_to_offset
# (yard1): TODO make this more robust
if
"lm_head"
in
module_name
:
logits_processor_module_name
=
'logits_processor'
...
...
@@ -574,15 +521,13 @@ class LoRAModelManager(AdapterModelManager):
self
,
lora_id
:
int
,
rank
:
int
,
scaling_factor
:
Optional
[
float
],
embedding_modules
:
Optional
[
dict
[
str
,
str
]]
=
None
)
->
LoRAModel
:
"""Create zero-initialized LoRAModel for warmup."""
model
=
LoRAModel
(
lora_id
,
rank
,
{}
,
scaling_factor
)
model
=
LoRAModel
(
lora_id
,
rank
,
{})
for
module_name
,
module
in
self
.
model
.
named_modules
():
bias_enabled
=
self
.
lora_config
.
bias_enabled
if
(
not
self
.
_match_target_modules
(
module_name
)
or
not
isinstance
(
module
,
BaseLayerWithLoRA
)
or
isinstance
(
module
,
LinearScalingRotaryEmbeddingWithLoRA
)
or
self
.
_filter_unsupported_mm_module
(
module_name
)):
continue
parts
=
module_name
.
split
(
"."
)
...
...
@@ -723,11 +668,8 @@ class LoRAModelManager(AdapterModelManager):
self
.
_deactivate_adapter
)
def
add_adapter
(
self
,
adapter
:
LoRAModel
)
->
bool
:
logger
.
debug
(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s"
,
adapter
.
id
,
adapter
.
id
,
adapter
.
scaling_factor
)
logger
.
debug
(
"Adding lora. Model id: %d, "
"int id: %d"
,
adapter
.
id
,
adapter
.
id
)
return
add_adapter
(
adapter
,
self
.
_registered_adapters
,
self
.
capacity
,
self
.
_add_adapter
)
...
...
@@ -772,10 +714,8 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def
add_adapter
(
self
,
lora
:
LoRAModel
)
->
bool
:
"""Add a LoRAModel to the manager."""
logger
.
debug
(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s"
,
lora
.
id
,
lora
.
id
,
lora
.
scaling_factor
)
logger
.
debug
(
"Adding lora. Model id: %d, "
"int id: %d"
,
lora
.
id
,
lora
.
id
)
if
lora
.
id
not
in
self
.
_registered_adapters
:
self
.
_add_adapter
(
lora
)
was_added
=
True
...
...
vllm/lora/peft_helper.py
View file @
1eaff278
...
...
@@ -35,12 +35,9 @@ class PEFTHelper:
use_rslora
:
bool
=
field
(
default
=
False
)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora
:
bool
=
field
(
default
=
False
)
# long context lora field
context_length
:
int
=
field
(
default
=
0
)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor
:
float
=
field
(
default
=
1.0
)
vllm_max_position_embeddings
:
Optional
[
int
]
=
field
(
default
=
False
)
vllm_long_context_scaling_factor
:
Optional
[
float
]
=
field
(
default
=
None
)
def
_validate_features
(
self
)
->
list
[
str
]:
"""
...
...
@@ -59,12 +56,6 @@ class PEFTHelper:
self
.
vllm_lora_scaling_factor
=
self
.
lora_alpha
/
math
.
sqrt
(
self
.
r
)
else
:
self
.
vllm_lora_scaling_factor
=
self
.
lora_alpha
/
self
.
r
if
self
.
context_length
:
if
self
.
vllm_max_position_embeddings
is
None
:
self
.
vllm_max_position_embeddings
=
self
.
context_length
self
.
vllm_long_context_scaling_factor
=
float
(
math
.
ceil
(
self
.
context_length
/
self
.
vllm_max_position_embeddings
))
@
classmethod
def
from_dict
(
cls
,
config_dict
:
dict
)
->
"PEFTHelper"
:
...
...
vllm/lora/punica_wrapper/punica_base.py
View file @
1eaff278
...
...
@@ -17,7 +17,6 @@ from .utils import compute_meta, convert_mapping
if
TYPE_CHECKING
:
# avoid circuit import
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.models
import
LongContextLoRAContext
class
PunicaWrapperABC
(
ABC
):
...
...
@@ -33,7 +32,6 @@ class PunicaWrapperABC(ABC):
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
**
kwargs
,
)
->
None
:
"""
...
...
@@ -144,14 +142,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
max_num_batched_tokens
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
_long_lora_indices
=
torch
.
empty
(
max_num_batched_tokens
,
dtype
=
torch
.
long
,
device
=
device
)
#
5
is the number of indices tensors.
#
4
is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
,long_lora_indices
self
.
indices_len
:
list
[
Optional
[
int
]]
=
[
None
]
*
5
# embeddings_indices
self
.
indices_len
:
list
[
Optional
[
int
]]
=
[
None
]
*
4
# these attributes are the information required for sgmv kernel
self
.
_seq_start_locs
=
torch
.
empty
(
max_batches
,
dtype
=
torch
.
long
,
...
...
@@ -176,14 +171,12 @@ class PunicaWrapperBase(PunicaWrapperABC):
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
):
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
long_lora_offsets_tensor
,
indices_len
,
)
=
convert_mapping
(
mapping
,
...
...
@@ -192,7 +185,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
vocab_size
,
extra_vocab_size
,
self
.
device
,
long_lora_context
,
)
self
.
_token_lora_indices
[:
base_indices
.
shape
[
0
]].
copy_
(
base_indices
)
self
.
_sampler_indices
[:
sampler_indices
.
shape
[
0
]].
copy_
(
sampler_indices
)
...
...
@@ -201,11 +193,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
self
.
_embeddings_indices
[:
embeddings_indices
.
shape
[
0
],
:
embeddings_indices
.
shape
[
1
]].
copy_
(
embeddings_indices
)
if
long_lora_offsets_tensor
is
not
None
:
self
.
_long_lora_indices
[:
long_lora_offsets_tensor
.
shape
[
0
]].
copy_
(
long_lora_offsets_tensor
)
else
:
self
.
_long_lora_indices
.
zero_
()
self
.
indices_len
[:]
=
indices_len
def
_update_prefill_metadata
(
self
,
...
...
@@ -312,28 +300,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
embeddings_indices_len
=
self
.
indices_len
[
3
]
return
self
.
_embeddings_indices
[:,
:
embeddings_indices_len
]
@
property
def
long_lora_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLoRA.
"""
long_lora_len
=
self
.
indices_len
[
4
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
def
update_metadata
(
self
,
mapping
:
"LoRAMapping"
,
lora_index_to_id
:
list
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
**
kwargs
):
def
update_metadata
(
self
,
mapping
:
"LoRAMapping"
,
lora_index_to_id
:
list
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
**
kwargs
):
self
.
_update_base_metadata
(
mapping
,
lora_index_to_id
,
max_loras
,
vocab_size
,
extra_vocab_size
,
long_lora_context
)
vocab_size
,
extra_vocab_size
)
if
mapping
.
is_prefill
:
# Update metadata required for prefill-related operators.
self
.
_update_prefill_metadata
(
self
.
token_lora_indices
)
...
...
vllm/lora/punica_wrapper/punica_gpu.py
View file @
1eaff278
...
...
@@ -7,7 +7,7 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
,
final
from
typing
import
Optional
,
Union
,
final
import
torch
...
...
@@ -21,10 +21,6 @@ if HAS_TRITON:
from
.punica_base
import
PunicaWrapperBase
if
TYPE_CHECKING
:
# avoid circuit import
from
vllm.lora.models
import
LongContextLoRAContext
@
final
class
PunicaWrapperGPU
(
PunicaWrapperBase
):
...
...
@@ -55,20 +51,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_num_prompts
,
device
=
device
)
def
update_metadata
(
self
,
mapping
:
LoRAMapping
,
lora_index_to_id
:
list
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
**
kwargs
):
def
update_metadata
(
self
,
mapping
:
LoRAMapping
,
lora_index_to_id
:
list
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
**
kwargs
):
self
.
is_prefill
=
mapping
.
is_prefill
self
.
_update_base_metadata
(
mapping
,
lora_index_to_id
,
max_loras
,
vocab_size
,
extra_vocab_size
,
long_lora_context
)
vocab_size
,
extra_vocab_size
)
# Prepare cuda kernel metadata tensors
self
.
token_mapping_meta
.
prepare_tensors
(
self
.
token_lora_indices
)
...
...
vllm/lora/punica_wrapper/punica_tpu.py
View file @
1eaff278
...
...
@@ -14,7 +14,6 @@ from vllm.lora.punica_wrapper.utils import convert_mapping
if
TYPE_CHECKING
:
# avoid circuit import
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.models
import
LongContextLoRAContext
from
.punica_base
import
PunicaWrapperBase
...
...
@@ -45,7 +44,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
self
.
_sampler_indices_padded
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
self
.
_embeddings_indices
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
self
.
_long_lora_indices
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
self
.
_lora_indices_per_batch
,
True
)
...
...
@@ -323,7 +321,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
):
# Make sure we don't accidentally collect outside operations
xm
.
mark_step
()
...
...
@@ -339,7 +336,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
long_lora_offsets_tensor
,
indices_len
,
)
=
convert_mapping
(
mapping
,
...
...
@@ -348,7 +344,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
vocab_size
,
extra_vocab_size
,
"cpu"
,
long_lora_context
,
)
self
.
_token_lora_indices
=
self
.
_pad_to_shape
(
base_indices
,
self
.
_token_lora_indices
.
shape
,
...
...
@@ -362,15 +357,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
self
.
_embeddings_indices
=
self
.
_pad_to_shape
(
embeddings_indices
,
self
.
_embeddings_indices
.
shape
,
dims
=
2
).
to
(
self
.
device
)
if
long_lora_offsets_tensor
is
not
None
:
self
.
_long_lora_indices
=
self
.
_pad_to_shape
(
long_lora_offsets_tensor
,
self
.
_long_lora_indices
.
shape
,
dims
=
1
).
to
(
self
.
device
)
else
:
zeroed
=
torch
.
zeros_like
(
self
.
_long_lora_indices
.
cpu
(),
dtype
=
torch
.
int32
)
self
.
_long_lora_indices
=
zeroed
.
to
(
self
.
device
)
self
.
indices_len
[:]
=
indices_len
def
_update_prefill_metadata
(
self
,
...
...
vllm/lora/punica_wrapper/utils.py
View file @
1eaff278
...
...
@@ -8,7 +8,6 @@ import torch
if
TYPE_CHECKING
:
# avoid circuit import
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.models
import
LongContextLoRAContext
def
compute_meta
(
...
...
@@ -49,9 +48,7 @@ def convert_mapping(
vocab_size
:
int
,
extra_vocab_size
:
int
,
device
:
torch
.
device
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
list
[
int
]]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
int
]]:
"""Converts LoRAMapping to index tensors.
Args:
...
...
@@ -60,7 +57,6 @@ def convert_mapping(
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
...
...
@@ -78,21 +74,14 @@ def convert_mapping(
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices
, long_lora_indices
).
embeddings_indices).
"""
index_mapping_indices
:
list
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
embedding_indices
=
index_mapping_indices
.
copy
()
lora_indices
=
index_mapping_indices
.
copy
()
long_lora_offsets
:
Optional
[
torch
.
Tensor
]
=
None
if
long_lora_context
:
long_lora_offsets
=
torch
.
zeros
(
len
(
index_mapping_indices
),
device
=
device
,
dtype
=
torch
.
long
)
prompt_mapping
:
list
[
int
]
=
[
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
for
x
in
mapping
.
prompt_mapping
...
...
@@ -104,20 +93,13 @@ def convert_mapping(
if
index_mapping_indices
[
i
]
>
0
else
-
1
)
embedding_indices
[
i
]
=
lora_idx
if
index_mapping_indices
[
i
]
>
0
else
0
lora_indices
[
i
]
=
lora_idx
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
lora_offset
:
int
=
long_lora_context
.
offsets_by_lora_id
.
get
(
index_mapping_indices
[
i
],
0
)
long_lora_offsets
[
i
]
=
lora_offset
indices_list
:
list
[
Union
[
list
[
int
],
torch
.
Tensor
]]
=
[
index_mapping_indices
,
lora_indices
,
embedding_indices
,
]
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
indices_list
.
append
(
long_lora_offsets
)
indices
=
torch
.
tensor
(
indices_list
,
dtype
=
torch
.
long
,
device
=
device
)
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
dtype
=
torch
.
long
,
...
...
@@ -136,11 +118,7 @@ def convert_mapping(
sampler_indices_padded
=
torch
.
arange
(
0
,
len
(
sampler_indices_padded
),
device
=
device
,
dtype
=
torch
.
long
)
+
(
sampler_indices_padded
*
len
(
sampler_indices_padded
))
long_lora_indices
=
None
long_lora_indices_len
:
Optional
[
int
]
=
None
if
long_lora_context
:
long_lora_indices
=
indices
[
3
]
long_lora_indices_len
=
long_lora_indices
.
shape
[
-
1
]
# Contain length of indices tensors. Used to index into each tensor.
indices_len
=
[
base_indices
.
shape
[
-
1
],
...
...
@@ -148,17 +126,11 @@ def convert_mapping(
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
],
]
if
long_lora_indices_len
is
not
None
:
indices_len
.
append
(
long_lora_indices_len
)
else
:
# If long_lora doesn't exist,append None
indices_len
.
append
(
None
)
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
long_lora_indices
,
indices_len
,
)
vllm/lora/utils.py
View file @
1eaff278
...
...
@@ -22,7 +22,6 @@ from vllm.lora.fully_sharded_layers import (
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
ColumnParallelLinearWithLoRA
,
LinearScalingRotaryEmbeddingWithLoRA
,
LogitsProcessorWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLoRA
,
...
...
@@ -56,7 +55,6 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithShardedLoRA
,
RowParallelLinearWithShardedLoRA
,
LinearScalingRotaryEmbeddingWithLoRA
,
}
...
...
vllm/lora/worker_manager.py
View file @
1eaff278
...
...
@@ -154,7 +154,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
lora_request
.
lora_int_id
)
else
:
dummy_lora
=
self
.
_adapter_manager
.
create_dummy_lora
(
lora_request
.
lora_int_id
,
rank
,
1
,
self
.
embedding_modules
)
lora_request
.
lora_int_id
,
rank
,
self
.
embedding_modules
)
if
self
.
_cached_dummy_lora
is
None
:
self
.
_cached_dummy_lora
=
dummy_lora
return
self
.
_adapter_manager
.
add_adapter
(
dummy_lora
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment