Commit 99b471c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.1

parents 1925d2e9 468d761b
...@@ -222,13 +222,15 @@ class JAISConfig(PretrainedConfig): ...@@ -222,13 +222,15 @@ class JAISConfig(PretrainedConfig):
f"got {alibi_scaling_type}") f"got {alibi_scaling_type}")
if (alibi_scaling_factor is not None if (alibi_scaling_factor is not None
and not isinstance(alibi_scaling_factor, float) and not isinstance(alibi_scaling_factor, float)
or alibi_scaling_factor <= 1.0): or (alibi_scaling_factor is not None
and alibi_scaling_factor <= 1.0)):
raise ValueError( raise ValueError(
f"`alibi_scaling`'s factor field must be a float > 1.0," f"`alibi_scaling`'s factor field must be a float > 1.0,"
f"got {alibi_scaling_factor}") f"got {alibi_scaling_factor}")
if (alibi_dynamic_scaling is not None if (alibi_dynamic_scaling is not None
and not isinstance(alibi_dynamic_scaling, int) and not isinstance(alibi_dynamic_scaling, int)
or alibi_dynamic_scaling <= 1): or (alibi_dynamic_scaling is not None
and alibi_dynamic_scaling <= 1)):
raise ValueError( raise ValueError(
f"`alibi_scaling`'s `train_seq_len` field must be an" f"`alibi_scaling`'s `train_seq_len` field must be an"
f"integer > 1, got {alibi_dynamic_scaling}") f"integer > 1, got {alibi_dynamic_scaling}")
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.tokenizer import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup) BaseTokenizerGroup)
...@@ -89,12 +87,15 @@ class Detokenizer: ...@@ -89,12 +87,15 @@ class Detokenizer:
prev_tokens.extend(next_iter_tokens) prev_tokens.extend(next_iter_tokens)
def decode_sequence_inplace(self, seq: Sequence, def decode_sequence_inplace(self, seq: Sequence,
prms: SamplingParams) -> None: prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation. """Decodes the new token for a sequence. In-place operation.
Args: Args:
seq: The sequence to decode. seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence. prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
""" """
all_input_ids = seq.get_token_ids() all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1] token_id_generated_this_iteration = all_input_ids[-1]
...@@ -148,10 +149,165 @@ class Detokenizer: ...@@ -148,10 +149,165 @@ class Detokenizer:
) )
sample_logprob.decoded_token = new_text sample_logprob.decoded_token = new_text
if seq.tokens is None: seq.tokens.extend(new_tokens)
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset seq.prefix_offset = prefix_offset
seq.read_offset = read_offset seq.read_offset = read_offset
seq.output_text += new_decoded_token_text seq.output_text += new_decoded_token_text
return len(new_decoded_token_text)
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts: List[str] = []
current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# We do not need to convert the whole prompt to tokens.
# Offset a little more in case we have special tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
skip_special_tokens=skip_special_tokens)
read_offset = len(new_tokens)
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
return new_tokens, prefix_offset, read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string.
if new_token_id >= len(tokenizer):
new_tokens = [""]
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
return new_tokens, "", prefix_offset, read_offset
new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
from typing import List, Optional, Tuple, Union import os
from typing import Optional, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from vllm.config import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import * from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async from vllm.utils import make_async
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,7 +30,7 @@ def get_cached_tokenizer( ...@@ -28,7 +30,7 @@ def get_cached_tokenizer(
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer) tokenizer_len = len(tokenizer)
class CachedTokenizer(tokenizer.__class__): class CachedTokenizer(tokenizer.__class__): # type: ignore
@property @property
def all_special_ids(self): def all_special_ids(self):
...@@ -57,9 +59,26 @@ def get_tokenizer( ...@@ -57,9 +59,26 @@ def get_tokenizer(
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface/modelscope."""
if VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
# Only set the tokenizer here, model will be downloaded on the workers.
if not os.path.exists(tokenizer_name):
tokenizer_path = snapshot_download(
model_id=tokenizer_name,
cache_dir=download_dir,
revision=tokenizer_revision,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = tokenizer_path
if tokenizer_mode == "slow": if tokenizer_mode == "slow":
if kwargs.get("use_fast", False): if kwargs.get("use_fast", False):
raise ValueError( raise ValueError(
...@@ -126,157 +145,3 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, ...@@ -126,157 +145,3 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
get_lora_tokenizer_async = make_async(get_lora_tokenizer) get_lora_tokenizer_async = make_async(get_lora_tokenizer)
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# Offset a little more in case we have special tokens.
prefix_offset = max(
len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0)
# We do not need to convert the whole prompt to tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens)
prefix_offset = max(
len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
read_offset = len(new_tokens)
return new_tokens, prefix_offset, read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
# If the new token id is out of bounds, return an empty string.
if new_token_id >= len(tokenizer):
new_tokens = [""]
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
else:
return new_tokens, "", prefix_offset, read_offset
...@@ -11,7 +11,7 @@ if ray: ...@@ -11,7 +11,7 @@ if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool) RayTokenizerGroupPool)
else: else:
RayTokenizerGroupPool = None RayTokenizerGroupPool = None # type: ignore
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
......
...@@ -51,6 +51,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -51,6 +51,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
enable_lora=enable_lora, enable_lora=enable_lora,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_input_length=max_input_length, max_input_length=max_input_length,
**tokenizer_config,
) )
ray_tokenizer_group_cls = ray.remote( ray_tokenizer_group_cls = ray.remote(
...@@ -88,6 +89,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -88,6 +89,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
This is blocking. This is blocking.
""" """
self._ensure_queue_initialized() self._ensure_queue_initialized()
assert self._idle_actors is not None
if self._idle_actors.empty(): if self._idle_actors.empty():
raise RuntimeError("No idle actors available.") raise RuntimeError("No idle actors available.")
...@@ -119,6 +121,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -119,6 +121,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
This is non-blocking. This is non-blocking.
""" """
self._ensure_queue_initialized() self._ensure_queue_initialized()
assert self._idle_actors is not None
actor = await self._idle_actors.get() actor = await self._idle_actors.get()
try: try:
......
...@@ -16,11 +16,11 @@ logger = logging.get_logger(__name__) ...@@ -16,11 +16,11 @@ logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = { # type: ignore
"vocab_file": {}, "vocab_file": {},
"tokenizer_file": {}, "tokenizer_file": {},
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore
class BaichuanTokenizer(PreTrainedTokenizer): class BaichuanTokenizer(PreTrainedTokenizer):
...@@ -114,9 +114,9 @@ class BaichuanTokenizer(PreTrainedTokenizer): ...@@ -114,9 +114,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
token = self.sp_model.IdToPiece(index) token = self.sp_model.IdToPiece(index)
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens: List[str]):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = [] current_sub_tokens: List[str] = []
out_string = "" out_string = ""
prev_is_special = False prev_is_special = False
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
...@@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer): ...@@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
`Tuple(str)`: Paths to the files saved. `Tuple(str)`: Paths to the files saved.
""" """
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) " raise ValueError(f"Vocabulary path ({save_directory}) "
"should be a directory") "should be a directory")
return
out_vocab_file = os.path.join( out_vocab_file = os.path.join(
save_directory, save_directory,
(filename_prefix + "-" if filename_prefix else "") + (filename_prefix + "-" if filename_prefix else "") +
......
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from typing import Dict, Optional from typing import Any, Dict, Optional
from uuid import uuid4 from uuid import uuid4
import cpuinfo import cpuinfo
...@@ -124,7 +124,7 @@ class UsageMessage: ...@@ -124,7 +124,7 @@ class UsageMessage:
def report_usage(self, def report_usage(self,
model_architecture: str, model_architecture: str,
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, any] = None) -> None: extra_kvs: Optional[Dict[str, Any]] = None) -> None:
t = Thread(target=self._report_usage_worker, t = Thread(target=self._report_usage_worker,
args=(model_architecture, usage_context, extra_kvs or {}), args=(model_architecture, usage_context, extra_kvs or {}),
daemon=True) daemon=True)
...@@ -132,13 +132,13 @@ class UsageMessage: ...@@ -132,13 +132,13 @@ class UsageMessage:
def _report_usage_worker(self, model_architecture: str, def _report_usage_worker(self, model_architecture: str,
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None: extra_kvs: Dict[str, Any]) -> None:
self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_usage_once(model_architecture, usage_context, extra_kvs)
self._report_continous_usage() self._report_continous_usage()
def _report_usage_once(self, model_architecture: str, def _report_usage_once(self, model_architecture: str,
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None: extra_kvs: Dict[str, Any]) -> None:
# Platform information # Platform information
if torch.cuda.is_available(): if torch.cuda.is_available():
device_property = torch.cuda.get_device_properties(0) device_property = torch.cuda.get_device_properties(0)
......
import asyncio import asyncio
import enum import enum
import gc import gc
import glob
import os import os
import socket import socket
import subprocess import subprocess
import uuid import uuid
import warnings import warnings
from collections import OrderedDict from collections import defaultdict
from functools import lru_cache, partial from functools import lru_cache, partial
from platform import uname from platform import uname
from typing import (Any, Awaitable, Callable, Generic, Hashable, List, from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Optional, Tuple, TypeVar, Union) Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
Union)
import psutil import psutil
import torch import torch
...@@ -23,9 +25,9 @@ logger = init_logger(__name__) ...@@ -23,9 +25,9 @@ logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half, "half": torch.half,
# "bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
# "fp8_e5m2": torch.uint8, # "fp8": torch.uint8,
} }
...@@ -51,7 +53,7 @@ class Counter: ...@@ -51,7 +53,7 @@ class Counter:
class LRUCache(Generic[T]): class LRUCache(Generic[T]):
def __init__(self, capacity: int): def __init__(self, capacity: int):
self.cache = OrderedDict[Hashable, T]() self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.capacity = capacity self.capacity = capacity
def __contains__(self, key: Hashable) -> bool: def __contains__(self, key: Hashable) -> bool:
...@@ -60,7 +62,7 @@ class LRUCache(Generic[T]): ...@@ -60,7 +62,7 @@ class LRUCache(Generic[T]):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
def __getitem__(self, key: Hashable) -> T: def __getitem__(self, key: Hashable) -> Optional[T]:
return self.get(key) return self.get(key)
def __setitem__(self, key: Hashable, value: T) -> None: def __setitem__(self, key: Hashable, value: T) -> None:
...@@ -76,7 +78,7 @@ class LRUCache(Generic[T]): ...@@ -76,7 +78,7 @@ class LRUCache(Generic[T]):
key: Hashable, key: Hashable,
default_value: Optional[T] = None) -> Optional[T]: default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache: if key in self.cache:
value = self.cache[key] value: Optional[T] = self.cache[key]
self.cache.move_to_end(key) self.cache.move_to_end(key)
else: else:
value = default_value value = default_value
...@@ -87,7 +89,7 @@ class LRUCache(Generic[T]): ...@@ -87,7 +89,7 @@ class LRUCache(Generic[T]):
self.cache.move_to_end(key) self.cache.move_to_end(key)
self._remove_old_if_needed() self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: T): def _on_remove(self, key: Hashable, value: Optional[T]):
pass pass
def remove_oldest(self): def remove_oldest(self):
...@@ -100,9 +102,11 @@ class LRUCache(Generic[T]): ...@@ -100,9 +102,11 @@ class LRUCache(Generic[T]):
while len(self.cache) > self.capacity: while len(self.cache) > self.capacity:
self.remove_oldest() self.remove_oldest()
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache run_on_remove = key in self.cache
value = self.cache.pop(key, default_value) value: Optional[T] = self.cache.pop(key, default_value)
if run_on_remove: if run_on_remove:
self._on_remove(key, value) self._on_remove(key, value)
return value return value
...@@ -117,6 +121,15 @@ def is_hip() -> bool: ...@@ -117,6 +121,15 @@ def is_hip() -> bool:
return torch.version.hip is not None return torch.version.hip is not None
@lru_cache(maxsize=None)
def is_cpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "cpu" in version("vllm")
except PackageNotFoundError:
return False
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def is_neuron() -> bool: def is_neuron() -> bool:
try: try:
...@@ -150,6 +163,17 @@ def random_uuid() -> str: ...@@ -150,6 +163,17 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
@lru_cache(maxsize=None)
def get_vllm_instance_id():
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}")
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def in_wsl() -> bool: def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071 # Reference: https://github.com/microsoft/WSL/issues/4071
...@@ -171,7 +195,43 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: ...@@ -171,7 +195,43 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
return _async_wrapper return _async_wrapper
def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i: int, iterator: AsyncIterator[T]):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
await asyncio.gather(*_tasks)
return consumer()
def get_ip() -> str: def get_ip() -> str:
host_ip = os.environ.get("HOST_IP") host_ip = os.environ.get("HOST_IP")
if host_ip: if host_ip:
...@@ -223,8 +283,12 @@ def get_open_port() -> int: ...@@ -223,8 +283,12 @@ def get_open_port() -> int:
return s.getsockname()[1] return s.getsockname()[1]
def set_cuda_visible_devices(device_ids: List[int]) -> None: def update_environment_variables(envs: Dict[str, str]):
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) for k, v in envs.items():
if k in os.environ and os.environ[k] != v:
logger.warning(f"Overwriting environment variable {k} "
f"from '{os.environ[k]}' to '{v}'")
os.environ[k] = v
def chunk_list(lst, chunk_size): def chunk_list(lst, chunk_size):
...@@ -257,7 +321,7 @@ def get_nvcc_cuda_version() -> Optional[Version]: ...@@ -257,7 +321,7 @@ def get_nvcc_cuda_version() -> Optional[Version]:
return nvcc_cuda_version return nvcc_cuda_version
def _generate_random_fp8_e5m2( def _generate_random_fp8(
tensor: torch.tensor, tensor: torch.tensor,
low: float, low: float,
high: float, high: float,
...@@ -270,10 +334,10 @@ def _generate_random_fp8_e5m2( ...@@ -270,10 +334,10 @@ def _generate_random_fp8_e5m2(
#-----|-------------|------------------- #-----|-------------|-------------------
# Inf | N/A | s.11111.00 # Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11} # NaN | s.1111.111 | s.11111.{01,10,11}
from vllm._C import cache_ops from vllm import _custom_ops as ops
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high) tensor_tmp.uniform_(low, high)
cache_ops.convert_fp8_e5m2(tensor_tmp, tensor) ops.convert_fp8(tensor_tmp, tensor)
del tensor_tmp del tensor_tmp
...@@ -285,7 +349,7 @@ def create_kv_caches_with_random( ...@@ -285,7 +349,7 @@ def create_kv_caches_with_random(
head_size: int, head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]], cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = 0, seed: int = 0,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
...@@ -302,7 +366,7 @@ def create_kv_caches_with_random( ...@@ -302,7 +366,7 @@ def create_kv_caches_with_random(
raise ValueError(f"Invalid model dtype: {model_dtype}") raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]: elif cache_dtype in ["half", "bfloat16", "float"]:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8_e5m2": elif cache_dtype == "fp8":
torch_dtype = torch.uint8 torch_dtype = torch.uint8
else: else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
...@@ -319,10 +383,10 @@ def create_kv_caches_with_random( ...@@ -319,10 +383,10 @@ def create_kv_caches_with_random(
key_cache = torch.empty(size=key_cache_shape, key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype, dtype=torch_dtype,
device=device) device=device)
if cache_dtype == 'fp8_e5m2': if cache_dtype in ["auto", "half", "bfloat16", "float"]:
_generate_random_fp8_e5m2(key_cache, -scale, scale)
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
key_cache.uniform_(-scale, scale) key_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8':
_generate_random_fp8(key_cache, -scale, scale)
else: else:
raise ValueError( raise ValueError(
f"Does not support key cache of type {cache_dtype}") f"Does not support key cache of type {cache_dtype}")
...@@ -334,10 +398,10 @@ def create_kv_caches_with_random( ...@@ -334,10 +398,10 @@ def create_kv_caches_with_random(
value_cache = torch.empty(size=value_cache_shape, value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype, dtype=torch_dtype,
device=device) device=device)
if cache_dtype == 'fp8_e5m2': if cache_dtype in ["auto", "half", "bfloat16", "float"]:
_generate_random_fp8_e5m2(value_cache, -scale, scale)
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
value_cache.uniform_(-scale, scale) value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8':
_generate_random_fp8(value_cache, -scale, scale)
else: else:
raise ValueError( raise ValueError(
f"Does not support value cache of type {cache_dtype}") f"Does not support value cache of type {cache_dtype}")
...@@ -362,6 +426,8 @@ def is_pin_memory_available() -> bool: ...@@ -362,6 +426,8 @@ def is_pin_memory_available() -> bool:
elif is_neuron(): elif is_neuron():
print_warning_once("Pin memory is not supported on Neuron.") print_warning_once("Pin memory is not supported on Neuron.")
return False return False
elif is_cpu():
return False
return True return True
...@@ -389,7 +455,7 @@ class CudaMemoryProfiler: ...@@ -389,7 +455,7 @@ class CudaMemoryProfiler:
gc.collect() gc.collect()
def str_to_int_tuple(s: str) -> Tuple[int]: def str_to_int_tuple(s: str) -> Tuple[int, ...]:
"""Convert a string to a tuple of integers.""" """Convert a string to a tuple of integers."""
try: try:
return tuple(map(int, s.split(","))) return tuple(map(int, s.split(",")))
...@@ -438,3 +504,106 @@ def maybe_expand_dim(tensor: torch.Tensor, ...@@ -438,3 +504,106 @@ def maybe_expand_dim(tensor: torch.Tensor,
if tensor.ndim < target_dims: if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor return tensor
def merge_dicts(dict1: Dict[Any, List[Any]],
dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict = defaultdict(list)
for key, value in dict1.items():
merged_dict[key].extend(value)
for key, value in dict2.items():
merged_dict[key].extend(value)
return dict(merged_dict)
def init_cached_hf_modules():
"""
Lazy initialization of the Hugging Face modules.
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
def nccl_integrity_check(filepath):
"""
when the library is corrupted, we cannot catch
the exception in python. it will crash the process.
instead, we use the exit code of `ldd` to check
if the library is corrupted. if not, we will return
the version of the library.
"""
exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
if exit_code != 0:
raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
import ctypes
nccl = ctypes.CDLL(filepath)
version = ctypes.c_int()
nccl.ncclGetVersion.restype = ctypes.c_int
nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
result = nccl.ncclGetVersion(ctypes.byref(version))
assert result == 0
return version.value
@lru_cache(maxsize=None)
def find_library(lib_name: str) -> str:
"""
Find the library file in the system.
`lib_name` is full filename, with both prefix and suffix.
This function resolves `lib_name` to the full path of the library.
"""
# Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
# According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
# `/sbin/ldconfig` should exist in all Linux systems.
# `/sbin/ldconfig` searches the library in the system
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
# each line looks like the following:
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
# `LD_LIBRARY_PATH` searches the library in the user-defined paths
env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
if not locs and env_ld_library_path:
locs = [
os.path.join(dir, lib_name)
for dir in env_ld_library_path.split(":")
if os.path.exists(os.path.join(dir, lib_name))
]
if not locs:
raise ValueError(f"Cannot find {lib_name} in the system.")
return locs[0]
def find_nccl_library():
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")
# check if we have vllm-managed nccl
vllm_nccl_path = None
if torch.version.cuda is not None:
cuda_major = torch.version.cuda.split(".")[0]
path = os.path.expanduser(
f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
files = glob.glob(path)
vllm_nccl_path = files[0] if files else None
# manually load the nccl library
if so_file:
logger.info(
f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}"
)
else:
if torch.version.cuda is not None:
so_file = vllm_nccl_path or find_library("libnccl.so.2")
elif torch.version.hip is not None:
so_file = find_library("librccl.so.1")
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}")
return so_file
...@@ -82,8 +82,7 @@ class CacheEngine: ...@@ -82,8 +82,7 @@ class CacheEngine:
@staticmethod @staticmethod
def get_cache_block_size( def get_cache_block_size(
block_size: int, cache_config: CacheConfig,
cache_dtype: str,
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
) -> int: ) -> int:
...@@ -91,13 +90,13 @@ class CacheEngine: ...@@ -91,13 +90,13 @@ class CacheEngine:
num_heads = model_config.get_num_kv_heads(parallel_config) num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config) num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block) total = num_layers * (key_cache_block + value_cache_block)
if cache_dtype == "auto": if cache_config.cache_dtype == "auto":
dtype = model_config.dtype dtype = model_config.dtype
else: else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = _get_dtype_size(dtype) dtype_size = _get_dtype_size(dtype)
return dtype_size * total return dtype_size * total
......
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad, maybe_expand_dim
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
class CPUModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None)
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype
self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
self.block_size: int # Set after initial profiling.
def load_model(self) -> None:
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Optional[torch.Tensor]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len)))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, prompt_len - self.sliding_window)
for i in range(computed_len, prompt_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
prompt_lens=prompt_lens,
num_prefills=len(prompt_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
prefill_metadata=None,
decode_metadata=None,
max_context_len=None,
context_lens=None,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_input)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
context_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
context_lens.append(context_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_context_len = max(context_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
max_context_len=max_context_len,
num_prefills=0,
prefill_metadata=None,
decode_metadata=None,
context_lens=context_lens,
block_tables=block_tables,
kv_cache_dtype=self.kv_cache_dtype,
)
return (
input_tokens,
input_positions,
attn_metadata,
)
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
subquery_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long)
categorized_sample_indices = {
t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Optional[torch.Tensor]]:
multi_modal_input = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices":
sampling_metadata.selected_token_indices,
}
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
perform_sampling=False,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_input)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling:
return None
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
return output
"""A CPU worker class."""
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
logger = init_logger(__name__)
class CPUCacheEngine:
"""Manages the KV cache for CPU backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig) -> None:
assert device_config.device_type == "cpu"
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for CPU backend, because we want to reuse KV cache management
# in the scheduler.
self.num_cpu_blocks = cache_config.num_gpu_blocks
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(model_config.dtype)
# Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
def _allocate_kv_cache(
self,
num_blocks: int,
) -> List[torch.Tensor]:
"""Allocates KV cache on CPU."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_heads, self.head_size)
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
kv_cache.append(
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
@staticmethod
def get_cache_block_size(
block_size: int,
cache_dtype: str,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
if cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
dtype_size = torch.tensor([], dtype=dtype).element_size()
return dtype_size * total
class CPUWorker(LoraNotSupportedWorkerBase):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
responsible for maintaining the KV cache and executing the model on the
CPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = CPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CPUCacheEngine
self.cpu_cache: List[torch.Tensor]
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured CPU
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For CPU device, the block number will be calculated based on the
# cpu_kvcache_space.
cache_block_size = self.get_cache_block_size_bytes()
num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
cache_block_size)
num_cpu_blocks = max(num_cpu_blocks, 0)
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks = num_cpu_blocks
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert (num_cpu_blocks == 0
), f"{type(self)} does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks = num_gpu_blocks
self._validate_num_cpu_blocks(num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_cpu_blocks
self.cache_config.num_cpu_blocks = 0
# Initialize the cache.
self._init_cache_engine()
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
"""Raise errors if the num_cpu_blocks is invalid.
"""
if num_cpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_cpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine.")
def _init_cache_engine(self) -> None:
self.cache_engine = CPUCacheEngine(self.cache_config,
self.model_config,
self.parallel_config,
self.device_config)
self.cpu_cache = self.cache_engine.cpu_cache
self.model_runner.block_size = self.cache_engine.block_size
assert self.cpu_cache is not None
# Populate the cache to warmup the memory
for layer_cache in self.cpu_cache:
layer_cache.fill_(0)
def cache_copy(
self,
blocks_to_copy: Dict[int, List[int]],
) -> None:
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> List[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
assert len(blocks_to_swap_in) == 0
assert len(blocks_to_swap_out) == 0
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_copy is not None
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.cpu_cache)
# CPU worker only supports single-step execution.
return [output]
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block.
"""
return CPUCacheEngine.get_cache_block_size(
self.cache_config.block_size, self.cache_config.cache_dtype,
self.model_config, self.parallel_config)
import contextlib import contextlib
import time import time
from typing import Dict, List, Optional, Set, Tuple from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, get_attn_backend)
SchedulerConfig, VisionLanguageConfig) from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
pynccl_utils)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
with_pynccl_for_all_reduce)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
is_pin_memory_available, make_tensor_with_pad, is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim) maybe_expand_dim)
...@@ -39,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ ...@@ -39,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
] ]
class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int]
subquery_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
)
class PrepareDecodeMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadata]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner: class ModelRunner:
def __init__( def __init__(
...@@ -47,6 +107,7 @@ class ModelRunner: ...@@ -47,6 +107,7 @@ class ModelRunner:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
...@@ -56,6 +117,7 @@ class ModelRunner: ...@@ -56,6 +117,7 @@ class ModelRunner:
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py. # model_config can be None in tests/samplers/test_sampler.py.
...@@ -66,23 +128,17 @@ class ModelRunner: ...@@ -66,23 +128,17 @@ class ModelRunner:
if device_config is not None else DeviceConfig()) if device_config is not None else DeviceConfig())
self.device = self.device_config.device self.device = self.device_config.device
self.model = None # Set after load_model.
self.block_size = None # Set after initial profiling. self.lora_manager: LRUCacheWorkerLoRAManager = None
self.lora_manager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture. self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.max_context_len_to_capture = ( self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture self.model_config.max_context_len_to_capture
if self.model_config is not None else 0) if self.model_config is not None else 0)
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.vision_language_config = vision_language_config self.vision_language_config = vision_language_config
...@@ -90,15 +146,28 @@ class ModelRunner: ...@@ -90,15 +146,28 @@ class ModelRunner:
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None) self.model_config.dtype if model_config is not None else None)
# Lazy initialization
self.model: torch.nn.Module # Set after load_model
self.block_size: int # Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables: torch.Tensor # Set after initial profiling.
def load_model(self) -> None: def load_model(self) -> None:
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
self.model = get_model( self.model = get_model(
self.model_config, model_config=self.model_config,
self.device_config, device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config) scheduler_config=self.scheduler_config,
)
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info(f"Loading model weights took " logger.info(f"Loading model weights took "
...@@ -120,6 +189,26 @@ class ModelRunner: ...@@ -120,6 +189,26 @@ class ModelRunner:
self.model.embedding_padding_modules) self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently scaled KV cache is only enabled on ROCm
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(
self.model_config.quantization_param_path)
else:
raise RuntimeError("Using FP8 KV cache and scaling "
"factors provided but model "
f"{self.model.__class__} does not "
"support loading scaling factors.")
else:
logger.warn("Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None:
logger.warn("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")
def set_block_size(self, block_size: int) -> None: def set_block_size(self, block_size: int) -> None:
self.block_size = block_size self.block_size = block_size
...@@ -134,10 +223,7 @@ class ModelRunner: ...@@ -134,10 +223,7 @@ class ModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> PreparePromptMetadata:
List[int], List[int], List[int], Set[LoRARequest],
torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
...@@ -151,6 +237,9 @@ class ModelRunner: ...@@ -151,6 +237,9 @@ class ModelRunner:
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_input_list: List[torch.Tensor] = []
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
...@@ -160,7 +249,8 @@ class ModelRunner: ...@@ -160,7 +249,8 @@ class ModelRunner:
computed_block_nums = seq_group_metadata.computed_block_nums computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled and self.scheduler_config.chunked_prefill_enabled
and computed_block_nums is not None): and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError( raise RuntimeError(
"chunked prefill cannot be used with prefix caching " "chunked prefill cannot be used with prefix caching "
"now.") "now.")
...@@ -172,13 +262,8 @@ class ModelRunner: ...@@ -172,13 +262,8 @@ class ModelRunner:
# it contains output tokens. # it contains output tokens.
prefill_end = min(seq_data.get_len(), prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size) computed_len + token_chunk_size)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = len(prompt_tokens) prompt_len = prefill_end
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert prefill_end == seq_data.get_len()
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
...@@ -188,6 +273,14 @@ class ModelRunner: ...@@ -188,6 +273,14 @@ class ModelRunner:
computed_len = len(computed_block_nums) * self.block_size computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:] prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums) prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else: else:
prefix_block_tables.append([]) prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this # Right now, prefill start is always 0. However, this
...@@ -202,7 +295,6 @@ class ModelRunner: ...@@ -202,7 +295,6 @@ class ModelRunner:
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prefill_end))) input_positions.extend(list(range(computed_len, prefill_end)))
lora_id = seq_group_metadata.lora_int_id lora_id = seq_group_metadata.lora_int_id
if lora_id > 0: if lora_id > 0:
...@@ -250,20 +342,8 @@ class ModelRunner: ...@@ -250,20 +342,8 @@ class ModelRunner:
max_subquery_len = max(subquery_lens) max_subquery_len = max(subquery_lens)
max_prompt_len = max(prompt_lens) max_prompt_len = max(prompt_lens)
num_prompt_tokens = len(input_tokens)
assert max_subquery_len > 0 assert max_subquery_len > 0
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
lora_index_mapping = lora_index_mapping
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
...@@ -315,11 +395,8 @@ class ModelRunner: ...@@ -315,11 +395,8 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
slot_mapping=slot_mapping,
prompt_lens=prompt_lens, prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor, prompt_lens_tensor=prompt_lens_tensor,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=0,
max_subquery_len=max_subquery_len, max_subquery_len=max_subquery_len,
max_context_len=None, max_context_len=None,
max_prompt_len=max_prompt_len, max_prompt_len=max_prompt_len,
...@@ -328,18 +405,25 @@ class ModelRunner: ...@@ -328,18 +405,25 @@ class ModelRunner:
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=False, use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, return PreparePromptMetadata(
lora_requests, multi_modal_input) input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
prompt_lens=prompt_lens,
subquery_lens=subquery_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping,
)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> PrepareDecodeMetadata:
List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
...@@ -349,6 +433,9 @@ class ModelRunner: ...@@ -349,6 +433,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set() lora_requests: Set[LoRARequest] = set()
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
...@@ -407,25 +494,16 @@ class ModelRunner: ...@@ -407,25 +494,16 @@ class ModelRunner:
lora_index_mapping.append(0) lora_index_mapping.append(0)
batch_size = graph_batch_size batch_size = graph_batch_size
input_tokens = torch.tensor(input_tokens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.long, dtype=torch.int,
device=self.device) device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
if use_captured_graph: if use_captured_graph:
# When using cuda-graph all these tensors should be # When using cuda-graph all these tensors should be
# padded. # padded.
assert context_lens.shape[0] == input_tokens.shape[0] assert context_lens_tensor.shape[0] == len(input_tokens)
assert context_lens.shape[0] == input_positions.shape[0] assert context_lens_tensor.shape[0] == len(input_positions)
assert context_lens.shape[0] == slot_mapping.shape[0] assert context_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
...@@ -447,23 +525,26 @@ class ModelRunner: ...@@ -447,23 +525,26 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None, prompt_lens=None,
prompt_lens_tensor=None, prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=len(input_tokens),
max_subquery_len=None, max_subquery_len=None,
max_context_len=max_context_len, max_context_len=max_context_len,
max_prompt_len=None, max_prompt_len=None,
subquery_start_loc=None, subquery_start_loc=None,
seq_start_loc=None, seq_start_loc=None,
context_lens=context_lens, context_lens=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, return PrepareDecodeMetadata(
lora_index_mapping, lora_prompt_mapping, lora_requests) input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
)
def _prepare_sample( def _prepare_sample(
self, self,
...@@ -475,7 +556,11 @@ class ModelRunner: ...@@ -475,7 +556,11 @@ class ModelRunner:
selected_token_indices: List[int] = [] selected_token_indices: List[int] = []
generators: List[torch.Generator] = [] generators: List[torch.Generator] = []
selected_token_start_idx = 0 selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0 categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0
...@@ -493,10 +578,9 @@ class ModelRunner: ...@@ -493,10 +578,9 @@ class ModelRunner:
categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[ categorized_sample_indices[
sampling_params.sampling_type].append([ sampling_params.sampling_type].append(
categorized_sample_indices_start_idx, (categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx categorized_sampled_token_indices_start_idx))
])
categorized_sample_indices_start_idx += 1 categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1
...@@ -520,15 +604,16 @@ class ModelRunner: ...@@ -520,15 +604,16 @@ class ModelRunner:
categorized_sample_indices[ categorized_sample_indices[
sampling_params.sampling_type].extend( sampling_params.sampling_type].extend(
zip( list(
range( zip(
categorized_sample_indices_start_idx, range(
categorized_sample_indices_start_idx + categorized_sample_indices_start_idx,
num_seqs), categorized_sample_indices_start_idx +
range( num_seqs),
categorized_sampled_token_indices_start_idx, range(
categorized_sampled_token_indices_start_idx + categorized_sampled_token_indices_start_idx,
num_seqs))) categorized_sampled_token_indices_start_idx
+ num_seqs))))
categorized_sample_indices_start_idx += num_seqs categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs categorized_sampled_token_indices_start_idx += num_seqs
...@@ -565,30 +650,70 @@ class ModelRunner: ...@@ -565,30 +650,70 @@ class ModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping, torch.Tensor]: Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or prefill_reqs = []
# all decodes. decode_reqs = []
is_prompt = seq_group_metadata_list[0].is_prompt for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors. # Prepare input tensors.
if is_prompt: (
(input_tokens, input_positions, attn_metadata, prompt_lens, input_tokens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, input_positions,
lora_requests, multi_modal_input prefill_attn_metadata,
) = self._prepare_prompt(seq_group_metadata_list) prompt_lens,
else: subquery_lens,
(input_tokens, input_positions, attn_metadata, lora_index_mapping,
lora_index_mapping, lora_prompt_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list) lora_requests,
prompt_lens = [] multi_modal_input,
subquery_lens = None slot_mapping,
multi_modal_input = None ) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens) subquery_lens)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.lora_config: if self.lora_config:
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
lora_index_mapping, lora_index_mapping,
...@@ -598,6 +723,16 @@ class ModelRunner: ...@@ -598,6 +723,16 @@ class ModelRunner:
lora_mapping = None lora_mapping = None
# Broadcast the metadata. # Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,
"input_positions": input_positions, "input_positions": input_positions,
...@@ -606,19 +741,50 @@ class ModelRunner: ...@@ -606,19 +741,50 @@ class ModelRunner:
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input, "multi_modal_input": multi_modal_input,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
} }
metadata_dict.update(attn_metadata.asdict_zerocopy()) if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
assert decode_attn_metadata is not None
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if batch_type == BatchType.MIXED:
assert decode_attn_metadata is not None
metadata_dict = decode_attn_metadata.asdict_zerocopy()
broadcast_tensor_dict(metadata_dict, src=0)
else: else:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens") input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions") input_positions = metadata_dict.pop("input_positions")
slot_mapping = metadata_dict.pop("slot_mapping")
num_prefills = metadata_dict.pop("num_prefills")
selected_token_indices = metadata_dict.pop( selected_token_indices = metadata_dict.pop(
"selected_token_indices") "selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping") lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests") lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input") multi_modal_input = metadata_dict.pop("multi_modal_input")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict) num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
batch_type = metadata_dict.pop("batch_type")
# Create an attention metadata.
prefill_attn_metadata = None
decode_attn_metadata = None
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
prefill_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None, seq_data=None,
...@@ -629,6 +795,23 @@ class ModelRunner: ...@@ -629,6 +795,23 @@ class ModelRunner:
perform_sampling=False, perform_sampling=False,
) )
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if batch_type == BatchType.MIXED:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping, sampling_metadata, lora_requests, lora_mapping,
multi_modal_input) multi_modal_input)
...@@ -636,7 +819,7 @@ class ModelRunner: ...@@ -636,7 +819,7 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata, (input_tokens, input_positions, attn_metadata, sampling_metadata,
...@@ -646,8 +829,10 @@ class ModelRunner: ...@@ -646,8 +829,10 @@ class ModelRunner:
if self.lora_config: if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping) self.set_active_loras(lora_requests, lora_mapping)
# Execute the model. # Currently cuda graph is only supported by the decode phase.
if attn_metadata.use_cuda_graph: prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[graph_batch_size]
else: else:
...@@ -748,7 +933,7 @@ class ModelRunner: ...@@ -748,7 +933,7 @@ class ModelRunner:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras() return self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: List[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: lora_mapping: LoRAMapping) -> None:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
...@@ -825,13 +1010,10 @@ class ModelRunner: ...@@ -825,13 +1010,10 @@ class ModelRunner:
# memory usage of CUDA graph. # memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata. # Create dummy attn_metadata.
attn_metadata = self.attn_backend.make_metadata( decode_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
prompt_lens=None, prompt_lens=None,
prompt_lens_tensor=None, prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=batch_size,
max_subquery_len=None, max_subquery_len=None,
max_context_len=self.max_context_len_to_capture, max_context_len=self.max_context_len_to_capture,
max_prompt_len=None, max_prompt_len=None,
...@@ -840,6 +1022,14 @@ class ModelRunner: ...@@ -840,6 +1022,14 @@ class ModelRunner:
context_lens=context_lens[:batch_size], context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size], block_tables=block_tables[:batch_size],
use_cuda_graph=True, use_cuda_graph=True,
)
attn_metadata = AttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
prefill_metadata=None,
decode_metadata=decode_metadata,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
...@@ -885,10 +1075,16 @@ class CUDAGraphRunner: ...@@ -885,10 +1075,16 @@ class CUDAGraphRunner:
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
self.model = model self.model = model
self.graph = None
self.input_buffers: Dict[str, torch.Tensor] = {} self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
@property
def graph(self):
assert self._graph is not None
return self._graph
def capture( def capture(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -898,7 +1094,7 @@ class CUDAGraphRunner: ...@@ -898,7 +1094,7 @@ class CUDAGraphRunner:
memory_pool, memory_pool,
**kwargs, **kwargs,
) -> None: ) -> None:
assert self.graph is None assert self._graph is None
# Run the model once without capturing the graph. # Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune). # kernel launches for initial benchmarking (e.g., Triton autotune).
...@@ -915,8 +1111,8 @@ class CUDAGraphRunner: ...@@ -915,8 +1111,8 @@ class CUDAGraphRunner:
# Capture the graph. # Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements. # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph() self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
with _maybe_pynccl(): with _maybe_pynccl():
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
...@@ -933,8 +1129,8 @@ class CUDAGraphRunner: ...@@ -933,8 +1129,8 @@ class CUDAGraphRunner:
"positions": positions, "positions": positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.context_lens, "context_lens": attn_metadata.decode_metadata.context_lens,
"block_tables": attn_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
return return
...@@ -955,10 +1151,10 @@ class CUDAGraphRunner: ...@@ -955,10 +1151,10 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True) non_blocking=True)
self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, self.input_buffers["context_lens"].copy_(
non_blocking=True) attn_metadata.decode_metadata.context_lens, non_blocking=True)
self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, self.input_buffers["block_tables"].copy_(
non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()
......
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import nn
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.neuron_model_loader import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available, from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
...@@ -34,9 +35,11 @@ class NeuronModelRunner: ...@@ -34,9 +35,11 @@ class NeuronModelRunner:
self.device_config = (device_config self.device_config = (device_config
if device_config is not None else DeviceConfig()) if device_config is not None else DeviceConfig())
self.device = self.device_config.device self.device = self.device_config.device
self.model = None
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
# Lazy initialization.
self.model: nn.Module # initialize after load_model.
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_neuron_model(self.model_config, self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
...@@ -147,7 +150,11 @@ class NeuronModelRunner: ...@@ -147,7 +150,11 @@ class NeuronModelRunner:
selected_token_indices: List[int] = [] selected_token_indices: List[int] = []
generators: List[torch.Generator] = [] generators: List[torch.Generator] = []
selected_token_start_idx = 0 selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0 categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0
...@@ -165,10 +172,9 @@ class NeuronModelRunner: ...@@ -165,10 +172,9 @@ class NeuronModelRunner:
categorized_sample_indices_start_idx += prompt_len - 1 categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[ categorized_sample_indices[
sampling_params.sampling_type].append([ sampling_params.sampling_type].append(
categorized_sample_indices_start_idx, (categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx categorized_sampled_token_indices_start_idx))
])
categorized_sample_indices_start_idx += 1 categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1
...@@ -237,7 +243,7 @@ class NeuronModelRunner: ...@@ -237,7 +243,7 @@ class NeuronModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
...@@ -259,7 +265,7 @@ class NeuronModelRunner: ...@@ -259,7 +265,7 @@ class NeuronModelRunner:
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_block_ids, sampling_metadata (input_tokens, input_positions, input_block_ids, sampling_metadata
) = self.prepare_input_tensors(seq_group_metadata_list) ) = self.prepare_input_tensors(seq_group_metadata_list)
......
"""A Neuron worker class.""" """A Neuron worker class."""
from typing import List, Optional from typing import List, Tuple
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NeuronWorker: class NeuronWorker(LoraNotSupportedWorkerBase):
"""A worker class that executes the model on a group of neuron cores. """A worker class that executes the model on a group of neuron cores.
""" """
...@@ -21,11 +22,17 @@ class NeuronWorker: ...@@ -21,11 +22,17 @@ class NeuronWorker:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
cache_config: CacheConfig,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = NeuronModelRunner(model_config, parallel_config, self.model_runner = NeuronModelRunner(model_config, parallel_config,
scheduler_config, device_config) scheduler_config, device_config)
...@@ -37,16 +44,55 @@ class NeuronWorker: ...@@ -37,16 +44,55 @@ class NeuronWorker:
def load_model(self): def load_model(self):
self.model_runner.load_model() self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
We configure num_gpu_blocks to be equal to max_num_seqs.
"""
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
num_gpu_blocks = self.scheduler_config.max_num_seqs
# Swap not yet supported with Neuron backend.
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache.
"""
# Different values are not tested.
assert num_cpu_blocks == 0
assert num_gpu_blocks == self.scheduler_config.max_num_seqs
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Optional[SamplerOutput]: ) -> List[SamplerOutput]:
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if num_seq_groups == 0: if num_seq_groups == 0:
return {} return []
output = self.model_runner.execute_model(seq_group_metadata_list) output = self.model_runner.execute_model(seq_group_metadata_list)
return output
# Neuron worker only supports single-step output. Wrap the output in a
# list to conform to interface.
return [output]
def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.
This is required for speculative decoding; it is not yet implemented.
"""
raise NotImplementedError
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils import pynccl_utils
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase
class Worker: class Worker(WorkerBase):
"""A worker class that executes (a partition of) the model on a GPU. """A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for Each worker is associated with a single GPU. The worker is responsible for
...@@ -35,26 +37,33 @@ class Worker: ...@@ -35,26 +37,33 @@ class Worker:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.vision_language_config = vision_language_config self.vision_language_config = vision_language_config
if self.vision_language_config: if self.vision_language_config:
assert not self.lora_config, ( assert not self.lora_config, (
...@@ -65,15 +74,16 @@ class Worker: ...@@ -65,15 +74,16 @@ class Worker:
parallel_config, parallel_config,
scheduler_config, scheduler_config,
device_config, device_config,
load_config=load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config) vision_language_config=vision_language_config,
)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # initialize_cache.
self.cache_config = None self.cache_engine: CacheEngine
self.cache_engine = None self.gpu_cache: List[torch.Tensor]
self.gpu_cache = None
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
...@@ -97,9 +107,9 @@ class Worker: ...@@ -97,9 +107,9 @@ class Worker:
raise RuntimeError( raise RuntimeError(
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank, init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method, self.distributed_init_method,
self.local_rank) self.local_rank)
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
...@@ -107,20 +117,17 @@ class Worker: ...@@ -107,20 +117,17 @@ class Worker:
self.model_runner.load_model() self.model_runner.load_model()
@torch.inference_mode() @torch.inference_mode()
def profile_num_available_blocks( def determine_num_available_blocks(self) -> Tuple[int, int]:
self, """Profiles the peak memory usage of the model to determine how many
block_size: int, KV blocks may be allocated without OOMs.
gpu_memory_utilization: float,
cpu_swap_space: int, The engine will first conduct a profiling of the existing memory usage.
cache_dtype: str, Then, it calculate the maximum possible number of GPU and CPU blocks
) -> Tuple[int, int]: that can be allocated with the remaining free memory.
"""Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated. .. tip::
You may limit the usage of GPU memory
Args: by adjusting the `gpu_memory_utilization` parameter.
block_size: The size of the cache block.
gpu_memory_utilization: The fraction of the total GPU memory to use.
cpu_swap_space: The size of the CPU swap space in bytes.
""" """
# Profile the memory usage of the model and get the maximum number of # Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory. # cache blocks that can be allocated with the remaining free memory.
...@@ -141,12 +148,12 @@ class Worker: ...@@ -141,12 +148,12 @@ class Worker:
"Error in memory profiling. This happens when the GPU memory was " "Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.") "not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes( cache_block_size = self.get_cache_block_size_bytes()
block_size, cache_dtype)
num_gpu_blocks = int( num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) // (total_gpu_memory * self.cache_config.gpu_memory_utilization -
cache_block_size) peak_memory) // cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager: if self.model_runner.lora_manager:
...@@ -155,14 +162,30 @@ class Worker: ...@@ -155,14 +162,30 @@ class Worker:
torch.cuda.empty_cache() torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None: def initialize_cache(self, num_gpu_blocks: int,
self.cache_config = cache_config num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(num_gpu_blocks,
self.cache_config.block_size,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._init_cache_engine()
self._warm_up_model()
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config) self.parallel_config)
self.gpu_cache = self.cache_engine.gpu_cache self.gpu_cache = self.cache_engine.gpu_cache
self.model_runner.set_block_size(self.cache_engine.block_size) self.model_runner.set_block_size(self.cache_engine.block_size)
def warm_up_model(self) -> None: def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache) self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
...@@ -191,14 +214,16 @@ class Worker: ...@@ -191,14 +214,16 @@ class Worker:
blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]: num_lookahead_slots: int = 0,
) -> List[SamplerOutput]:
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None assert blocks_to_swap_out is not None
assert blocks_to_copy is not None assert blocks_to_copy is not None
data = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out, "blocks_to_swap_out": blocks_to_swap_out,
...@@ -212,15 +237,21 @@ class Worker: ...@@ -212,15 +237,21 @@ class Worker:
blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"] blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if num_seq_groups == 0: if num_seq_groups == 0:
return {} return []
output = self.model_runner.execute_model(seq_group_metadata_list, output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache) self.gpu_cache)
return output
# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
return [output]
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)
...@@ -239,40 +270,23 @@ class Worker: ...@@ -239,40 +270,23 @@ class Worker:
def vocab_size(self) -> int: def vocab_size(self) -> int:
return self.model_runner.vocab_size return self.model_runner.vocab_size
def get_cache_block_size_bytes(self, block_size: int, def get_cache_block_size_bytes(self) -> int:
cache_dtype: str) -> int:
"""Get the size of the KV cache block size in bytes. """Get the size of the KV cache block size in bytes.
""" """
return CacheEngine.get_cache_block_size(block_size, cache_dtype, return CacheEngine.get_cache_block_size(self.cache_config,
self.model_config, self.model_config,
self.parallel_config) self.parallel_config)
def init_distributed_environment( def init_worker_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
if torch.distributed.is_initialized(): init_distributed_environment(parallel_config.world_size, rank,
torch_world_size = torch.distributed.get_world_size() distributed_init_method, local_rank)
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
if pynccl_utils.is_initialized(): if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size() pynccl_world_size = pynccl_utils.get_world_size()
...@@ -284,17 +298,10 @@ def init_distributed_environment( ...@@ -284,17 +298,10 @@ def init_distributed_environment(
elif parallel_config.world_size > 1: elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size # NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1. # is 1.
pynccl_utils.init_process_group( # NOTE(kaichao): By default, pynccl will use information inside
world_size=parallel_config.world_size, # `parallel_state` for initialization.
local_rank=local_rank, pynccl_utils.init_process_group()
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if pynccl_utils.is_initialized():
pynccl_utils.all_reduce(torch.zeros(1).cuda())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
...@@ -302,6 +309,11 @@ def init_distributed_environment( ...@@ -302,6 +309,11 @@ def init_distributed_environment(
if not parallel_config.disable_custom_all_reduce: if not parallel_config.disable_custom_all_reduce:
init_custom_ar() init_custom_ar()
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if pynccl_utils.is_initialized():
pynccl_utils.all_reduce(torch.zeros(1).cuda())
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
...@@ -315,3 +327,19 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): ...@@ -315,3 +327,19 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
f"{compute_capability[0]}.{compute_capability[1]}. " f"{compute_capability[0]}.{compute_capability[1]}. "
"You can use float16 instead by explicitly setting the" "You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.") "`dtype` flag in CLI, for example: --dtype=half.")
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
max_model_len) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
import datetime
import importlib
import os
import tempfile
import threading
from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple
from vllm.logger import enable_trace_function_call, init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import get_vllm_instance_id, update_environment_variables
logger = init_logger(__name__)
class WorkerBase(ABC):
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware.
"""
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError
@abstractmethod
def execute_model(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
int],
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise NotImplementedError
@abstractmethod
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError
class LoraNotSupportedWorkerBase(WorkerBase):
"""Partial implementation of WorkerBase that raises exceptions when LoRA
methods are invoked.
"""
def add_lora(self, lora_request: LoRARequest) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")
class WorkerWrapperBase:
"""
The whole point of this class is to lazily initialize the worker.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def __init__(self,
worker_module_name=None,
worker_class_name=None,
trust_remote_code: bool = False) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker = None
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
@staticmethod
def update_environment_variables(envs: Dict[str, str]) -> None:
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
update_environment_variables(envs)
def init_worker(self, *args, **kwargs):
"""
Actual initialization of the worker class, and set up
function tracing if required.
Arguments are passed to the worker class constructor.
"""
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
def execute_method(self, method, *args, **kwargs):
try:
target = self if self.worker is None else self.worker
executor = getattr(target, method)
return executor(*args, **kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment