Unverified Commit 8bddb735 authored by Akshat Tripathi's avatar Akshat Tripathi Committed by GitHub
Browse files

[Hardware][CPU] Multi-LoRA implementation for the CPU backend (#11100)


Signed-off-by: default avatarAkshat Tripathi <akshat@krai.ai>
Signed-off-by: default avatarOleg Mosalov <oleg@krai.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarOleg Mosalov <oleg@krai.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent f967e51f
from typing import Callable, Optional, Tuple, Union
import torch
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from .punica_base import PunicaWrapperBase
# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
class PunicaWrapperCPU(PunicaWrapperBase):
"""
PunicaWrapperCPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_inputs,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_inputs)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, scale: float):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
scale)
def add_expand(self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = tuple(
torch.zeros(
(x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices)))
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_inputs=True,
**kwargs)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)
...@@ -12,11 +12,11 @@ import torch ...@@ -12,11 +12,11 @@ import torch
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
if HAS_TRITON: if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand from vllm.lora.ops.triton_ops import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.triton_ops import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.triton_ops import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.triton_ops import sgmv_expand
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.triton_ops import sgmv_shrink
from .punica_base import PunicaWrapperBase from .punica_base import PunicaWrapperBase
......
...@@ -12,6 +12,11 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: ...@@ -12,6 +12,11 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
logger.info_once("Using PunicaWrapperGPU.") logger.info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs) return PunicaWrapperGPU(*args, **kwargs)
elif current_platform.is_cpu():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
logger.info_once("Using PunicaWrapperCPU.")
return PunicaWrapperCPU(*args, **kwargs)
elif current_platform.is_hpu(): elif current_platform.is_hpu():
# Lazy import to avoid ImportError # Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
......
...@@ -2,8 +2,8 @@ import dataclasses ...@@ -2,8 +2,8 @@ import dataclasses
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
Union) TypeVar, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -12,10 +12,14 @@ from vllm.attention import AttentionMetadata, get_attn_backend ...@@ -12,10 +12,14 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap) MultiModalKwargs, MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData, from vllm.sequence import (IntermediateTensors, SequenceData,
...@@ -49,6 +53,8 @@ class ModelInputForCPU(ModelRunnerInputBase): ...@@ -49,6 +53,8 @@ class ModelInputForCPU(ModelRunnerInputBase):
virtual_engine: Optional[int] = None virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
...@@ -57,6 +63,8 @@ class ModelInputForCPU(ModelRunnerInputBase): ...@@ -57,6 +63,8 @@ class ModelInputForCPU(ModelRunnerInputBase):
"input_positions": self.input_positions, "input_positions": self.input_positions,
"token_type_ids": self.token_type_ids, "token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
...@@ -143,7 +151,11 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -143,7 +151,11 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
or runner.cache_config.enable_prefix_caching) or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.enable_lora = self.runner.lora_config is not None
self.input_data = ModelInputForCPUBuilder.ModelInputData( self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope) self.runner.model_config.uses_mrope)
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
...@@ -183,15 +195,28 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -183,15 +195,28 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
attn_metadata = self.att_metadata_builder.build( attn_metadata = self.att_metadata_builder.build(
input_data.seq_lens, input_data.query_lens, -1, -1) input_data.seq_lens, input_data.query_lens, -1, -1)
return self.model_input_cls( is_prompt = (self.seq_group_metadata_list[0].is_prompt
input_tokens=input_tokens, if self.seq_group_metadata_list else None)
# LoRA data.
lora_requests = set()
lora_mapping = None
if self.enable_lora:
lora_requests = set(seq.lora_request
for seq in self.seq_group_metadata_list
if seq.lora_request is not None)
lora_mapping = self._prepare_lora_input(
self.seq_group_metadata_list, is_prompt)
return self.model_input_cls(input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
seq_lens=input_data.seq_lens, seq_lens=input_data.seq_lens,
query_lens=input_data.query_lens, query_lens=input_data.query_lens,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs, multi_modal_kwargs=multi_modal_kwargs,
) lora_mapping=lora_mapping,
lora_requests=lora_requests)
def _build_input_data(self): def _build_input_data(self):
for seq_group_metadata in self.seq_group_metadata_list: for seq_group_metadata in self.seq_group_metadata_list:
...@@ -381,6 +406,24 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -381,6 +406,24 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
self.input_data.multi_modal_placeholder_maps[modality].extend( self.input_data.multi_modal_placeholder_maps[modality].extend(
placeholder_map) placeholder_map)
def _prepare_lora_input(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
is_prefill: bool) -> LoRAMapping:
index_mapping = []
prompt_mapping = []
for seq in seq_group_metadata_list:
lora_id = seq.lora_int_id
query_len = seq.token_chunk_size
index_mapping += [lora_id] * query_len
prompt_mapping += [lora_id] * (
query_len if seq.sampling_params
and seq.sampling_params.prompt_logprobs is not None else 1)
return LoRAMapping(index_mapping=tuple(index_mapping),
prompt_mapping=tuple(prompt_mapping),
is_prefill=is_prefill)
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
""" """
...@@ -431,10 +474,41 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -431,10 +474,41 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
assert supports_lora(
self.model
), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# It's necessary to distinguish between the max_position_embeddings
# of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = self.model.config.max_position_embeddings
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
def _prepare_model_input_tensors( def _prepare_model_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
...@@ -459,6 +533,37 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -459,6 +533,37 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
def vocab_size(self) -> int: def vocab_size(self) -> int:
return self.model_config.get_vocab_size() return self.model_config.get_vocab_size()
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
...@@ -515,6 +620,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): ...@@ -515,6 +620,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
raise ValueError( raise ValueError(
"CPU worker does not support multi-step execution.") "CPU worker does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
model_executable = self.model model_executable = self.model
multimodal_kwargs = {} multimodal_kwargs = {}
......
"""A CPU worker class.""" """A CPU worker class."""
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Set, Tuple, Type
import torch import torch
import torch.distributed import torch.distributed
...@@ -11,14 +11,14 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ...@@ -11,14 +11,14 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -111,7 +111,7 @@ class CPUCacheEngine: ...@@ -111,7 +111,7 @@ class CPUCacheEngine:
return dtype_size * total return dtype_size * total
class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): class CPUWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a CPU socket. """A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is Each worker is associated with a single CPU socket. The worker is
...@@ -266,6 +266,18 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -266,6 +266,18 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# Initialize the cache. # Initialize the cache.
self._init_cache_engine() self._init_cache_engine()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
"""Raise errors if the num_cpu_blocks is invalid. """Raise errors if the num_cpu_blocks is invalid.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment