Commit 51679bbd authored by zhuwenwen's avatar zhuwenwen
Browse files

resolve merge confilcts

parents 4095d0db 1af090b5
from typing import Dict, List, Sequence, Tuple, Optional
from vllm.block import BlockTable
class Prefix:
"""Data and states associated with a prefix of prompt tokens for multiple
sequence groups.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
token_ids: The token ids of the prefix.
block_size: The block size of the executed model.
"""
def __init__(
self,
token_ids: Sequence[int],
block_size: int,
) -> None:
self.token_ids = tuple(token_ids)
self.block_size = block_size
self.length = len(token_ids)
self.hash = hash(token_ids)
assert self.length % block_size == 0
self.block_table: Optional[BlockTable] = None
self.computed = False
@property
def allocated(self) -> bool:
return self.block_table is not None
def get_num_blocks(self) -> int:
return self.length // self.block_size
def get_block_numbers(self) -> List[int]:
return [block.block_number for block in self.block_table]
def get_length(self) -> int:
return self.length
def __hash__(self) -> int:
return self.hash
def set_block_table(self, block_table: BlockTable) -> None:
self.block_table = block_table.copy()
class PrefixPool:
"""Manages all the prompt prefixes.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
block_size: The block size of the executed model.
Attributes:
prefixes: A list of all the prefixes.
block_size: The block size of the executed model.
"""
def __init__(
self,
block_size: int,
) -> None:
# TODO(zhuohan): Add a capacity limit to the prefix pool.
self.prefixes: Dict[int, Prefix] = {}
self.block_size = block_size
def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
new_length = len(token_ids) // self.block_size * self.block_size
return tuple(token_ids[:new_length])
def add_or_get_prefix(self, token_ids: Sequence[int],
lora_int_id: int) -> Optional[Prefix]:
token_ids = self._truncate_token_ids(token_ids)
if len(token_ids) == 0:
# Prefix is empty.
return None
prefix = Prefix(token_ids, self.block_size)
prefix_hash = hash((prefix, lora_int_id))
if prefix_hash not in self.prefixes:
self.prefixes[prefix_hash] = prefix
return self.prefixes[prefix_hash]
......@@ -108,7 +108,7 @@ class SamplingParams:
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: int = 16,
max_tokens: Optional[int] = 16,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
......@@ -183,7 +183,7 @@ class SamplingParams:
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0, 1], got "
f"{self.min_p}.")
if self.max_tokens < 1:
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.logprobs is not None and self.logprobs < 0:
......
......@@ -4,7 +4,9 @@ import enum
from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock
from vllm.prefix import Prefix
from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest
PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]]
......@@ -105,6 +107,7 @@ class Sequence:
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
lora_request: LoRA request.
"""
def __init__(
......@@ -113,10 +116,12 @@ class Sequence:
prompt: str,
prompt_token_ids: List[int],
block_size: int,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.seq_id = seq_id
self.prompt = prompt
self.block_size = block_size
self.lora_request = lora_request
self.data = SequenceData(prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
......@@ -133,6 +138,10 @@ class Sequence:
# Input + output tokens
self.tokens: Optional[List[str]] = None
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
......@@ -228,6 +237,8 @@ class SequenceGroup:
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
"""
def __init__(
......@@ -236,11 +247,15 @@ class SequenceGroup:
seqs: List[Sequence],
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
) -> None:
self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.arrival_time = arrival_time
self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None
@property
......@@ -255,6 +270,10 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
......@@ -327,7 +346,6 @@ class SequenceGroup:
class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
......@@ -335,6 +353,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
"""
def __init__(
......@@ -344,12 +364,20 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
self.seq_data = seq_data
self.sampling_params = sampling_params
self.block_tables = block_tables
self.lora_request = lora_request
self.prefix = prefix
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
class SequenceOutput:
......
import ray
from vllm.config import ParallelConfig
from vllm.utils import get_open_port
from vllm.worker.worker import init_distributed_environment
def init_test_distributed_environment(
pipeline_parallel_size: int,
tensor_parallel_size: int,
rank: int,
distributed_init_port: str,
) -> None:
parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(parallel_config, rank,
distributed_init_method)
def multi_process_tensor_parallel(
tensor_parallel_size: int,
test_target,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
ray.init()
distributed_init_port = get_open_port()
refs = []
for rank in range(tensor_parallel_size):
refs.append(
test_target.remote(tensor_parallel_size, rank,
distributed_init_port))
ray.get(refs)
ray.shutdown()
......@@ -4,6 +4,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.transformers_utils.tokenizers import *
logger = init_logger(__name__)
......@@ -65,6 +67,84 @@ def get_tokenizer(
return tokenizer
def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[PreTrainedTokenizer]:
if lora_request is None:
return None
try:
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
**kwargs)
except OSError as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
f"No tokenizer found in {lora_request.lora_local_path}, "
"using base model tokenizer instead. "
f"(Exception: {str(e)})")
tokenizer = None
return tokenizer
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
......
import enum
import os
import socket
import subprocess
import uuid
from platform import uname
from typing import List
from typing import List, Tuple, Union
from packaging.version import parse, Version
import psutil
import torch
import asyncio
from functools import partial
from typing import (
Awaitable,
Callable,
TypeVar,
)
from collections import OrderedDict
from typing import Any, Hashable, Optional
from vllm._C import cuda_utils
from vllm.logger import init_logger
T = TypeVar("T")
logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8_e5m2": torch.uint8,
}
class Device(enum.Enum):
......@@ -30,16 +51,83 @@ class Counter:
self.counter = 0
class LRUCache:
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> Any:
return self.get(key)
def __setitem__(self, key: Hashable, value: Any) -> None:
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
self.pop(key)
def touch(self, key: Hashable) -> None:
self.cache.move_to_end(key)
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
return value
def put(self, key: Hashable, value: Any) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: Any):
pass
def remove_oldest(self):
if not self.cache:
return
key, value = self.cache.popitem(last=False)
self._on_remove(key, value)
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
run_on_remove = key in self.cache
value = self.cache.pop(key, default_value)
if run_on_remove:
self._on_remove(key, value)
return value
def clear(self):
while len(self.cache) > 0:
self.remove_oldest()
self.cache.clear()
def is_hip() -> bool:
return torch.version.hip is not None
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
# NOTE: This import statement should be executed lazily since
# the Neuron-X backend does not have the `cuda_utils` module.
from vllm._C import cuda_utils
max_shared_mem = cuda_utils.get_max_shared_memory_per_block_device_attribute(
gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"
return int(max_shared_mem)
......@@ -57,8 +145,30 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower()
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args, **kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=None, func=p_func)
return _async_wrapper
def get_ip() -> str:
return socket.gethostbyname(socket.gethostname())
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://{ip}:{port}"
def get_open_port() -> int:
......@@ -69,3 +179,99 @@ def get_open_port() -> int:
def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
def get_nvcc_cuda_version() -> Version:
cuda_home = os.environ.get('CUDA_HOME')
if not cuda_home:
cuda_home = '/usr/local/cuda'
logger.info(
f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
)
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
return nvcc_cuda_version
def _generate_random_fp8_e5m2(
tensor: torch.tensor,
low: float,
high: float,
) -> None:
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
# it may occur Inf or NaN if we directly use torch.randint
# to generate random data for fp8 data.
# For example, s.11111.00 in fp8e5m2 format repesents Inf.
# | E4M3 | E5M2
#-----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
from vllm._C import cache_ops
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high)
cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
del tensor_tmp
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
if isinstance(model_dtype, str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype
else:
raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8_e5m2":
torch_dtype = torch.uint8
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype):
torch_dtype = cache_dtype
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(key_cache, -scale, scale)
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(value_cache, -scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches
......@@ -6,7 +6,7 @@ import torch
from vllm._C import cache_ops
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import in_wsl
from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
......@@ -34,12 +34,16 @@ class CacheEngine:
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.dtype = model_config.dtype
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache()
......@@ -142,6 +146,7 @@ class CacheEngine:
@staticmethod
def get_cache_block_size(
block_size: int,
cache_dtype: str,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
......@@ -152,7 +157,11 @@ class CacheEngine:
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = _get_dtype_size(model_config.dtype)
if cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
dtype_size = _get_dtype_size(dtype)
return dtype_size * total
......
import time
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Set, Union
import numpy as np
import torch
import torch.nn as nn
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import (
broadcast, broadcast_object_list)
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils import custom_all_reduce
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.utils import in_wsl
logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
......@@ -30,19 +35,24 @@ class ModelRunner:
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None)
self.device = torch.device(torch.cuda.current_device())
self.model = None
self.block_size = None # Set after initial profiling.
self.lora_manager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
......@@ -59,9 +69,20 @@ class ModelRunner:
self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result
self.in_wsl = in_wsl()
self.kv_cache_dtype = kv_cache_dtype
def load_model(self) -> None:
self.model = get_model(self.model_config)
self.model = get_model(self.model_config, self.lora_config)
vocab_size = self.model.config.vocab_size
if self.lora_config:
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, vocab_size,
self.lora_config, self.device)
self.model = self.lora_manager.create_lora_manager(self.model)
def set_block_size(self, block_size: int) -> None:
self.block_size = block_size
......@@ -74,13 +95,20 @@ class ModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]:
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
List[int], List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_lens: List[int] = []
context_lens: List[int] = []
subquery_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
......@@ -91,11 +119,34 @@ class ModelRunner:
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
prefix_len = 0
prefix = seq_group_metadata.prefix
if prefix is not None and prefix.computed:
prefix_len = prefix.get_length()
prompt_tokens = prompt_tokens[prefix_len:]
prefix_block_tables.append(prefix.get_block_numbers())
else:
prefix_block_tables.append([])
# actual prompt lens
context_lens.append(prefix_len)
subquery_lens.append(prompt_len - prefix_len)
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(list(range(prompt_len)))
input_positions.append(
list(range(prefix_len, prefix_len + len(prompt_tokens))))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping.append([lora_id] * prompt_len)
lora_prompt_mapping.extend(
[lora_id] *
(prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
......@@ -113,8 +164,11 @@ class ModelRunner:
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
assert prefix_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window)
for i in range(prompt_len):
for i in range(prefix_len, prompt_len):
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
continue
......@@ -124,7 +178,7 @@ class ModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
max_prompt_len = max(prompt_lens)
max_prompt_len = max(subquery_lens)
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
......@@ -137,32 +191,70 @@ class ModelRunner:
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long)
lora_index_mapping = [
_pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping
]
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device='cuda')
# Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = _make_tensor_with_pad(
prefix_block_tables,
max_len=max_prompt_block_table_len,
pad=0,
dtype=torch.int,
)
start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len,
max_prompt_len,
dtype=torch.long,
device='cuda')
prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long,
device='cuda')
input_metadata = InputMetadata(
is_prompt=True,
slot_mapping=slot_mapping,
prompt_lens=prompt_lens_tensor,
max_seq_len=max_prompt_len,
start_loc=start_loc_tensor,
max_context_len=None,
context_lens=None,
block_tables=None,
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
)
return input_tokens, input_positions, input_metadata, prompt_lens
return (input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
block_tables: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
......@@ -181,6 +273,8 @@ class ModelRunner:
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
lora_index_mapping.append([lora_id])
lora_prompt_mapping.append(lora_id)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
......@@ -235,28 +329,39 @@ class ModelRunner:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device="cuda")
else:
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = _make_tensor_with_pad(
block_tables,
max_len=max_context_len,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device="cuda",
)
lora_index_mapping = [
_pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
]
input_metadata = InputMetadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
max_seq_len=None,
start_loc=None,
max_context_len=max_context_len,
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
)
return input_tokens, input_positions, input_metadata
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
......@@ -264,7 +369,7 @@ class ModelRunner:
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
max_prompt_len = max(prompt_lens) if prompt_lens else 1
max_subquery_len = max(subquery_lens) if subquery_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
......@@ -272,10 +377,11 @@ class ModelRunner:
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
prompt_len = prompt_lens[i]
assert subquery_lens is not None
subquery_len = subquery_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
......@@ -285,10 +391,10 @@ class ModelRunner:
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_prompt_len
subquery_len - 1)
selected_token_start_idx += max_subquery_len
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
......@@ -326,115 +432,86 @@ class ModelRunner:
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]:
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
Set[int], LoRAMapping]:
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_metadata,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
(input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, input_metadata
) = self._prepare_decode(seq_group_metadata_list)
(input_tokens, input_positions, input_metadata,
lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
subquery_lens = None
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
def get_size_or_none(x: Optional[torch.Tensor]):
return x.size() if x is not None else None
# Broadcast the input data. For input tensors, we first broadcast
# its shape and then broadcast the tensor to avoid high
# serialization cost.
py_data = {
"input_tokens_size":
input_tokens.size(),
"input_positions_size":
input_positions.size(),
"is_prompt":
input_metadata.is_prompt,
"slot_mapping_size":
get_size_or_none(input_metadata.slot_mapping),
"max_context_len":
input_metadata.max_context_len,
"context_lens_size":
get_size_or_none(input_metadata.context_lens),
"block_tables_size":
get_size_or_none(input_metadata.block_tables),
"use_cuda_graph":
input_metadata.use_cuda_graph,
"selected_token_indices_size":
sampling_metadata.selected_token_indices.size(),
prompt_lens,
subquery_lens)
if self.lora_config:
flat_lora_index_mapping = [
item for sublist in lora_index_mapping for item in sublist
]
lora_mapping = LoRAMapping(
flat_lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"is_prompt": input_metadata.is_prompt,
"slot_mapping": input_metadata.slot_mapping,
"prompt_lens": input_metadata.prompt_lens,
"max_seq_len": input_metadata.max_seq_len,
"start_loc": input_metadata.start_loc,
"max_context_len": input_metadata.max_context_len,
"context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables,
"use_cuda_graph": input_metadata.use_cuda_graph,
"kv_cache_dtype": input_metadata.kv_cache_dtype,
"selected_token_indices":
sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
}
broadcast_object_list([py_data], src=0)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast(input_tokens, src=0)
broadcast(input_positions, src=0)
if input_metadata.slot_mapping is not None:
broadcast(input_metadata.slot_mapping, src=0)
if input_metadata.context_lens is not None:
broadcast(input_metadata.context_lens, src=0)
if input_metadata.block_tables is not None:
broadcast(input_metadata.block_tables, src=0)
broadcast(sampling_metadata.selected_token_indices, src=0)
broadcast_tensor_dict(metadata_dict, src=0)
else:
receving_list = [None]
broadcast_object_list(receving_list, src=0)
py_data = receving_list[0]
input_tokens = torch.empty(*py_data["input_tokens_size"],
dtype=torch.long,
device="cuda")
broadcast(input_tokens, src=0)
input_positions = torch.empty(*py_data["input_positions_size"],
dtype=torch.long,
device="cuda")
broadcast(input_positions, src=0)
if py_data["slot_mapping_size"] is not None:
slot_mapping = torch.empty(*py_data["slot_mapping_size"],
dtype=torch.long,
device="cuda")
broadcast(slot_mapping, src=0)
else:
slot_mapping = None
if py_data["context_lens_size"] is not None:
context_lens = torch.empty(*py_data["context_lens_size"],
dtype=torch.int,
device="cuda")
broadcast(context_lens, src=0)
else:
context_lens = None
if py_data["block_tables_size"] is not None:
block_tables = torch.empty(*py_data["block_tables_size"],
dtype=torch.int,
device="cuda")
broadcast(block_tables, src=0)
else:
block_tables = None
selected_token_indices = torch.empty(
*py_data["selected_token_indices_size"],
dtype=torch.long,
device="cuda")
broadcast(selected_token_indices, src=0)
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict["input_tokens"]
input_positions = metadata_dict["input_positions"]
lora_mapping = metadata_dict["lora_mapping"]
lora_requests = metadata_dict["lora_requests"]
input_metadata = InputMetadata(
is_prompt=py_data["is_prompt"],
slot_mapping=slot_mapping,
max_context_len=py_data["max_context_len"],
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=py_data["use_cuda_graph"],
is_prompt=metadata_dict["is_prompt"],
slot_mapping=metadata_dict["slot_mapping"],
prompt_lens=metadata_dict["prompt_lens"],
max_seq_len=metadata_dict["max_seq_len"],
start_loc=metadata_dict["start_loc"],
max_context_len=metadata_dict["max_context_len"],
context_lens=metadata_dict["context_lens"],
block_tables=metadata_dict["block_tables"],
use_cuda_graph=metadata_dict["use_cuda_graph"],
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
selected_token_indices=selected_token_indices,
selected_token_indices=metadata_dict["selected_token_indices"],
categorized_sample_indices=None,
perform_sampling=False,
)
return input_tokens, input_positions, input_metadata, sampling_metadata
return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping
@torch.inference_mode()
def execute_model(
......@@ -442,8 +519,12 @@ class ModelRunner:
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]:
input_tokens, input_positions, input_metadata, sampling_metadata = (
input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = (
self.prepare_input_tensors(seq_group_metadata_list))
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Execute the model.
if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
......@@ -472,6 +553,28 @@ class ModelRunner:
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests = []
dummy_lora_requests_per_seq = []
if self.lora_config:
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
......@@ -485,6 +588,8 @@ class ModelRunner:
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
)
seqs.append(seq)
......@@ -495,6 +600,32 @@ class ModelRunner:
torch.cuda.synchronize()
return
def remove_all_loras(self) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_loras()
@torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None:
assert not self.model_config.enforce_eager
......@@ -504,7 +635,9 @@ class ModelRunner:
"use '--enforce-eager' in the CLI.")
logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
"If you are running out of memory, consider decreasing "
"`gpu_memory_utilization` or enforcing eager mode.")
"`gpu_memory_utilization` or enforcing eager mode. "
"You can also reduce the `max_num_seqs` as needed "
"to decrease memory usage.")
start_time = time.perf_counter()
# Prepare dummy inputs. These will be reused for all batch sizes.
......@@ -517,29 +650,47 @@ class ModelRunner:
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
# Create dummy input_metadata.
input_metadata = InputMetadata(
is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
max_context_len=self.max_context_len_to_capture,
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
input_metadata,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
with custom_all_reduce.capture():
for batch_size in reversed(batch_size_capture_list):
# Create dummy input_metadata.
input_metadata = InputMetadata(
is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
prompt_lens=None,
max_seq_len=None,
start_loc=None,
max_context_len=self.max_context_len_to_capture,
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
kv_cache_dtype=self.kv_cache_dtype,
)
if self.lora_config:
lora_mapping = LoRAMapping(
[0] * batch_size,
[0] * batch_size,
)
self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
input_metadata,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
end_time = time.perf_counter()
elapsed_time = end_time - start_time
......
from typing import List, Dict
import copy
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.worker import Worker
class MultiStepWorker(Worker):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
@torch.inference_mode()
def execute_model_multi_step(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_steps: int,
) -> List[SamplerOutput]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
"""
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
blocks_to_swap_out, blocks_to_copy)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list = self._shallow_copy_inputs(
seq_group_metadata_list)
# Assert enough KV space for num_steps tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
# Run model num_steps times.
model_outputs = []
for _ in range(num_steps):
model_output = super().execute_model(
seq_group_metadata_list=copied_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
model_outputs.append(model_output)
return model_outputs
def _append_new_tokens(
self, model_output: SamplerOutput,
seq_group_metadata_list: SequenceGroupMetadata) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
for seq_group_metadata, sequence_group_outputs in zip(
seq_group_metadata_list, model_output):
seq_group_metadata.is_prompt = False
for seq_output in sequence_group_outputs.samples:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob)
def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
) -> List[SequenceGroupMetadata]:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
The multi-step worker must be able to append tokens to sequences after
a forward pass. This necessitates modification of the data structures
used by the worker. Since these data structures are shared with other
parts of vLLM, like the scheduler, we must take care not to introduce
unexpected side-effects.
When Ray is used to orchestrate worker processes (such as when the
tensor-parallel degree is >1), this is not a problem because the input
datastructures will be serialized and created anew in the worker
process.
However, when Ray is not used to orchestrate the worker processes (such
as when the tensor-parallel degree is 1), this is a problem. We avoid
the problem by shallow-copying the input datastructures (specifically,
the parts that will change in multiple steps).
"""
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list = []
for old_seq_group_metadata in seq_group_metadata_list:
# We must shallow-copy seq_group_metadata as is_prompt could change.
seq_group_metadata = copy.copy(old_seq_group_metadata)
new_seq_group_metadata_list.append(seq_group_metadata)
# We must shallow-copy seq_data as we will append token ids
new_seq_data = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
seq_group_metadata.seq_data = new_seq_data
return new_seq_group_metadata_list
def _assert_enough_kv_space(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
num_steps: int) -> None:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert self.model_runner.block_size is not None
for seq_group_metadata in seq_group_metadata_list:
# Only one seq_id is guaranteed because there is no beam search.
seq_id = list(seq_group_metadata.seq_data.keys())[0]
seq = seq_group_metadata.seq_data[seq_id]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len = seq.get_len() + num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots = final_seq_len - 1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks = len(
seq_group_metadata.block_tables[seq_id])
allocated_kv_slots = (number_physical_blocks *
self.model_runner.block_size)
if required_num_kv_slots > allocated_kv_slots:
request_id = seq_group_metadata.request_id
raise ValueError(
"The worker attempted to run "
f"{num_steps} times but found insufficient KV space for "
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
f"{required_num_kv_slots=}).")
def _raise_if_unsupported(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
raise NotImplementedError(
"MultiStepWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")
"""A GPU worker class."""
import gc
import os
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple, Set, Optional
import torch
import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_object_list)
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel)
ensure_model_parallel_initialized)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.lora.request import LoRARequest
class Worker:
......@@ -33,6 +36,8 @@ class Worker:
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
......@@ -41,12 +46,17 @@ class Worker:
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config, parallel_config,
scheduler_config, is_driver_worker)
self.model_runner = ModelRunner(model_config,
parallel_config,
scheduler_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
......@@ -71,9 +81,10 @@ class Worker:
_check_if_gpu_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment.
_init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
if not self.parallel_config.disable_custom_all_reduce:
init_custom_ar()
# Initialize the model.
set_random_seed(self.model_config.seed)
......@@ -86,7 +97,16 @@ class Worker:
block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
cache_dtype: str,
) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated.
Args:
block_size: The size of the cache block.
gpu_memory_utilization: The fraction of the total GPU memory to use.
cpu_swap_space: The size of the CPU swap space in bytes.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
......@@ -102,13 +122,16 @@ class Worker:
peak_memory = total_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config)
block_size, cache_dtype, self.model_config, self.parallel_config)
num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks
......@@ -167,20 +190,21 @@ class Worker:
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
block_swapping_info = [
blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy
]
broadcast_object_list([num_seq_groups] + block_swapping_info,
src=0)
data = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
# blocks_to_copy (4 elements)
recv_data = [None] * 4
broadcast_object_list(recv_data, src=0)
num_seq_groups = recv_data[0]
block_swapping_info = recv_data[1:]
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
self.cache_swap(*block_swapping_info)
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
......@@ -190,8 +214,17 @@ class Worker:
self.gpu_cache)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def _init_distributed_environment(
def init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
......@@ -218,8 +251,8 @@ def _init_distributed_environment(
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
......@@ -231,4 +264,6 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")
f"{compute_capability[0]}.{compute_capability[1]}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment