Unverified Commit 28219864 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Core] Modify the initialization parameters of the lora manager (#25249)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 6c117cff
...@@ -8,11 +8,12 @@ import torch ...@@ -8,11 +8,12 @@ import torch
from safetensors.torch import load_file from safetensors.torch import load_file
from torch import nn from torch import nn
from vllm.config import ModelConfig, VllmConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA, from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA) RowParallelLinearWithLoRA)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager, from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager) LRUCacheLoRAModelManager)
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
...@@ -435,10 +436,19 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, ...@@ -435,10 +436,19 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
target_modules=["layer1.dense1", "dense2"], target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE, lora_dtype=DEFAULT_DTYPE,
) )
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config,
lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_adapter_manager = LRUCacheWorkerLoRAManager( worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) worker_adapter_manager.max_num_seqs = 4
worker_adapter_manager.max_num_batched_tokens = 2
worker_adapter_manager.create_lora_manager(dummy_model) worker_adapter_manager.create_lora_manager(dummy_model)
mapping = LoRAMapping([], []) mapping = LoRAMapping([], [])
...@@ -517,10 +527,20 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, ...@@ -517,10 +527,20 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
max_cpu_loras=4, max_cpu_loras=4,
max_loras=4, max_loras=4,
lora_dtype=DEFAULT_DTYPE) lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = WorkerLoRAManager(
4, 2, dummy_model_gate_up.unpadded_vocab_size - model_config = ModelConfig(max_model_len=16)
lora_config.lora_extra_vocab_size, lora_config, device, vllm_config = VllmConfig(model_config=model_config,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_adapter_manager = WorkerLoRAManager(vllm_config, device,
EMBEDDING_MODULES,
EMBEDDING_PADDING_MODULES)
worker_adapter_manager.vocab_size = (
dummy_model_gate_up.unpadded_vocab_size -
lora_config.lora_extra_vocab_size)
worker_adapter_manager.create_lora_manager(dummy_model_gate_up) worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
dummy_lora_files = f"{tmp_path}/lora_adapter" dummy_lora_files = f"{tmp_path}/lora_adapter"
......
...@@ -9,7 +9,7 @@ from typing import Optional, Union ...@@ -9,7 +9,7 @@ from typing import Optional, Union
import torch import torch
from safetensors.torch import save_file from safetensors.torch import save_file
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
class DummyLoRAManager: class DummyLoRAManager:
......
...@@ -14,7 +14,7 @@ from torch import nn ...@@ -14,7 +14,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
......
...@@ -6,7 +6,7 @@ from typing import Any, Literal, Optional, Union ...@@ -6,7 +6,7 @@ from typing import Any, Literal, Optional, Union
import torch import torch
from vllm.config.lora import LoRAConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.models import (LoRAModel, LoRAModelManager, from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager) LRUCacheLoRAModelManager, create_lora_manager)
...@@ -27,25 +27,26 @@ class WorkerLoRAManager: ...@@ -27,25 +27,26 @@ class WorkerLoRAManager:
def __init__( def __init__(
self, self,
max_num_seqs: int, vllm_config: VllmConfig,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device, device: torch.device,
embedding_modules: dict[str, str], embedding_modules: dict[str, str],
embedding_padding_modules: list[str], embedding_padding_modules: list[str],
lora_model_cls: type[LoRAModel] = LoRAModel, lora_model_cls: type[LoRAModel] = LoRAModel,
max_position_embeddings: Optional[int] = None,
): ):
self._lora_model_cls = lora_model_cls self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules self.embedding_padding_modules = embedding_padding_modules
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
self.max_num_seqs = max_num_seqs self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = (
self.vocab_size = vocab_size vllm_config.scheduler_config.max_num_batched_tokens)
self.lora_config = lora_config self.vocab_size = vllm_config.model_config.get_vocab_size()
self.max_position_embeddings = max_position_embeddings self.lora_config = vllm_config.lora_config
# Use get_text_config() in case of multimodal models
text_config = vllm_config.model_config.hf_config.get_text_config()
self.max_position_embeddings = text_config.max_position_embeddings
self.device = device self.device = device
# Lazily initialized by create_lora_manager. # Lazily initialized by create_lora_manager.
self._adapter_manager: LoRAModelManager self._adapter_manager: LoRAModelManager
......
...@@ -107,9 +107,8 @@ class CPUModelRunner(GPUModelRunner): ...@@ -107,9 +107,8 @@ class CPUModelRunner(GPUModelRunner):
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config: if self.lora_config:
self.model = self.load_lora_model(self.model, self.model_config, self.model = self.load_lora_model(self.model, self.vllm_config,
self.scheduler_config, self.device)
self.lora_config, self.device)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
......
...@@ -2552,10 +2552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2552,10 +2552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model = model_loader.load_model( self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config) vllm_config=self.vllm_config, model_config=self.model_config)
if self.lora_config: if self.lora_config:
self.model = self.load_lora_model(self.model, self.model = self.load_lora_model(self.model, self.vllm_config,
self.model_config,
self.scheduler_config,
self.lora_config,
self.device) self.device)
if hasattr(self, "drafter"): if hasattr(self, "drafter"):
logger.info("Loading drafter model...") logger.info("Loading drafter model...")
......
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, SchedulerConfig from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
...@@ -31,9 +31,7 @@ class LoRAModelRunnerMixin: ...@@ -31,9 +31,7 @@ class LoRAModelRunnerMixin:
LORA_WARMUP_RANK = 8 LORA_WARMUP_RANK = 8
def load_lora_model(self, model: nn.Module, model_config: ModelConfig, def load_lora_model(self, model: nn.Module, vllm_config: VllmConfig,
scheduler_config: SchedulerConfig,
lora_config: LoRAConfig,
device: torch.device) -> nn.Module: device: torch.device) -> nn.Module:
if not supports_lora(model): if not supports_lora(model):
...@@ -44,19 +42,12 @@ class LoRAModelRunnerMixin: ...@@ -44,19 +42,12 @@ class LoRAModelRunnerMixin:
logger.warning("Regarding multimodal models, vLLM currently " logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.") "only supports adding LoRA to language model.")
# Use get_text_config() in case of multimodal models
text_config = model_config.hf_config.get_text_config()
# Add LoRA Manager to the Model Runner # Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
scheduler_config.max_num_seqs, vllm_config,
scheduler_config.max_num_batched_tokens,
model_config.get_vocab_size(),
lora_config,
device, device,
model.embedding_modules, model.embedding_modules,
model.embedding_padding_modules, model.embedding_padding_modules,
max_position_embeddings=text_config.max_position_embeddings,
) )
return self.lora_manager.create_lora_manager(model) return self.lora_manager.create_lora_manager(model)
......
...@@ -1178,9 +1178,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1178,9 +1178,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"or sharding the weights on more chips. " "or sharding the weights on more chips. "
f"See the detailed error: {e}") from e f"See the detailed error: {e}") from e
if self.lora_config is not None: if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config, model = self.load_lora_model(model, self.vllm_config, self.device)
self.scheduler_config,
self.lora_config, self.device)
replace_set_lora(model) replace_set_lora(model)
# Sync all pending XLA execution during model initialization and weight # Sync all pending XLA execution during model initialization and weight
......
...@@ -1078,20 +1078,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1078,20 +1078,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"Regarding multimodal models, vLLM currently " "Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.") "only supports adding LoRA to language model.")
# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.vllm_config,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device, self.device,
self.model.embedding_modules, self.model.embedding_modules,
self.model.embedding_padding_modules, self.model.embedding_padding_modules,
max_position_embeddings=text_config.
max_position_embeddings,
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment