Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
...@@ -179,7 +179,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -179,7 +179,7 @@ class TP1DraftModelRunner(ModelRunner):
return False return False
# TODO: Add support for other attn backends # TODO: Add support for other attn backends
if self.attn_backend.get_name() != "flash-attn": if self.attn_backend.get_name() != "FLASH_ATTN":
return False return False
# TODO: Add support for LORA # TODO: Add support for LORA
......
...@@ -184,7 +184,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -184,7 +184,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if not disable_mqa_scorer: if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name( if scorer_worker.model_runner.attn_backend.get_name(
) != "flash-attn": ) != "FLASH_ATTN":
disable_mqa_scorer = True disable_mqa_scorer = True
logger.info( logger.info(
"[Speculative Decoding] Disabling MQA scorer as the " "[Speculative Decoding] Disabling MQA scorer as the "
......
...@@ -232,6 +232,68 @@ def get_config( ...@@ -232,6 +232,68 @@ def get_config(
return config return config
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
"""Try to register HF model configuration class to serialize by value
With trust_remote_code, the config class is typically an instance of a
custom class imported from the HF modules cache. The class will not be
importable in spawned workers by default (and won't exist at all on
other nodes), which breaks serialization of the config.
In this function we tell the cloudpickle serialization library to pass
instances of these generated classes by value instead of by reference,
i.e. the class definition is serialized along with its data so that the
class module does not need to be importable on the receiving end. This
registration only works if the modules cache has already been
initialized.
See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
"""
if not trust_remote_code:
return
try:
import transformers_modules
except ImportError:
logger.debug("Could not import transformers_modules used for remote"
" code. If remote code is not needed remove"
" `--trust-remote-code`.")
return
try:
import cloudpickle
cloudpickle.register_pickle_by_value(transformers_modules)
# ray vendors its own version of cloudpickle
from vllm.executor.ray_utils import ray
if ray:
ray.cloudpickle.register_pickle_by_value(transformers_modules)
# multiprocessing uses pickle to serialize arguments when using spawn
# Here we get pickle to use cloudpickle to serialize ModelConfig objects
# that contain instances of the custom config class to avoid
# serialization problems if the generated module (and model) has a `.`
# in its name
import multiprocessing
import pickle
from vllm.config import ModelConfig
def _reduce_modelconfig(mc: ModelConfig):
return (pickle.loads, (cloudpickle.dumps(mc), ))
multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)
except Exception as e:
logger.warning(
"Unable to register remote classes used by"
" trust_remote_code with by-value serialization. This may"
" lead to a later error. If remote code is not needed"
" remove `--trust-remote-code`",
exc_info=e)
def load_params_config(model, revision) -> PretrainedConfig: def load_params_config(model, revision) -> PretrainedConfig:
# This function loads a params.json config which # This function loads a params.json config which
# should be used when loading models in mistral format # should be used when loading models in mistral format
......
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
Sequence, SequenceGroup) Sequence, SequenceGroup)
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer from .tokenizer import AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup from .tokenizer_group import BaseTokenizerGroup
...@@ -88,7 +90,7 @@ class Detokenizer: ...@@ -88,7 +90,7 @@ class Detokenizer:
prefix_offset = next_iter_prefix_offset prefix_offset = next_iter_prefix_offset
read_offset = next_iter_read_offset read_offset = next_iter_read_offset
if prev_tokens is None: if prev_tokens is None:
prev_tokens = next_iter_tokens prev_tokens = next_iter_tokens.copy()
else: else:
prev_tokens.extend(next_iter_tokens) prev_tokens.extend(next_iter_tokens)
...@@ -161,167 +163,3 @@ class Detokenizer: ...@@ -161,167 +163,3 @@ class Detokenizer:
seq.output_text += new_decoded_token_text seq.output_text += new_decoded_token_text
return len(new_decoded_token_text) return len(new_decoded_token_text)
def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""
def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts: List[str] = []
current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: AnyTokenizer,
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# We do not need to convert the whole prompt to tokens.
# Offset a little more in case we have special tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
skip_special_tokens=skip_special_tokens)
read_offset = len(new_tokens)
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
return new_tokens, prefix_offset, read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: AnyTokenizer,
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string.
if 0 <= new_token_id < len(tokenizer):
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
else:
new_tokens = [""]
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
return new_tokens, "", prefix_offset, read_offset
new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
from typing import List, Optional, Tuple
from .tokenizer import AnyTokenizer
def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""
def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts: List[str] = []
current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: AnyTokenizer,
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# We do not need to convert the whole prompt to tokens.
# Offset a little more in case we have special tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
skip_special_tokens=skip_special_tokens)
read_offset = len(new_tokens)
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
return new_tokens, prefix_offset, read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: AnyTokenizer,
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string.
if 0 <= new_token_id < len(tokenizer):
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
else:
new_tokens = [""]
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
return new_tokens, "", prefix_offset, read_offset
new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
from functools import lru_cache
from typing import Any, cast from typing import Any, cast
...@@ -37,6 +38,9 @@ def get_processor( ...@@ -37,6 +38,9 @@ def get_processor(
return cast(ProcessorMixin, processor) return cast(ProcessorMixin, processor)
cached_get_processor = lru_cache(get_processor)
def get_image_processor( def get_image_processor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
......
...@@ -2,11 +2,12 @@ import os ...@@ -2,11 +2,12 @@ import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
# yapf: disable # yapf: disable
from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import ( from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer) MistralTokenizer as PublicMistralTokenizer)
# yapf: enable # yapf: enable
...@@ -24,6 +25,26 @@ class Encoding: ...@@ -24,6 +25,26 @@ class Encoding:
input_ids: List[int] input_ids: List[int]
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *repo_id.split("/")]))
if revision is None:
revision_file = os.path.join(repo_cache, "refs", "main")
if os.path.isfile(revision_file):
with open(revision_file) as file:
revision = file.read()
if revision:
revision_dir = os.path.join(repo_cache, "snapshots", revision)
if os.path.isdir(revision_dir):
return os.listdir(revision_dir)
return []
def find_tokenizer_file(files: List[str]): def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$") file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
...@@ -90,9 +111,16 @@ class MistralTokenizer: ...@@ -90,9 +111,16 @@ class MistralTokenizer:
@staticmethod @staticmethod
def _download_mistral_tokenizer_from_hf(tokenizer_name: str, def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision: Optional[str]) -> str: revision: Optional[str]) -> str:
api = HfApi() try:
repo_info = api.model_info(tokenizer_name) hf_api = HfApi()
files = [s.rfilename for s in repo_info.siblings] files = hf_api.list_repo_files(repo_id=tokenizer_name,
revision=revision)
except ConnectionError as exc:
files = list_local_repo_files(repo_id=tokenizer_name,
revision=revision)
if len(files) == 0:
raise exc
filename = find_tokenizer_file(files) filename = find_tokenizer_file(files)
...@@ -166,7 +194,7 @@ class MistralTokenizer: ...@@ -166,7 +194,7 @@ class MistralTokenizer:
tools: Optional[Dict[str, Any]] = None, tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]: **kwargs) -> List[int]:
last_message = messages[-1] last_message = cast(Dict[str, Any], messages[-1])
if last_message["role"] == "assistant": if last_message["role"] == "assistant":
last_message["prefix"] = True last_message["prefix"] = True
......
from importlib.util import find_spec from importlib.util import find_spec
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
HAS_TRITON = find_spec("triton") is not None # neuron has too old torch
HAS_TRITON = find_spec(
"triton") is not None and not current_platform.is_neuron()
if not HAS_TRITON: if not HAS_TRITON:
logger.info("Triton not installed; certain GPU-related functions" logger.info("Triton not installed; certain GPU-related functions"
......
...@@ -13,10 +13,12 @@ import subprocess ...@@ -13,10 +13,12 @@ import subprocess
import sys import sys
import tempfile import tempfile
import threading import threading
import time
import uuid import uuid
import warnings import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, ensure_future from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections.abc import Mapping
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
...@@ -316,15 +318,6 @@ def is_hip() -> bool: ...@@ -316,15 +318,6 @@ def is_hip() -> bool:
return torch.version.hip is not None return torch.version.hip is not None
@lru_cache(maxsize=None)
def is_cpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "cpu" in version("vllm")
except PackageNotFoundError:
return False
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def is_openvino() -> bool: def is_openvino() -> bool:
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
...@@ -334,15 +327,6 @@ def is_openvino() -> bool: ...@@ -334,15 +327,6 @@ def is_openvino() -> bool:
return False return False
@lru_cache(maxsize=None)
def is_neuron() -> bool:
try:
import transformers_neuronx
except ImportError:
transformers_neuronx = None
return transformers_neuronx is not None
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def is_xpu() -> bool: def is_xpu() -> bool:
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
...@@ -436,6 +420,12 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]: ...@@ -436,6 +420,12 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return _async_wrapper return _async_wrapper
def _next_task(iterator: AsyncGenerator[T, None],
loop: AbstractEventLoop) -> Task:
# Can use anext() in python >= 3.10
return loop.create_task(iterator.__anext__()) # type: ignore[arg-type]
async def iterate_with_cancellation( async def iterate_with_cancellation(
iterator: AsyncGenerator[T, None], iterator: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]], is_cancelled: Callable[[], Awaitable[bool]],
...@@ -444,19 +434,27 @@ async def iterate_with_cancellation( ...@@ -444,19 +434,27 @@ async def iterate_with_cancellation(
at least once per second to check for client cancellation. at least once per second to check for client cancellation.
""" """
# Can use anext() in python >= 3.10 loop = asyncio.get_running_loop()
awaits = [ensure_future(iterator.__anext__())]
awaits: List[Future[T]] = [_next_task(iterator, loop)]
next_cancel_check: float = 0
while True: while True:
done, pending = await asyncio.wait(awaits, timeout=1) done, pending = await asyncio.wait(awaits, timeout=1.5)
if await is_cancelled():
with contextlib.suppress(BaseException): # Check for cancellation at most once per second
awaits[0].cancel() time_now = time.time()
await iterator.aclose() if time_now >= next_cancel_check:
raise asyncio.CancelledError("client cancelled") if await is_cancelled():
with contextlib.suppress(BaseException):
awaits[0].cancel()
await iterator.aclose()
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
if done: if done:
try: try:
item = await awaits[0] item = await awaits[0]
awaits[0] = ensure_future(iterator.__anext__()) awaits[0] = _next_task(iterator, loop)
yield item yield item
except StopAsyncIteration: except StopAsyncIteration:
# we are done # we are done
...@@ -477,25 +475,29 @@ async def merge_async_iterators( ...@@ -477,25 +475,29 @@ async def merge_async_iterators(
to check for client cancellation. to check for client cancellation.
""" """
# Can use anext() in python >= 3.10 loop = asyncio.get_running_loop()
awaits = {
ensure_future(pair[1].__anext__()): pair awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
for pair in enumerate(iterators) timeout = None if is_cancelled is None else 1.5
} next_cancel_check: float = 0
timeout = None if is_cancelled is None else 1
try: try:
while awaits: while awaits:
done, pending = await asyncio.wait(awaits.keys(), done, pending = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED, return_when=FIRST_COMPLETED,
timeout=timeout) timeout=timeout)
if is_cancelled is not None and await is_cancelled(): if is_cancelled is not None:
raise asyncio.CancelledError("client cancelled") # Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
for d in done: for d in done:
pair = awaits.pop(d) pair = awaits.pop(d)
try: try:
item = await d item = await d
i, it = pair i, it = pair
awaits[ensure_future(it.__anext__())] = pair awaits[_next_task(it, loop)] = pair
yield i, item yield i, item
except StopAsyncIteration: except StopAsyncIteration:
pass pass
...@@ -775,10 +777,10 @@ def is_pin_memory_available() -> bool: ...@@ -775,10 +777,10 @@ def is_pin_memory_available() -> bool:
elif is_xpu(): elif is_xpu():
print_warning_once("Pin memory is not supported on XPU.") print_warning_once("Pin memory is not supported on XPU.")
return False return False
elif is_neuron(): elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.") print_warning_once("Pin memory is not supported on Neuron.")
return False return False
elif is_cpu() or is_openvino(): elif current_platform.is_cpu() or is_openvino():
return False return False
return True return True
...@@ -948,6 +950,8 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]: ...@@ -948,6 +950,8 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist] return [item for sublist in lists for item in sublist]
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None: def init_cached_hf_modules() -> None:
""" """
Lazy initialization of the Hugging Face modules. Lazy initialization of the Hugging Face modules.
...@@ -1033,10 +1037,54 @@ def identity(value: T) -> T: ...@@ -1033,10 +1037,54 @@ def identity(value: T) -> T:
F = TypeVar('F', bound=Callable[..., Any]) F = TypeVar('F', bound=Callable[..., Any])
def deprecate_args(
start_index: int,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None,
) -> Callable[[F], F]:
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
params = inspect.signature(fn).parameters
pos_types = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
pos_kws = [
kw for kw, param in params.items() if param.kind in pos_types
]
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_args = pos_kws[start_index:len(args)]
if deprecated_args:
msg = (
f"The positional arguments {deprecated_args} are "
"deprecated and will be removed in a future update.")
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
def deprecate_kwargs( def deprecate_kwargs(
*kws: str, *kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True, is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None) -> Callable[[F], F]: additional_message: Optional[str] = None,
) -> Callable[[F], F]:
deprecated_kws = set(kws) deprecated_kws = set(kws)
if not callable(is_deprecated): if not callable(is_deprecated):
...@@ -1442,3 +1490,24 @@ class AtomicCounter: ...@@ -1442,3 +1490,24 @@ class AtomicCounter:
@property @property
def value(self): def value(self):
return self._value return self._value
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping, Generic[T]):
def __init__(self, factory: Dict[str, Callable[[], T]]):
self._factory = factory
self._dict: Dict[str, T] = {}
def __getitem__(self, key) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]
def __iter__(self):
return iter(self._factory)
def __len__(self):
return len(self._factory)
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.forward_context import get_forward_context
from vllm.vllm_flash_attn import flash_attn_varlen_func
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "flash-attn-vllm-v1"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@dataclass
class FlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
class FlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if sliding_window is not None:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise ValueError(
"Sliding window is not supported in FlashAttention.")
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
output = torch.ops.vllm.unified_flash_attention(
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
return output
@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
if current_metadata is None:
# Profiling run.
return torch.empty_like(query)
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
return output.view(num_tokens, hidden_size)
@unified_flash_attention.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
from typing import Dict, List, Optional
import numpy as np
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.request import Request
logger = init_logger(__name__)
class KVCacheManager:
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
sliding_window: Optional[int] = None,
enable_caching: bool = True,
num_preallocate_tokens: int = 64,
) -> None:
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.sliding_window = sliding_window
self.enable_caching = enable_caching
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
# blocks for each request. For example, when a request reaches the end
# of its block table, we preallocate N blocks in advance. This way, we
# reduce the overhead of updating free_block_ids and ref_cnts for each
# request every step (at the cost of some memory waste).
# NOTE(woosuk): This is different from the "lookahead" slots since this
# does not guarantee that the request always has N empty blocks. After
# the request gets N empty blocks, it starts to use the blocks without
# further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks.
self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
self.free_block_ids = list(range(num_gpu_blocks))
self.req_to_block_ids: Dict[str, List[int]] = {}
self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32)
def get_computed_blocks(self, request: Request) -> List[int]:
if not self.enable_caching:
# No prefix caching.
return []
# TODO(woosuk): Implement hash-based caching.
return []
def append_slots(
self,
request: Request,
num_tokens: int,
) -> Optional[List[int]]:
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
self.block_size)
req_block_ids = self.req_to_block_ids[request.request_id]
if num_required_blocks <= len(req_block_ids):
# No new block is needed.
return []
num_new_blocks = num_required_blocks - len(req_block_ids)
num_free_blocks = len(self.free_block_ids)
if num_new_blocks > num_free_blocks:
# Cannot allocate new blocks.
return None
# Allocate new blocks.
num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks,
num_free_blocks)
new_block_ids = self._get_new_blocks(num_new_blocks)
req_block_ids.extend(new_block_ids)
self.ref_cnts[new_block_ids] += 1
return new_block_ids
def allocate_slots(
self,
request: Request,
num_tokens: int,
computed_block_ids: List[int],
) -> Optional[List[int]]:
num_required_blocks = cdiv(num_tokens, self.block_size)
num_free_blocks = len(self.free_block_ids)
if num_required_blocks > num_free_blocks:
# Cannot allocate new blocks.
return None
num_new_blocks = min(num_required_blocks + self.num_preallocate_blocks,
num_free_blocks)
new_block_ids = self._get_new_blocks(num_new_blocks)
block_ids = computed_block_ids + new_block_ids
self.req_to_block_ids[request.request_id] = block_ids
self.ref_cnts[block_ids] += 1
return new_block_ids
def free(self, request: Request) -> None:
block_ids = self.req_to_block_ids.pop(request.request_id)
self.ref_cnts[block_ids] -= 1
for block_id in block_ids:
ref_cnt = self.ref_cnts[block_id]
if ref_cnt == 0:
self.free_block_ids.append(block_id)
def _get_new_blocks(self, num_blocks: int) -> List[int]:
assert num_blocks <= len(self.free_block_ids)
new_block_ids = self.free_block_ids[-num_blocks:]
self.free_block_ids = self.free_block_ids[:-num_blocks]
return new_block_ids
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
class Scheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
# TODO: Support LoRA.
assert lora_config is None, "V1 does not support LoRA yet."
num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# Create the block space manager.
self.kv_cache_manager = KVCacheManager(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=True)
self.block_size = self.cache_config.block_size
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
# req_id -> Request
self.requests: Dict[str, Request] = {}
# Priority queues for requests.
self.waiting: Deque[Request] = deque()
self.running: List[Request] = []
# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
# requests so that they can free the cached states for those requests.
# This is flushed at the end of each scheduling step.
self.finished_req_ids: Set[str] = set()
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> RunningRequestData
self.running_reqs_data: Dict[str, RunningRequestData] = {}
def schedule(self) -> "SchedulerOutput":
scheduled_new_reqs: List[Request] = []
scheduled_resumed_reqs: List[Request] = []
scheduled_running_reqs: List[Request] = []
preempted_reqs: List[Request] = []
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and num_tokens,
# which is equal to len(prompt_token_ids) + len(output_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens. This is general enough to cover chunked prefills,
# prefix caching, and the "jump forward" optimization in the future.
req_to_new_block_ids: Dict[str, List[int]] = {}
num_scheduled_tokens: Dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running):
if token_budget == 0:
break
request = self.running[req_index]
num_new_tokens = request.num_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
while True:
new_block_ids = self.kv_cache_manager.append_slots(
request, num_new_tokens)
if new_block_ids is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
self.waiting.appendleft(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
break
else:
# The request can be scheduled.
scheduled_running_reqs.append(request)
req_to_new_block_ids[request.request_id] = new_block_ids
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
break
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting:
if len(self.running) == self.max_num_running_reqs:
break
if token_budget == 0:
break
request = self.waiting[0]
# Get already-cached tokens.
computed_block_ids = self.kv_cache_manager.get_computed_blocks(
request)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_block_ids) * self.block_size
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
new_block_ids = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_block_ids)
if new_block_ids is None:
# The request cannot be scheduled.
break
request.num_computed_tokens = num_computed_tokens
self.waiting.popleft()
self.running.append(request)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(
f"Invalid request status: {request.status}")
req_to_new_block_ids[request.request_id] = (
computed_block_ids + new_block_ids)
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) == len(self.running))
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id],
req.num_computed_tokens)
for req in scheduled_new_reqs
]
resumed_reqs_data = [
ResumedRequestData.from_request(
req, req_to_new_block_ids[req.request_id],
req.num_computed_tokens) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_running_request_data(
req, req_to_new_block_ids[req.request_id],
req.num_computed_tokens) for req in scheduled_running_reqs
]
preempted_req_ids = {req.request_id for req in preempted_reqs}
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_resumed_reqs=resumed_reqs_data,
scheduled_running_reqs=running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
preempted_req_ids=preempted_req_ids,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
)
self.finished_req_ids = set()
return scheduler_output
def _make_running_request_data(
self,
request: Request,
new_block_ids: List[int],
num_computed_tokens: int,
) -> "RunningRequestData":
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
# them at each scheduling step.
if request.request_id in self.running_reqs_data:
req_data = self.running_reqs_data[request.request_id]
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
req_data = RunningRequestData.from_request(request, new_block_ids,
num_computed_tokens)
self.running_reqs_data[request.request_id] = req_data
return req_data
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> List[Tuple[Request, int]]:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
# (request, num_sampled_tokens)
sampled: List[Tuple[Request, int]] = []
for request in self.running:
req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id]
# When the request's num_computed_tokens catches up its num_tokens,
# the request generates output tokens. Otherwise, we ignore the
# sampler output for the request.
assert request.num_computed_tokens <= request.num_tokens
if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.output_token_ids.append(token_id)
sampled.append((request, 1))
# TODO: Update the KV cache manager for prefix caching.
# Check if the request is finished.
stopped = self._check_stop(request)
if stopped:
continue
new_running.append(request)
self.running = new_running
return sampled
def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
self._free_request(request)
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
self._free_request(request)
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
self._free_request(request)
return True
return False
def add_request(self, request: Request) -> None:
self.waiting.append(request)
self.requests[request.request_id] = request
def finish_requests(
self,
request_ids: Union[str, Iterable[str]],
finished_status: RequestStatus,
) -> None:
"""Handles the finish signal from outside the scheduler.
For example, the API server can abort a request when the client
disconnects.
"""
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids, )
request_ids = set(request_ids)
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
# Invalid request ID.
continue
if request.status == RequestStatus.RUNNING:
self.running.remove(request)
else:
self.waiting.remove(request)
request.status = finished_status
self._free_request(request)
def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.running_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional[MultiModalDataDict]
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
block_ids: List[int],
num_computed_tokens: int,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
prompt_token_ids=request.inputs["prompt_token_ids"],
prompt=request.inputs.get("prompt"),
multi_modal_data=request.inputs.get("multi_modal_data"),
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
)
@dataclass
class ResumedRequestData:
req_id: str
block_ids: List[int]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
block_ids: List[int],
num_computed_tokens: int,
) -> "ResumedRequestData":
return cls(
req_id=request.request_id,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
)
@dataclass
class RunningRequestData:
req_id: str
new_block_ids: List[int]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
new_block_ids: List[int],
num_computed_tokens: int,
) -> "RunningRequestData":
return cls(
req_id=request.request_id,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)
@dataclass
class SchedulerOutput:
scheduled_new_reqs: List[NewRequestData]
scheduled_resumed_reqs: List[ResumedRequestData]
scheduled_running_reqs: List[RunningRequestData]
num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
preempted_req_ids: Set[str]
finished_req_ids: Set[str]
import time
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
Union)
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderLLMInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
class LLMEngine:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[GPUExecutor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
# Override the configs for V1.
# FIXME
if usage_context == UsageContext.LLM_CLASS:
scheduler_config.max_num_seqs = 1024
scheduler_config.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
scheduler_config.max_num_seqs = 1024
scheduler_config.max_num_batched_tokens = 2048
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
model_config.tokenizer,
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
model_config.max_model_len,
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
observability_config,
model_config.seed,
model_config.served_model_name,
scheduler_config.num_scheduler_steps,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
model_config.mm_processor_kwargs,
)
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
assert not self.model_config.skip_tokenizer_init
self.tokenizer = self._init_tokenizer()
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.detokenizer = Detokenizer(self.model_config.tokenizer)
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
model_config)
# Request id -> Request
self.requests: Dict[str, Request] = {}
# NOTE(woosuk): Now that the detokenizer works asynchronously, we need
# to keep track of how many steps each request has been lagged behind
# in terms of detokenization.
# Request id -> how many detokenizer steps the request should wait for.
self.num_lagged_steps: Dict[str, int] = {}
# OPTIMIZATION: Cache the request output and update it incrementally.
# This is used to avoid creating a new RequestOutput object every step.
# Request id -> RequestOutput
self.request_outputs: Dict[str, RequestOutput] = {}
self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
assert self.model_config.task != "embedding"
self._initialize_kv_caches()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
def _initialize_kv_caches(self) -> None:
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
)
if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_gpu_blocks,
num_gpu_blocks_override)
num_gpu_blocks = num_gpu_blocks_override
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = 0
self.model_executor.initialize_cache(num_gpu_blocks)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine
def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config))
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _add_processed_request(
self,
request_id: str,
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs],
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
assert prompt_adapter_request is None
assert trace_headers is None
self._validate_model_inputs(processed_inputs)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
# TODO(woosuk): Support embedding mode.
assert isinstance(params, SamplingParams)
sampling_params = params.clone()
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
# TODO(woosuk): Check max_logprobs
# TODO(woosuk): Support encoder-decoder models.
req = Request(request_id, processed_inputs, params, eos_token_id,
arrival_time)
self.requests[request_id] = req
self.num_lagged_steps[request_id] = 0
self.scheduler.add_request(req)
def stop_remote_worker_execution_loop(self) -> None:
raise NotImplementedError("TP not implemented yet.")
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
assert priority == 0, "vLLM V1 does not support priority at the moment."
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
)
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
self.scheduler.finish_requests(request_id,
RequestStatus.FINISHED_ABORTED)
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return len(self.requests)
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return len(self.requests) > 0
def step(self) -> List[RequestOutput]:
# NOTE(woosuk): This method may return an empty list when the
# detokenizer is still processing the outputs. This should not be
# considered as the end of the generation process.
# FIXME(woosuk): Currently, the step method is inefficient because it
# creates RequestOutput objects for all running requests, while they
# may not be needed unless the output is streamed to the client.
if self.scheduler.has_unfinished_requests():
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
sampled = self.scheduler.update_from_output(
scheduler_output, output)
self.send_to_detokenizer(sampled)
req_outputs = self.recv_from_detokenizer()
return req_outputs
def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
inputs = DetokenizerInputs(
req_ids=[],
prompt_token_ids=[],
new_token_ids=[],
skip_special_tokens=[],
spaces_between_special_tokens=[],
free_req_ids=[], # TODO(woosuk): Implement freeing.
)
for req, num_tokens in sampled:
inputs.req_ids.append(req.request_id)
if len(req.output_token_ids) == num_tokens:
# The request is first detokenized.
inputs.prompt_token_ids.append(req.prompt_token_ids)
else:
# The prompt token ids are already cached in the detokenizer.
inputs.prompt_token_ids.append([])
inputs.new_token_ids.append(req.output_token_ids[-num_tokens:])
inputs.skip_special_tokens.append(
req.sampling_params.skip_special_tokens)
inputs.spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens)
# Update the number of lagged steps.
self.num_lagged_steps[req.request_id] += 1
self.detokenizer.send(inputs)
def recv_from_detokenizer(self) -> List[RequestOutput]:
detokenizer_output = self.detokenizer.recv()
if detokenizer_output is None:
return []
req_outputs: List[RequestOutput] = []
num_reqs = len(detokenizer_output.req_ids)
for i in range(num_reqs):
req_id = detokenizer_output.req_ids[i]
req = self.requests[req_id]
req.output_text += detokenizer_output.detokenized_texts[i]
self.num_lagged_steps[req_id] -= 1
finished = (self.num_lagged_steps[req_id] == 0
and req.is_finished())
req_output = self._make_request_output(
req, detokenizer_output.num_output_token_ids[i],
detokenizer_output.detokenized_texts[i], finished)
req_outputs.append(req_output)
if finished:
del self.requests[req_id]
del self.num_lagged_steps[req_id]
del self.request_outputs[req_id]
return req_outputs
def terminate_detokenizer(self) -> None:
self.detokenizer.terminate()
def _make_request_output(
self,
request: Request,
num_output_tokens: int,
new_output_text: str,
finished: bool,
) -> RequestOutput:
req_output = self.request_outputs.get(request.request_id)
if req_output is None:
# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None, # TODO
finish_reason=None,
stop_reason=None,
lora_request=None,
)
req_output = RequestOutput(
request_id=request.request_id,
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
prompt_logprobs=None, # TODO
outputs=[completion_output],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
self.request_outputs[request.request_id] = req_output
completion_output = req_output.outputs[0]
if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE:
completion_output.text += new_output_text
completion_output.token_ids = (
request.output_token_ids[:num_output_tokens])
elif request.sampling_params.output_kind == RequestOutputKind.DELTA:
completion_output.text = new_output_text
num_prev_tokens = len(completion_output.token_ids)
completion_output.token_ids = request.output_token_ids[
num_prev_tokens:num_output_tokens]
elif (request.sampling_params.output_kind ==
RequestOutputKind.FINAL_ONLY):
if finished:
completion_output.text = request.output_text
completion_output.token_ids = request.output_token_ids
else:
completion_output.text = ""
completion_output.token_ids = []
if finished:
completion_output.finish_reason = request.get_finished_reason()
completion_output.stop_reason = request.stop_reason
req_output.finished = finished
return req_output
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderLLMInputs]):
prompt_ids = inputs.get("prompt_token_ids")
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). "
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
@classmethod
def validate_outputs(cls, outputs, output_type):
return outputs
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
def get_parallel_config(self) -> ParallelConfig:
"""Gets the parallel configuration."""
return self.parallel_config
def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config
def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config
def get_lora_config(self) -> LoRAConfig:
"""Gets the LoRA configuration."""
return self.lora_config
@classmethod
def _get_executor_cls(cls, engine_config: EngineConfig):
return GPUExecutor
def is_tracing_enabled(self) -> bool:
return False
def do_log_stats(self, *args, **kwargs) -> None:
pass
def is_encoder_decoder_model(self) -> bool:
return False
def start_profile(self) -> None:
pass
def stop_profile(self) -> None:
pass
def get_tokenizer_group(self, *args, **kwargs):
return self.tokenizer
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()
import os
from typing import Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
class GPUExecutor:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
self.worker = self._create_worker()
self.worker.initialize()
self.worker.load_model()
def _create_worker(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Worker:
"""Return worker init args for a given rank."""
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return Worker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
observability_config=self.observability_config,
)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.worker.initialize_cache(num_gpu_blocks)
self.worker.compile_or_warm_up_model()
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
output = self.worker.execute_model(scheduler_output)
return output
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
@dataclass
class SamplerOutput:
# [num_reqs]
sampled_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1]
logprobs: Optional[torch.Tensor]
# TODO: Support prompt logprobs.
prompt_logprob_token_ids: Optional[torch.Tensor]
prompt_logprobs: Optional[torch.Tensor]
@dataclass
class ModelRunnerOutput:
# [num_reqs]
req_ids: List[str]
# req_id -> index
req_id_to_index: Dict[str, int]
# [num_reqs]
sampled_token_ids_cpu: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids_cpu: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1]
logprobs_cpu: Optional[torch.Tensor]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment