Unverified Commit 5157338e authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Improve LoRA spelling (#13831)

parent e206b543
...@@ -89,7 +89,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, ...@@ -89,7 +89,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
sort_by_lora_id: bool, sort_by_lora_id: bool,
device: str) -> torch.Tensor: device: str) -> torch.Tensor:
""" """
All prompts are mapped to a Lora ID in range [0, num_active_loras). All prompts are mapped to a LoRA ID in range [0, num_active_loras).
where 0 refers to first lora, 1 refers to second lora and so on. where 0 refers to first lora, 1 refers to second lora and so on.
""" """
assert num_active_loras > 0 assert num_active_loras > 0
......
...@@ -170,7 +170,7 @@ Now, you can specify a base_model_name alongside the name and path using JSON fo ...@@ -170,7 +170,7 @@ Now, you can specify a base_model_name alongside the name and path using JSON fo
To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case. To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.
## Lora model lineage in model card ## LoRA model lineage in model card
The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this: The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:
......
...@@ -491,7 +491,7 @@ def test_prefill_schedule_max_lora(): ...@@ -491,7 +491,7 @@ def test_prefill_schedule_max_lora():
lora_path="abc")) lora_path="abc"))
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
# Add two more requests to verify lora is prioritized. # Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular # 0: LoRA, 1: LoRA, 2: regular, 3: regular
# In the first iteration, index 0, 2 is scheduled. # In the first iteration, index 0, 2 is scheduled.
# If a request is not scheduled because it hits max lora, it is # If a request is not scheduled because it hits max lora, it is
# prioritized. Verify that. # prioritized. Verify that.
......
...@@ -26,7 +26,7 @@ def serve_parser(): ...@@ -26,7 +26,7 @@ def serve_parser():
return make_arg_parser(parser) return make_arg_parser(parser)
### Tests for Lora module parsing ### Tests for LoRA module parsing
def test_valid_key_value_format(serve_parser): def test_valid_key_value_format(serve_parser):
# Test old format: name=path # Test old format: name=path
args = serve_parser.parse_args([ args = serve_parser.parse_args([
......
...@@ -8,8 +8,8 @@ import pytest ...@@ -8,8 +8,8 @@ import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse, from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest, LoadLoRAAdapterRequest,
UnloadLoraAdapterRequest) UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.serving_models import (BaseModelPath, from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels) OpenAIServingModels)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -51,7 +51,7 @@ async def test_serving_model_name(): ...@@ -51,7 +51,7 @@ async def test_serving_model_name():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_lora_adapter_success(): async def test_load_lora_adapter_success():
serving_models = await _async_serving_models_init() serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter", request = LoadLoRAAdapterRequest(lora_name="adapter",
lora_path="/path/to/adapter2") lora_path="/path/to/adapter2")
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
...@@ -62,7 +62,7 @@ async def test_load_lora_adapter_success(): ...@@ -62,7 +62,7 @@ async def test_load_lora_adapter_success():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields(): async def test_load_lora_adapter_missing_fields():
serving_models = await _async_serving_models_init() serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="", lora_path="") request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput" assert response.type == "InvalidUserInput"
...@@ -72,14 +72,14 @@ async def test_load_lora_adapter_missing_fields(): ...@@ -72,14 +72,14 @@ async def test_load_lora_adapter_missing_fields():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_lora_adapter_duplicate(): async def test_load_lora_adapter_duplicate():
serving_models = await _async_serving_models_init() serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter1", request = LoadLoRAAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1") lora_path="/path/to/adapter1")
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format( assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1') lora_name='adapter1')
assert len(serving_models.lora_requests) == 1 assert len(serving_models.lora_requests) == 1
request = LoadLoraAdapterRequest(lora_name="adapter1", request = LoadLoRAAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1") lora_path="/path/to/adapter1")
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
...@@ -91,12 +91,12 @@ async def test_load_lora_adapter_duplicate(): ...@@ -91,12 +91,12 @@ async def test_load_lora_adapter_duplicate():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unload_lora_adapter_success(): async def test_unload_lora_adapter_success():
serving_models = await _async_serving_models_init() serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter1", request = LoadLoRAAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1") lora_path="/path/to/adapter1")
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert len(serving_models.lora_requests) == 1 assert len(serving_models.lora_requests) == 1
request = UnloadLoraAdapterRequest(lora_name="adapter1") request = UnloadLoRAAdapterRequest(lora_name="adapter1")
response = await serving_models.unload_lora_adapter(request) response = await serving_models.unload_lora_adapter(request)
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1') lora_name='adapter1')
...@@ -106,7 +106,7 @@ async def test_unload_lora_adapter_success(): ...@@ -106,7 +106,7 @@ async def test_unload_lora_adapter_success():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields(): async def test_unload_lora_adapter_missing_fields():
serving_models = await _async_serving_models_init() serving_models = await _async_serving_models_init()
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None) request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
response = await serving_models.unload_lora_adapter(request) response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput" assert response.type == "InvalidUserInput"
...@@ -116,7 +116,7 @@ async def test_unload_lora_adapter_missing_fields(): ...@@ -116,7 +116,7 @@ async def test_unload_lora_adapter_missing_fields():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unload_lora_adapter_not_found(): async def test_unload_lora_adapter_not_found():
serving_models = await _async_serving_models_init() serving_models = await _async_serving_models_init()
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter") request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_models.unload_lora_adapter(request) response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "NotFoundError" assert response.type == "NotFoundError"
......
...@@ -14,16 +14,16 @@ from vllm.config import LoRAConfig ...@@ -14,16 +14,16 @@ from vllm.config import LoRAConfig
from vllm.lora.fully_sharded_layers import ( from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA) RowParallelLinearWithShardedLoRA)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora, LinearScalingRotaryEmbeddingWithLoRA,
LogitsProcessorWithLoRA, LoRAMapping, LogitsProcessorWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora, MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLora, QKVParallelLinearWithLoRA,
ReplicatedLinearWithLoRA, ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA) VocabParallelEmbeddingWithLoRA)
...@@ -866,9 +866,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -866,9 +866,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
bias=False, bias=False,
params_dtype=torch.float16) params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = (MergedQKVParallelLinearWithLora(linear) lora_linear = (MergedQKVParallelLinearWithLoRA(linear)
if not fully_shard else if not fully_shard else
MergedQKVParallelLinearWithShardedLora(linear)) MergedQKVParallelLinearWithShardedLoRA(linear))
else: else:
linear = QKVParallelLinear(4096, linear = QKVParallelLinear(4096,
64, 64,
...@@ -876,9 +876,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -876,9 +876,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
bias=False, bias=False,
params_dtype=torch.float16) params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = QKVParallelLinearWithLora( lora_linear = QKVParallelLinearWithLoRA(
linear linear
) if not fully_shard else QKVParallelLinearWithShardedLora(linear) ) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear)
@dataclass @dataclass
class FakeConfig: class FakeConfig:
...@@ -1024,7 +1024,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -1024,7 +1024,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
base, base,
is_neox_style, is_neox_style,
) )
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) lora_rope = LinearScalingRotaryEmbeddingWithLoRA(rope)
lora_rope.set_mapping(punica_wrapper) lora_rope.set_mapping(punica_wrapper)
lora_rope.create_lora_weights(max_loras, lora_config) lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base, linear_rope = get_rope(head_size, rotary_dim, max_position, base,
......
...@@ -8,7 +8,7 @@ import pytest ...@@ -8,7 +8,7 @@ import pytest
import vllm import vllm
from vllm import SamplingParams from vllm import SamplingParams
from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLoRA
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
LinearScalingRotaryEmbedding) LinearScalingRotaryEmbedding)
...@@ -151,7 +151,7 @@ def test_rotary_emb_replaced(dist_init): ...@@ -151,7 +151,7 @@ def test_rotary_emb_replaced(dist_init):
if "rotary_emb" in module_name: if "rotary_emb" in module_name:
if "base_layer" not in module_name: if "base_layer" not in module_name:
rotary_emb_count += 1 rotary_emb_count += 1
assert isinstance(module, LinearScalingRotaryEmbeddingWithLora) assert isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
else: else:
assert isinstance(module, LinearScalingRotaryEmbedding) assert isinstance(module, LinearScalingRotaryEmbedding)
# Llama 2 has 32 layers. # Llama 2 has 32 layers.
......
...@@ -1629,7 +1629,7 @@ class LLMEngine: ...@@ -1629,7 +1629,7 @@ class LLMEngine:
max_tokens_requests: List[int] = [] max_tokens_requests: List[int] = []
finished_reason_requests: List[str] = [] finished_reason_requests: List[str] = []
# Lora requests # LoRA requests
running_lora_adapters = dict( running_lora_adapters = dict(
collectionsCounter([ collectionsCounter([
running_request.lora_request.lora_name running_request.lora_request.lora_name
......
...@@ -53,7 +53,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -53,7 +53,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingResponse, EmbeddingResponse,
EmbeddingResponseData, EmbeddingResponseData,
ErrorResponse, ErrorResponse,
LoadLoraAdapterRequest, LoadLoRAAdapterRequest,
PoolingChatRequest, PoolingChatRequest,
PoolingCompletionRequest, PoolingCompletionRequest,
PoolingRequest, PoolingResponse, PoolingRequest, PoolingResponse,
...@@ -63,7 +63,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -63,7 +63,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeResponse, TokenizeResponse,
TranscriptionRequest, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponse,
UnloadLoraAdapterRequest) UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
...@@ -690,12 +690,12 @@ if envs.VLLM_TORCH_PROFILER_DIR: ...@@ -690,12 +690,12 @@ if envs.VLLM_TORCH_PROFILER_DIR:
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning( logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. " "LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!") "This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter", @router.post("/v1/load_lora_adapter",
dependencies=[Depends(validate_json_request)]) dependencies=[Depends(validate_json_request)])
async def load_lora_adapter(request: LoadLoraAdapterRequest, async def load_lora_adapter(request: LoadLoRAAdapterRequest,
raw_request: Request): raw_request: Request):
handler = models(raw_request) handler = models(raw_request)
response = await handler.load_lora_adapter(request) response = await handler.load_lora_adapter(request)
...@@ -707,7 +707,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: ...@@ -707,7 +707,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/unload_lora_adapter", @router.post("/v1/unload_lora_adapter",
dependencies=[Depends(validate_json_request)]) dependencies=[Depends(validate_json_request)])
async def unload_lora_adapter(request: UnloadLoraAdapterRequest, async def unload_lora_adapter(request: UnloadLoRAAdapterRequest,
raw_request: Request): raw_request: Request):
handler = models(raw_request) handler = models(raw_request)
response = await handler.unload_lora_adapter(request) response = await handler.unload_lora_adapter(request)
......
...@@ -1431,12 +1431,12 @@ class DetokenizeResponse(OpenAIBaseModel): ...@@ -1431,12 +1431,12 @@ class DetokenizeResponse(OpenAIBaseModel):
prompt: str prompt: str
class LoadLoraAdapterRequest(BaseModel): class LoadLoRAAdapterRequest(BaseModel):
lora_name: str lora_name: str
lora_path: str lora_path: str
class UnloadLoraAdapterRequest(BaseModel): class UnloadLoRAAdapterRequest(BaseModel):
lora_name: str lora_name: str
lora_int_id: Optional[int] = Field(default=None) lora_int_id: Optional[int] = Field(default=None)
......
...@@ -9,10 +9,10 @@ from typing import List, Optional, Union ...@@ -9,10 +9,10 @@ from typing import List, Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse, from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest, LoadLoRAAdapterRequest,
ModelCard, ModelList, ModelCard, ModelList,
ModelPermission, ModelPermission,
UnloadLoraAdapterRequest) UnloadLoRAAdapterRequest)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -88,7 +88,7 @@ class OpenAIServingModels: ...@@ -88,7 +88,7 @@ class OpenAIServingModels:
if self.static_lora_modules is None: if self.static_lora_modules is None:
return return
for lora in self.static_lora_modules: for lora in self.static_lora_modules:
load_request = LoadLoraAdapterRequest(lora_path=lora.path, load_request = LoadLoRAAdapterRequest(lora_path=lora.path,
lora_name=lora.name) lora_name=lora.name)
load_result = await self.load_lora_adapter( load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name) request=load_request, base_model_name=lora.base_model_name)
...@@ -140,7 +140,7 @@ class OpenAIServingModels: ...@@ -140,7 +140,7 @@ class OpenAIServingModels:
async def load_lora_adapter( async def load_lora_adapter(
self, self,
request: LoadLoraAdapterRequest, request: LoadLoRAAdapterRequest,
base_model_name: Optional[str] = None base_model_name: Optional[str] = None
) -> Union[ErrorResponse, str]: ) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request) error_check_ret = await self._check_load_lora_adapter_request(request)
...@@ -177,7 +177,7 @@ class OpenAIServingModels: ...@@ -177,7 +177,7 @@ class OpenAIServingModels:
async def unload_lora_adapter( async def unload_lora_adapter(
self, self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request error_check_ret = await self._check_unload_lora_adapter_request(request
) )
if error_check_ret is not None: if error_check_ret is not None:
...@@ -192,7 +192,7 @@ class OpenAIServingModels: ...@@ -192,7 +192,7 @@ class OpenAIServingModels:
return f"Success: LoRA adapter '{lora_name}' removed successfully." return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request( async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided # Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path: if not request.lora_name or not request.lora_path:
return create_error_response( return create_error_response(
...@@ -214,7 +214,7 @@ class OpenAIServingModels: ...@@ -214,7 +214,7 @@ class OpenAIServingModels:
async def _check_unload_lora_adapter_request( async def _check_unload_lora_adapter_request(
self, self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided # Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id: if not request.lora_name and not request.lora_int_id:
return create_error_response( return create_error_response(
......
...@@ -13,8 +13,8 @@ from vllm.distributed.communication_op import ( ...@@ -13,8 +13,8 @@ from vllm.distributed.communication_op import (
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA, from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora, MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLora, QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA) RowParallelLinearWithLoRA)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -167,9 +167,9 @@ class MergedColumnParallelLinearWithShardedLoRA( ...@@ -167,9 +167,9 @@ class MergedColumnParallelLinearWithShardedLoRA(
) )
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
""" """
Differs from QKVParallelLinearWithLora by slicing the Differs from QKVParallelLinearWithLoRA by slicing the
LoRA A's also. LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
...@@ -202,9 +202,9 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): ...@@ -202,9 +202,9 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
) )
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
""" """
Differs from MergedQKVParallelLinearWithLora by slicing the Differs from MergedQKVParallelLinearWithLoRA by slicing the
LoRA A's also. LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
......
...@@ -363,7 +363,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -363,7 +363,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None, lora_bias: Optional[torch.Tensor] = None,
): ):
# Except for QKVParallelLinearWithLora and # Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will # store weights in a tuple of size 1. These two layers will
# override this function. # override this function.
...@@ -686,7 +686,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -686,7 +686,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
and len(packed_modules_list) == 2) and len(packed_modules_list) == 2)
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
""" """
ColumnParallelLinear layer that is specifically designed for ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chatglm3 and baichuan-7b, qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
...@@ -754,7 +754,7 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -754,7 +754,7 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
packed_modules_list) == 1 packed_modules_list) == 1
class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA): class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj). (q_proj + k_proj + v_proj -> qkv_proj).
...@@ -1120,7 +1120,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1120,7 +1120,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
return False return False
class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
"""Implements RoPE-scaled embeddings with linear scaling for """Implements RoPE-scaled embeddings with linear scaling for
multiple LoRA adapters with a specialized kernel. multiple LoRA adapters with a specialized kernel.
......
...@@ -20,7 +20,7 @@ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, ...@@ -20,7 +20,7 @@ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA, from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora, LinearScalingRotaryEmbeddingWithLoRA,
LoRAMapping) LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
...@@ -201,7 +201,7 @@ class LoRAModel(AdapterModel): ...@@ -201,7 +201,7 @@ class LoRAModel(AdapterModel):
expected_lora_modules: Name of modules that are expected to be expected_lora_modules: Name of modules that are expected to be
replaced by lora. replaced by lora.
peft_helper: Loaded lora configuration information. peft_helper: Loaded lora configuration information.
lora_model_id: Lora model id. If not given, automatically set by lora_model_id: LoRA model id. If not given, automatically set by
a global counter. a global counter.
device: Device where the lora model is loaded. device: Device where the lora model is loaded.
dtype: dtype of the lora model weights. dtype: dtype of the lora model weights.
...@@ -480,9 +480,9 @@ class LoRAModelManager(AdapterModelManager): ...@@ -480,9 +480,9 @@ class LoRAModelManager(AdapterModelManager):
from_layer(module, self.lora_slots, self.lora_config, from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config)) packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLora is used to handle # LinearScalingRotaryEmbeddingWithLoRA is used to handle
# long context lora. Register relevant metadata. # long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
self.long_lora_context = LongContextLoRAContext( self.long_lora_context = LongContextLoRAContext(
new_module.scaling_factors, new_module.rotary_dim) new_module.scaling_factors, new_module.rotary_dim)
self.scaling_factor_to_offset = \ self.scaling_factor_to_offset = \
...@@ -527,7 +527,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -527,7 +527,7 @@ class LoRAModelManager(AdapterModelManager):
bias_enabled = self.lora_config.bias_enabled bias_enabled = self.lora_config.bias_enabled
if (not self._match_target_modules(module_name) if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA) or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora) or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
or self._filter_unsupported_mm_module(module_name)): or self._filter_unsupported_mm_module(module_name)):
continue continue
parts = module_name.split(".") parts = module_name.split(".")
......
...@@ -42,7 +42,7 @@ class PEFTHelper: ...@@ -42,7 +42,7 @@ class PEFTHelper:
def _validate_features(self) -> List[str]: def _validate_features(self) -> List[str]:
""" """
Check if there are any unsupported Lora features. Check if there are any unsupported LoRA features.
""" """
error_msg = [] error_msg = []
if self.modules_to_save: if self.modules_to_save:
......
...@@ -314,7 +314,7 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -314,7 +314,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
def long_lora_indices(self) -> torch.Tensor: def long_lora_indices(self) -> torch.Tensor:
""" """
This property provides access to the indices used for long context This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora. lora, specifically for LinearScalingRotaryEmbeddingWithLoRA.
""" """
long_lora_len = self.indices_len[4] long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len] return self._long_lora_indices[:long_lora_len]
......
...@@ -15,17 +15,17 @@ from vllm.logger import init_logger ...@@ -15,17 +15,17 @@ from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import ( from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA) RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below # being imported for _all_lora_classes below
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora, LinearScalingRotaryEmbeddingWithLoRA,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora, MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLora, QKVParallelLinearWithLoRA,
ReplicatedLinearWithLoRA, ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA) VocabParallelEmbeddingWithLoRA)
...@@ -41,17 +41,17 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { ...@@ -41,17 +41,17 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA, VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLora, QKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora, MergedQKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA, ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
LinearScalingRotaryEmbeddingWithLora, LinearScalingRotaryEmbeddingWithLoRA,
} }
......
...@@ -6,10 +6,10 @@ from typing import List, Optional, Set, Tuple ...@@ -6,10 +6,10 @@ from typing import List, Optional, Set, Tuple
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposer from vllm.spec_decode.interfaces import SpeculativeProposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoRANotSupportedWorkerBase
class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer): class ProposerWorkerBase(LoRANotSupportedWorkerBase, SpeculativeProposer):
"""Interface for proposer workers""" """Interface for proposer workers"""
@abstractmethod @abstractmethod
......
...@@ -47,7 +47,7 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output, ...@@ -47,7 +47,7 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -118,7 +118,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -118,7 +118,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
class SpecDecodeWorker(LoraNotSupportedWorkerBase): class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"""Worker which implements speculative decoding. """Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal Speculative decoding reduces decoding per-token latency by using a proposal
......
...@@ -21,7 +21,7 @@ ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -21,7 +21,7 @@ ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = {
@dataclass @dataclass
class ArcticLoraConfig: class ArcticLoRAConfig:
lora_r: int = 64 lora_r: int = 64
lora_alpha: float = 16 lora_alpha: float = 16
shard_base_weights: bool = False shard_base_weights: bool = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment