Unverified Commit 9b945daa authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Experimental] Add multi-LoRA support (#1804)


Co-authored-by: default avatarChen Shen <scv119@gmail.com>
Co-authored-by: default avatarShreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: default avatarAvnish Narayan <avnish@anyscale.com>
parent 9c1352eb
...@@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import ( ...@@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value.""" """Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to return ((vocab_size + pad_to - 1) // pad_to) * pad_to
...@@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module):
num_embeddings: vocabulary size. num_embeddings: vocabulary size.
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
params_dtype: type of the parameters. params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
params_dtype: Optional[torch.dtype] = None): params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__() super().__init__()
# Keep the input dimensions. # Keep the input dimensions.
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings) self.org_vocab_size = org_num_embeddings or num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
padding_size)
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
...@@ -77,7 +86,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -77,7 +86,7 @@ class VocabParallelEmbedding(torch.nn.Module):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.num_embeddings assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.vocab_start_index:self. loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index] vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
...@@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding):
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
bias: whether to use bias. bias: whether to use bias.
params_dtype: type of the parameters. params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
bias: bool = False, bias: bool = False,
params_dtype: Optional[torch.dtype] = None): params_dtype: Optional[torch.dtype] = None,
super().__init__(num_embeddings, embedding_dim, params_dtype) org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
......
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib import contextlib
from typing import Type from typing import Optional, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig, LoRAConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
...@@ -32,7 +32,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: ...@@ -32,7 +32,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
f"Supported architectures: {ModelRegistry.get_supported_archs()}") f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
# Get the (maybe quantized) linear method. # Get the (maybe quantized) linear method.
...@@ -62,7 +63,17 @@ def get_model(model_config: ModelConfig) -> nn.Module: ...@@ -62,7 +63,17 @@ def get_model(model_config: ModelConfig) -> nn.Module:
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
with torch.device("cuda"): with torch.device("cuda"):
model = model_class(model_config.hf_config, linear_method) if getattr(model_class, "supports_lora", False):
model = model_class(model_config.hf_config, linear_method,
lora_config)
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
else:
model = model_class(model_config.hf_config, linear_method)
if model_config.load_format == "dummy": if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
......
...@@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -225,14 +226,19 @@ class LlamaModel(nn.Module): ...@@ -225,14 +226,19 @@ class LlamaModel(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config, linear_method) LlamaDecoderLayer(config, linear_method)
...@@ -263,18 +269,31 @@ class LlamaModel(nn.Module): ...@@ -263,18 +269,31 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
supports_lora = True
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = LlamaModel(config, linear_method) self.model = LlamaModel(config, linear_method, lora_config=lora_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) unpadded_vocab_size = config.vocab_size
self.sampler = Sampler(config.vocab_size) if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
def forward( def forward(
self, self,
......
...@@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -220,15 +221,20 @@ class MistralModel(nn.Module): ...@@ -220,15 +221,20 @@ class MistralModel(nn.Module):
self, self,
config: MistralConfig, config: MistralConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MistralDecoderLayer(config, linear_method) MistralDecoderLayer(config, linear_method)
...@@ -259,18 +265,33 @@ class MistralModel(nn.Module): ...@@ -259,18 +265,33 @@ class MistralModel(nn.Module):
class MistralForCausalLM(nn.Module): class MistralForCausalLM(nn.Module):
supports_lora = True
def __init__( def __init__(
self, self,
config: MistralConfig, config: MistralConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = MistralModel(config, linear_method) self.model = MistralModel(config,
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) linear_method,
self.sampler = Sampler(config.vocab_size) lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
def forward( def forward(
self, self,
......
...@@ -195,10 +195,14 @@ def get_pipeline_model_parallel_prev_rank(): ...@@ -195,10 +195,14 @@ def get_pipeline_model_parallel_prev_rank():
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none and destroy them."""
global _TENSOR_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
if _TENSOR_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
if _PIPELINE_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
...@@ -2,6 +2,7 @@ from typing import List, Optional ...@@ -2,6 +2,7 @@ from typing import List, Optional
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
SequenceStatus) SequenceStatus)
from vllm.lora.request import LoRARequest
class CompletionOutput: class CompletionOutput:
...@@ -16,6 +17,7 @@ class CompletionOutput: ...@@ -16,6 +17,7 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested. position if the logprobs are requested.
finish_reason: The reason why the sequence is finished. finish_reason: The reason why the sequence is finished.
lora_request: The LoRA request that was used to generate the output.
""" """
def __init__( def __init__(
...@@ -26,6 +28,7 @@ class CompletionOutput: ...@@ -26,6 +28,7 @@ class CompletionOutput:
cumulative_logprob: float, cumulative_logprob: float,
logprobs: Optional[SampleLogprobs], logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None, finish_reason: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.index = index self.index = index
self.text = text self.text = text
...@@ -33,6 +36,7 @@ class CompletionOutput: ...@@ -33,6 +36,7 @@ class CompletionOutput:
self.cumulative_logprob = cumulative_logprob self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs self.logprobs = logprobs
self.finish_reason = finish_reason self.finish_reason = finish_reason
self.lora_request = lora_request
def finished(self) -> bool: def finished(self) -> bool:
return self.finish_reason is not None return self.finish_reason is not None
...@@ -56,6 +60,7 @@ class RequestOutput: ...@@ -56,6 +60,7 @@ class RequestOutput:
prompt_logprobs: The log probabilities to return per prompt token. prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
lora_request: The LoRA request that was used to generate the output.
""" """
def __init__( def __init__(
...@@ -66,6 +71,7 @@ class RequestOutput: ...@@ -66,6 +71,7 @@ class RequestOutput:
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
finished: bool, finished: bool,
lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
...@@ -73,6 +79,7 @@ class RequestOutput: ...@@ -73,6 +79,7 @@ class RequestOutput:
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = prompt_logprobs
self.outputs = outputs self.outputs = outputs
self.finished = finished self.finished = finished
self.lora_request = lora_request
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
...@@ -108,8 +115,13 @@ class RequestOutput: ...@@ -108,8 +115,13 @@ class RequestOutput:
prompt_token_ids = seq_group.prompt_token_ids prompt_token_ids = seq_group.prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished() finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids, return cls(seq_group.request_id,
prompt_logprobs, outputs, finished) prompt,
prompt_token_ids,
prompt_logprobs,
outputs,
finished,
lora_request=seq_group.lora_request)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
...@@ -117,4 +129,5 @@ class RequestOutput: ...@@ -117,4 +129,5 @@ class RequestOutput:
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"finished={self.finished})") f"finished={self.finished}, "
f"lora_request={self.lora_request})")
...@@ -74,13 +74,14 @@ class PrefixPool: ...@@ -74,13 +74,14 @@ class PrefixPool:
new_length = len(token_ids) // self.block_size * self.block_size new_length = len(token_ids) // self.block_size * self.block_size
return tuple(token_ids[:new_length]) return tuple(token_ids[:new_length])
def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: def add_or_get_prefix(self, token_ids: Sequence[int],
lora_int_id: int) -> Optional[Prefix]:
token_ids = self._truncate_token_ids(token_ids) token_ids = self._truncate_token_ids(token_ids)
if len(token_ids) == 0: if len(token_ids) == 0:
# Prefix is empty. # Prefix is empty.
return None return None
prefix = Prefix(token_ids, self.block_size) prefix = Prefix(token_ids, self.block_size)
prefix_hash = hash(prefix) prefix_hash = hash((prefix, lora_int_id))
if prefix_hash not in self.prefixes: if prefix_hash not in self.prefixes:
self.prefixes[prefix_hash] = prefix self.prefixes[prefix_hash] = prefix
return self.prefixes[prefix_hash] return self.prefixes[prefix_hash]
...@@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Union ...@@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
from vllm.prefix import Prefix from vllm.prefix import Prefix
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest
PromptLogprobs = List[Optional[Dict[int, float]]] PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]] SampleLogprobs = List[Dict[int, float]]
...@@ -106,6 +107,7 @@ class Sequence: ...@@ -106,6 +107,7 @@ class Sequence:
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine. block size used by the block manager and cache engine.
lora_request: LoRA request.
""" """
def __init__( def __init__(
...@@ -114,10 +116,12 @@ class Sequence: ...@@ -114,10 +116,12 @@ class Sequence:
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
block_size: int, block_size: int,
lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.prompt = prompt self.prompt = prompt
self.block_size = block_size self.block_size = block_size
self.lora_request = lora_request
self.data = SequenceData(prompt_token_ids) self.data = SequenceData(prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
...@@ -134,6 +138,10 @@ class Sequence: ...@@ -134,6 +138,10 @@ class Sequence:
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[List[str]] = None
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def _append_logical_block(self) -> None: def _append_logical_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks), block_number=len(self.logical_token_blocks),
...@@ -229,6 +237,7 @@ class SequenceGroup: ...@@ -229,6 +237,7 @@ class SequenceGroup:
seqs: The list of sequences. seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request. arrival_time: The arrival time of the request.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group. prefix: The prefix of the prompt of the sequence group.
""" """
...@@ -238,12 +247,14 @@ class SequenceGroup: ...@@ -238,12 +247,14 @@ class SequenceGroup:
seqs: List[Sequence], seqs: List[Sequence],
sampling_params: SamplingParams, sampling_params: SamplingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None, prefix: Optional[Prefix] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.arrival_time = arrival_time self.arrival_time = arrival_time
self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
...@@ -259,6 +270,10 @@ class SequenceGroup: ...@@ -259,6 +270,10 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids return next(iter(self.seqs_dict.values())).data.prompt_token_ids
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def get_max_num_running_seqs(self) -> int: def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
lifetime of the request.""" lifetime of the request."""
...@@ -338,6 +353,7 @@ class SequenceGroupMetadata: ...@@ -338,6 +353,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block block_tables: The block tables. (Seq id -> list of physical block
numbers) numbers)
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group. prefix: The prefix of the prompt of the sequence group.
""" """
...@@ -348,6 +364,7 @@ class SequenceGroupMetadata: ...@@ -348,6 +364,7 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData], seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], block_tables: Dict[int, List[int]],
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None, prefix: Optional[Prefix] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
...@@ -355,8 +372,13 @@ class SequenceGroupMetadata: ...@@ -355,8 +372,13 @@ class SequenceGroupMetadata:
self.seq_data = seq_data self.seq_data = seq_data
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.block_tables = block_tables self.block_tables = block_tables
self.lora_request = lora_request
self.prefix = prefix self.prefix = prefix
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
class SequenceOutput: class SequenceOutput:
"""The model output associated with a sequence. """The model output associated with a sequence.
......
...@@ -4,6 +4,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, ...@@ -4,6 +4,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.transformers_utils.tokenizers import * from vllm.transformers_utils.tokenizers import *
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -65,6 +67,84 @@ def get_tokenizer( ...@@ -65,6 +67,84 @@ def get_tokenizer(
return tokenizer return tokenizer
def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[PreTrainedTokenizer]:
if lora_request is None:
return None
try:
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
**kwargs)
except OSError as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
f"No tokenizer found in {lora_request.lora_local_path}, "
"using base model tokenizer instead. "
f"(Exception: {str(e)})")
tokenizer = None
return tokenizer
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
def _convert_tokens_to_string_with_added_encoders( def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str], output_tokens: List[str],
......
...@@ -7,6 +7,17 @@ from typing import List ...@@ -7,6 +7,17 @@ from typing import List
import psutil import psutil
import torch import torch
import asyncio
from functools import partial
from typing import (
Awaitable,
Callable,
TypeVar,
)
from collections import OrderedDict
from typing import Any, Hashable, Optional
T = TypeVar("T")
class Device(enum.Enum): class Device(enum.Enum):
...@@ -28,6 +39,69 @@ class Counter: ...@@ -28,6 +39,69 @@ class Counter:
self.counter = 0 self.counter = 0
class LRUCache:
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> Any:
return self.get(key)
def __setitem__(self, key: Hashable, value: Any) -> None:
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
self.pop(key)
def touch(self, key: Hashable) -> None:
self.cache.move_to_end(key)
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
return value
def put(self, key: Hashable, value: Any) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: Any):
pass
def remove_oldest(self):
if not self.cache:
return
key, value = self.cache.popitem(last=False)
self._on_remove(key, value)
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
run_on_remove = key in self.cache
value = self.cache.pop(key, default_value)
if run_on_remove:
self._on_remove(key, value)
return value
def clear(self):
while len(self.cache) > 0:
self.remove_oldest()
self.cache.clear()
def is_hip() -> bool: def is_hip() -> bool:
return torch.version.hip is not None return torch.version.hip is not None
...@@ -59,6 +133,22 @@ def in_wsl() -> bool: ...@@ -59,6 +133,22 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args, **kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=None, func=p_func)
return _async_wrapper
def get_ip() -> str: def get_ip() -> str:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
......
import time import time
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Set, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.utils import in_wsl from vllm.utils import in_wsl
logger = init_logger(__name__) logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1 _PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
...@@ -30,19 +34,23 @@ class ModelRunner: ...@@ -30,19 +34,23 @@ class ModelRunner:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py. # model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this. # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window() self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None) if model_config is not None else None)
self.device = torch.device(torch.cuda.current_device())
self.model = None self.model = None
self.block_size = None # Set after initial profiling. self.block_size = None # Set after initial profiling.
self.lora_manager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture. self.graph_memory_pool = None # Set during graph capture.
...@@ -61,7 +69,17 @@ class ModelRunner: ...@@ -61,7 +69,17 @@ class ModelRunner:
self.in_wsl = in_wsl() self.in_wsl = in_wsl()
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(self.model_config) self.model = get_model(self.model_config, self.lora_config)
vocab_size = self.model.config.vocab_size
if self.lora_config:
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, vocab_size,
self.lora_config, self.device)
self.model = self.lora_manager.create_lora_manager(self.model)
def set_block_size(self, block_size: int) -> None: def set_block_size(self, block_size: int) -> None:
self.block_size = block_size self.block_size = block_size
...@@ -74,12 +92,15 @@ class ModelRunner: ...@@ -74,12 +92,15 @@ class ModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
List[int]]: List[int], List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_lens: List[int] = [] prompt_lens: List[int] = []
context_lens: List[int] = [] context_lens: List[int] = []
...@@ -113,6 +134,17 @@ class ModelRunner: ...@@ -113,6 +134,17 @@ class ModelRunner:
input_positions.append( input_positions.append(
list(range(prefix_len, prefix_len + len(prompt_tokens)))) list(range(prefix_len, prefix_len + len(prompt_tokens))))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping.append([lora_id] * prompt_len)
lora_prompt_mapping.extend(
[lora_id] *
(prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping. # yet. In this case, we just use a dummy slot mapping.
...@@ -156,6 +188,10 @@ class ModelRunner: ...@@ -156,6 +188,10 @@ class ModelRunner:
max_prompt_len, max_prompt_len,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long) dtype=torch.long)
lora_index_mapping = [
_pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping
]
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device='cuda') device='cuda')
...@@ -188,23 +224,33 @@ class ModelRunner: ...@@ -188,23 +224,33 @@ class ModelRunner:
use_cuda_graph=False, use_cuda_graph=False,
) )
return (input_tokens, input_positions, input_metadata, prompt_lens, return (input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens) subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
context_lens: List[int] = [] context_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id() generation_token = seq_data.get_last_token_id()
...@@ -223,6 +269,8 @@ class ModelRunner: ...@@ -223,6 +269,8 @@ class ModelRunner:
block_offset = position % self.block_size block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append([slot]) slot_mapping.append([slot])
lora_index_mapping.append([lora_id])
lora_prompt_mapping.append(lora_id)
if self.sliding_window is not None: if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window // sliding_window_blocks = (self.sliding_window //
...@@ -287,6 +335,10 @@ class ModelRunner: ...@@ -287,6 +335,10 @@ class ModelRunner:
device="cuda", device="cuda",
) )
lora_index_mapping = [
_pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
]
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
...@@ -298,7 +350,7 @@ class ModelRunner: ...@@ -298,7 +350,7 @@ class ModelRunner:
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
) )
return input_tokens, input_positions, input_metadata return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
def _prepare_sample( def _prepare_sample(
self, self,
...@@ -375,7 +427,8 @@ class ModelRunner: ...@@ -375,7 +427,8 @@ class ModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
Set[int], LoRAMapping]:
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
...@@ -383,16 +436,29 @@ class ModelRunner: ...@@ -383,16 +436,29 @@ class ModelRunner:
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, input_metadata, prompt_lens, (input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens) = self._prepare_prompt(seq_group_metadata_list) subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, input_metadata (input_tokens, input_positions, input_metadata,
) = self._prepare_decode(seq_group_metadata_list) lora_index_mapping, lora_prompt_mapping,
subquery_lens = None lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = [] prompt_lens = []
subquery_lens = None
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens) subquery_lens)
if self.lora_config:
flat_lora_index_mapping = [
item for sublist in lora_index_mapping for item in sublist
]
lora_mapping = LoRAMapping(
flat_lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
# Broadcast the metadata. # Broadcast the metadata.
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,
...@@ -408,12 +474,16 @@ class ModelRunner: ...@@ -408,12 +474,16 @@ class ModelRunner:
"use_cuda_graph": input_metadata.use_cuda_graph, "use_cuda_graph": input_metadata.use_cuda_graph,
"selected_token_indices": "selected_token_indices":
sampling_metadata.selected_token_indices, sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
} }
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
else: else:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict["input_tokens"] input_tokens = metadata_dict["input_tokens"]
input_positions = metadata_dict["input_positions"] input_positions = metadata_dict["input_positions"]
lora_mapping = metadata_dict["lora_mapping"]
lora_requests = metadata_dict["lora_requests"]
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=metadata_dict["is_prompt"], is_prompt=metadata_dict["is_prompt"],
slot_mapping=metadata_dict["slot_mapping"], slot_mapping=metadata_dict["slot_mapping"],
...@@ -434,7 +504,7 @@ class ModelRunner: ...@@ -434,7 +504,7 @@ class ModelRunner:
perform_sampling=False, perform_sampling=False,
) )
return input_tokens, input_positions, input_metadata, sampling_metadata return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
...@@ -442,8 +512,12 @@ class ModelRunner: ...@@ -442,8 +512,12 @@ class ModelRunner:
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
input_tokens, input_positions, input_metadata, sampling_metadata = ( input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = (
self.prepare_input_tensors(seq_group_metadata_list)) self.prepare_input_tensors(seq_group_metadata_list))
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Execute the model. # Execute the model.
if input_metadata.use_cuda_graph: if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] graph_batch_size = input_tokens.shape[0]
...@@ -472,6 +546,28 @@ class ModelRunner: ...@@ -472,6 +546,28 @@ class ModelRunner:
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests = []
dummy_lora_requests_per_seq = []
if self.lora_config:
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total # Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = [] seqs: List[SequenceGroupMetadata] = []
...@@ -485,6 +581,8 @@ class ModelRunner: ...@@ -485,6 +581,8 @@ class ModelRunner:
seq_data={group_id: seq_data}, seq_data={group_id: seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
) )
seqs.append(seq) seqs.append(seq)
...@@ -495,6 +593,32 @@ class ModelRunner: ...@@ -495,6 +593,32 @@ class ModelRunner:
torch.cuda.synchronize() torch.cuda.synchronize()
return return
def remove_all_loras(self) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_loras()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None: def capture_model(self, kv_caches: List[KVCache]) -> None:
assert not self.model_config.enforce_eager assert not self.model_config.enforce_eager
...@@ -541,6 +665,13 @@ class ModelRunner: ...@@ -541,6 +665,13 @@ class ModelRunner:
use_cuda_graph=True, use_cuda_graph=True,
) )
if self.lora_config:
lora_mapping = LoRAMapping(
[0] * batch_size,
[0] * batch_size,
)
self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model) graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture( graph_runner.capture(
input_tokens[:batch_size], input_tokens[:batch_size],
......
"""A GPU worker class.""" """A GPU worker class."""
import gc
import os import os
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Tuple, Set, Optional
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
...@@ -15,6 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -15,6 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.lora.request import LoRARequest
class Worker: class Worker:
...@@ -33,6 +35,7 @@ class Worker: ...@@ -33,6 +35,7 @@ class Worker:
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
...@@ -41,12 +44,16 @@ class Worker: ...@@ -41,12 +44,16 @@ class Worker:
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config, parallel_config, self.model_runner = ModelRunner(model_config,
scheduler_config, is_driver_worker) parallel_config,
scheduler_config,
lora_config=self.lora_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
...@@ -117,6 +124,9 @@ class Worker: ...@@ -117,6 +124,9 @@ class Worker:
num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks
...@@ -199,6 +209,15 @@ class Worker: ...@@ -199,6 +209,15 @@ class Worker:
self.gpu_cache) self.gpu_cache)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def _init_distributed_environment( def _init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment