"vllm/vscode:/vscode.git/clone" did not exist on "7439b2056adb22894f6e56c629f4ffc275d1e63d"
Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
...@@ -53,8 +53,8 @@ def create_sequence_group_output( ...@@ -53,8 +53,8 @@ def create_sequence_group_output(
token_id_logprob_rank: int, token_id_logprob_rank: int,
token_id_logprob: float, token_id_logprob: float,
seq_id: SeqId, seq_id: SeqId,
topk_token_ids: List[int], topk_token_ids: List[Optional[int]],
topk_logprobs: List[float], topk_logprobs: List[Optional[float]],
) -> CompletionSequenceGroupOutput: ) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results. """Create a SequenceGroupOutput given the sampling results.
...@@ -68,7 +68,7 @@ def create_sequence_group_output( ...@@ -68,7 +68,7 @@ def create_sequence_group_output(
""" """
# vLLM logprobs always include the sampled token. In addition, the user may # vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs). # request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = { logprobs: Dict[Optional[int], Logprob] = {
token_id: Logprob( token_id: Logprob(
logprob=token_id_logprob, logprob=token_id_logprob,
rank=token_id_logprob_rank, rank=token_id_logprob_rank,
......
...@@ -5,10 +5,10 @@ from transformers import GenerationConfig, PretrainedConfig ...@@ -5,10 +5,10 @@ from transformers import GenerationConfig, PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, from vllm.transformers_utils.configs import (ChameleonConfig, ChatGLMConfig,
JAISConfig, MedusaConfig, DbrxConfig, JAISConfig,
MLPSpeculatorConfig, MPTConfig, MedusaConfig, MLPSpeculatorConfig,
RWConfig) MPTConfig, RWConfig)
if VLLM_USE_MODELSCOPE: if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig from modelscope import AutoConfig
...@@ -18,6 +18,7 @@ else: ...@@ -18,6 +18,7 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chameleon": ChameleonConfig,
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
"mpt": MPTConfig, "mpt": MPTConfig,
......
from vllm.transformers_utils.configs.chameleon import (ChameleonConfig,
ChameleonVQVAEConfig)
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and # RWConfig is for the original tiiuae/falcon-40b(-instruct) and
...@@ -10,6 +12,8 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig ...@@ -10,6 +12,8 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
__all__ = [ __all__ = [
"ChameleonConfig",
"ChameleonVQVAEConfig",
"ChatGLMConfig", "ChatGLMConfig",
"DbrxConfig", "DbrxConfig",
"MPTConfig", "MPTConfig",
......
from typing import List, Optional
from transformers import PretrainedConfig
#TODO (ywang96): Remove this file and import it from
# transformers once the new release with Chameleon support
# is available.
class ChameleonConfig(PretrainedConfig):
model_type = "chameleon"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=65536,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
model_parallel_size=1,
swin_norm=False,
vq_config=None,
vocabulary_map=None,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.mlp_bias = mlp_bias
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.model_parallel_size = model_parallel_size
self.swin_norm = swin_norm
if vq_config is None:
vq_config = {}
self.vq_config = ChameleonVQVAEConfig(**vq_config)
self.vocabulary_map = vocabulary_map
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling,
dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, "
f"`type` and `factor`, got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in [
"linear", "dynamic"
]:
raise ValueError(
"`rope_scaling`'s type field must be one of ['linear', "
f"'dynamic'], got {rope_scaling_type}")
if rope_scaling_factor is None or not isinstance(
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(
"`rope_scaling`'s factor field must be a float > 1, "
f"got {rope_scaling_factor}")
class ChameleonVQVAEConfig(PretrainedConfig):
model_type = "chameleon_vqgan"
def __init__(
self,
embed_dim: int = 256,
num_embeddings: int = 8192,
double_latent: bool = False,
latent_channels: int = 256,
resolution: int = 512,
in_channels: int = 3,
base_channels: int = 128,
channel_multiplier: List[int] = [1, 1, 2, 2, 4], #noqa
num_res_blocks: int = 2,
attn_resolutions: Optional[List[int]] = None,
dropout: float = 0.0,
attn_type: str = "vanilla",
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_embeddings = num_embeddings
self.double_latent = double_latent
self.latent_channels = latent_channels
self.resolution = resolution
self.in_channels = in_channels
self.base_channels = base_channels
self.channel_multiplier = channel_multiplier
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.dropout = dropout
self.attn_type = attn_type
self.initializer_range = initializer_range
...@@ -165,6 +165,12 @@ class Detokenizer: ...@@ -165,6 +165,12 @@ class Detokenizer:
return len(new_decoded_token_text) return len(new_decoded_token_text)
def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""
def _convert_tokens_to_string_with_added_encoders( def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str], output_tokens: List[str],
...@@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens( ...@@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens(
read_offset = len(new_tokens) read_offset = len(new_tokens)
prefix_offset = max( prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens)
return new_tokens, prefix_offset, read_offset return new_tokens, prefix_offset, read_offset
......
...@@ -88,6 +88,9 @@ def get_tokenizer( ...@@ -88,6 +88,9 @@ def get_tokenizer(
"Cannot use the fast tokenizer in slow tokenizer mode.") "Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False kwargs["use_fast"] = False
if "truncation_side" not in kwargs:
kwargs["truncation_side"] = "left"
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, tokenizer_name,
...@@ -134,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, ...@@ -134,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
if lora_request is None: if lora_request is None:
return None return None
try: try:
tokenizer = get_tokenizer(lora_request.lora_local_path, *args, tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
**kwargs)
except OSError as e: except OSError as e:
# No tokenizer was found in the LoRA folder, # No tokenizer was found in the LoRA folder,
# use base model tokenizer # use base model tokenizer
logger.warning( logger.warning(
"No tokenizer found in %s, using base model tokenizer instead. " "No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)", lora_request.lora_local_path, e) "(Exception: %s)", lora_request.lora_path, e)
tokenizer = None tokenizer = None
return tokenizer return tokenizer
......
from typing import Optional from typing import Optional, Type
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray from vllm.executor.ray_utils import ray
...@@ -16,18 +16,22 @@ else: ...@@ -16,18 +16,22 @@ else:
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup: **init_kwargs) -> BaseTokenizerGroup:
tokenizer_cls: Type[BaseTokenizerGroup]
if tokenizer_pool_config is None: if tokenizer_pool_config is None:
return TokenizerGroup(**init_kwargs) tokenizer_cls = TokenizerGroup
if tokenizer_pool_config.pool_type == "ray": elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
tokenizer_pool_config.pool_type, BaseTokenizerGroup):
tokenizer_cls = tokenizer_pool_config.pool_type
elif tokenizer_pool_config.pool_type == "ray":
if RayTokenizerGroupPool is None: if RayTokenizerGroupPool is None:
raise ImportError( raise ImportError(
"RayTokenizerGroupPool is not available. Please install " "RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool.") "the ray package to use the Ray tokenizer group pool.")
return RayTokenizerGroupPool.from_config(tokenizer_pool_config, tokenizer_cls = RayTokenizerGroupPool
**init_kwargs)
else: else:
raise ValueError( raise ValueError(
f"Unknown pool type: {tokenizer_pool_config.pool_type}") f"Unknown pool type: {tokenizer_pool_config.pool_type}")
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"] __all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
...@@ -3,12 +3,19 @@ from typing import List, Optional ...@@ -3,12 +3,19 @@ from typing import List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
class BaseTokenizerGroup(ABC): class BaseTokenizerGroup(ABC):
"""A group of tokenizers that can be used for LoRA adapters.""" """A group of tokenizers that can be used for LoRA adapters."""
@classmethod
@abstractmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "BaseTokenizerGroup":
pass
@abstractmethod @abstractmethod
def ping(self) -> bool: def ping(self) -> bool:
"""Check if the tokenizer group is alive.""" """Check if the tokenizer group is alive."""
......
...@@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
_worker_cls = TokenizerGroup _worker_cls = TokenizerGroup
@classmethod @classmethod
def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "RayTokenizerGroupPool": **init_kwargs) -> "RayTokenizerGroupPool":
if not tokenizer_pool_config:
raise ValueError("tokenizer_pool_config must not be None.")
ray_actor_options = (tokenizer_pool_config.extra_config or { ray_actor_options = (tokenizer_pool_config.extra_config or {
"num_cpus": 0 "num_cpus": 0
}) })
......
...@@ -2,6 +2,7 @@ from typing import List, Optional ...@@ -2,6 +2,7 @@ from typing import List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer, from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
get_lora_tokenizer_async, get_lora_tokenizer_async,
...@@ -24,6 +25,11 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -24,6 +25,11 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers = LRUCache[PreTrainedTokenizer]( self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None capacity=max_num_seqs) if enable_lora else None
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "TokenizerGroup":
return cls(**init_kwargs)
def ping(self) -> bool: def ping(self) -> bool:
"""Check if the tokenizer group is alive.""" """Check if the tokenizer group is alive."""
return True return True
......
...@@ -16,12 +16,12 @@ import requests ...@@ -16,12 +16,12 @@ import requests
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
_config_home = envs.VLLM_CONFIG_ROOT _config_home = envs.VLLM_CONFIG_ROOT
_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json") _USAGE_STATS_JSON_PATH = os.path.join(_config_home, "usage_stats.json")
_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, _USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, "do_not_track")
"vllm/do_not_track")
_USAGE_STATS_ENABLED = None _USAGE_STATS_ENABLED = None
_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER _USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER
...@@ -205,7 +205,8 @@ class UsageMessage: ...@@ -205,7 +205,8 @@ class UsageMessage:
def _send_to_server(self, data): def _send_to_server(self, data):
try: try:
requests.post(_USAGE_STATS_SERVER, json=data) global_http_client = global_http_connection.get_sync_client()
global_http_client.post(_USAGE_STATS_SERVER, json=data)
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
# silently ignore unless we are using debug log # silently ignore unless we are using debug log
logging.debug("Failed to send usage data to server") logging.debug("Failed to send usage data to server")
......
...@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, ...@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Union) Union)
import numpy as np import numpy as np
import numpy.typing as npt
import psutil import psutil
import torch import torch
import torch.types import torch.types
...@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
# "fp8_e5m2": torch.uint8, # "fp8_e5m2": torch.uint8,
} }
TORCH_DTYPE_TO_NUMPY_DTYPE = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int32: np.int32,
torch.int64: np.int64,
}
P = ParamSpec('P') P = ParamSpec('P')
K = TypeVar("K") K = TypeVar("K")
T = TypeVar("T") T = TypeVar("T")
...@@ -415,9 +425,10 @@ def init_kmp_env(): ...@@ -415,9 +425,10 @@ def init_kmp_env():
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
def chunk_list(lst: List[T], chunk_size: int) -> List[List[T]]: def chunk_list(lst: List[T], chunk_size: int):
"""Yield successive chunk_size chunks from lst.""" """Yield successive chunk_size chunks from lst."""
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] for i in range(0, len(lst), chunk_size):
yield lst[i:i + chunk_size]
def cdiv(a: int, b: int) -> int: def cdiv(a: int, b: int) -> int:
...@@ -616,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]: ...@@ -616,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f"(e.g., 1, 2, 3). Given input: {s}") from e f"(e.g., 1, 2, 3). Given input: {s}") from e
def make_tensor_with_pad( def make_ndarray_with_pad(
x: List[List[int]], x: List[List[T]],
max_len: int, pad: T,
pad: int, dtype: npt.DTypeLike,
dtype: torch.dtype, *,
device: Optional[Union[str, torch.device]], max_len: Optional[int] = None,
) -> torch.Tensor: ) -> npt.NDArray:
"""Make a padded tensor of a 2D inputs. """
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches The padding is applied to the end of each inner list until it reaches
`max_len`. `max_len`.
""" """
padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad if max_len is None:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
for ind, blocktb in enumerate(x): for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len assert len(blocktb) <= max_len
padded_x[ind, :len(blocktb)] = blocktb padded_x[ind, :len(blocktb)] = blocktb
return torch.tensor(padded_x, dtype=dtype, device=device)
return padded_x
def make_tensor_with_pad(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
def async_tensor_h2d( def async_tensor_h2d(
...@@ -677,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]], ...@@ -677,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
return dict(merged_dict) return dict(merged_dict)
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
def init_cached_hf_modules() -> None: def init_cached_hf_modules() -> None:
""" """
Lazy initialization of the Hugging Face modules. Lazy initialization of the Hugging Face modules.
...@@ -939,3 +986,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser): ...@@ -939,3 +986,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
processed_args.append(arg) processed_args.append(arg)
return super().parse_args(processed_args, namespace) return super().parse_args(processed_args, namespace)
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
**kwargs):
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)
...@@ -9,4 +9,4 @@ except Exception as e: ...@@ -9,4 +9,4 @@ except Exception as e:
stacklevel=2) stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER" __commit__ = "COMMIT_HASH_PLACEHOLDER"
__version__ = "0.5.2" __version__ = "0.5.3.post1"
...@@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
block_tables, block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=self.device, device=self.device,
......
...@@ -2,7 +2,8 @@ import dataclasses ...@@ -2,7 +2,8 @@ import dataclasses
import gc import gc
import time import time
import warnings import warnings
from collections import defaultdict import weakref
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union) Tuple, Type, TypeVar, Union)
...@@ -38,6 +39,7 @@ from vllm.model_executor.model_loader import get_model ...@@ -38,6 +39,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora, from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision) supports_vision)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs) MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.layers import PromptAdapterMapping
...@@ -47,10 +49,11 @@ from vllm.prompt_adapter.worker_manager import ( ...@@ -47,10 +49,11 @@ from vllm.prompt_adapter.worker_manager import (
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists,
is_pin_memory_available, make_tensor_with_pad) get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
...@@ -74,7 +77,7 @@ _NUM_WARMUP_ITERS = 2 ...@@ -74,7 +77,7 @@ _NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
@dataclasses.dataclass(frozen=True) @dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase): class ModelInputForGPU(ModelRunnerInputBase):
""" """
This base class contains metadata needed for the base model forward pass This base class contains metadata needed for the base model forward pass
...@@ -124,7 +127,7 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -124,7 +127,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
return cls(**tensor_dict) return cls(**tensor_dict)
@dataclasses.dataclass(frozen=True) @dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
""" """
Used by the ModelRunner. Used by the ModelRunner.
...@@ -165,6 +168,425 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -165,6 +168,425 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
return cls(**tensor_dict) return cls(**tensor_dict)
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""
@dataclass
class InterDataForSeqGroup:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id: str
seq_ids: List[int]
is_prompt: bool
block_tables: Optional[Dict[int, List[int]]]
computed_block_nums: List[int]
n_seqs: int = 0
# Input tokens and positions.
input_tokens: List[List[int]] = field(default_factory=list)
input_positions: List[List[int]] = field(default_factory=list)
# The sequence length (may be capped to the sliding window).
seq_lens: List[int] = field(default_factory=list)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: List[int] = field(default_factory=list)
# The query length.
query_lens: List[int] = field(default_factory=list)
# The number of tokens that are already computed.
context_lens: List[int] = field(default_factory=list)
# The current sliding window block.
curr_sliding_window_blocks: List[int] = field(default_factory=list)
# LoRA inputs.
lora_index_mapping: List[List[int]] = field(default_factory=list)
lora_prompt_mapping: List[List[int]] = field(default_factory=list)
lora_requests: Set[LoRARequest] = field(default_factory=set)
# Prompt adapter inputs.
prompt_adapter_index_mapping: List[int] = field(default_factory=list)
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
prompt_adapter_request: Optional[PromptAdapterRequest] = None
# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False
def __post_init__(self):
self.n_seqs = len(self.seq_ids)
self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)]
self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs
self.query_lens = [0] * self.n_seqs
self.context_lens = [0] * self.n_seqs
self.curr_sliding_window_blocks = [0] * self.n_seqs
self.lora_index_mapping = [[] for _ in range(self.n_seqs)]
self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)]
def __init__(self,
runner: "GPUModelRunnerBase",
finished_requests_ids: Optional[List[str]] = None):
super().__init__()
# Compute functions for each sequence in a sequence group.
# WARNING: The order of the functions matters!
self.per_seq_compute_fns = [
self._compute_lens,
self._compute_for_prefix_cache_hit,
self._compute_for_sliding_window,
self._compute_lora_input,
]
# Compute functions for each sequence group.
# WARNING: The order of the functions matters!
self.per_seq_group_compute_fns = [
self._compute_prompt_adapter_input,
self._compute_multi_modal_input,
]
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.scheduler_config = self.runner.scheduler_config
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.enable_lora = self.runner.lora_config is not None
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None)
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.finished_requests_ids = finished_requests_ids
self.decode_only = True
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
# Attention metadata inputs.
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
weakref.proxy(self))
# Engine/Model configurations.
self.chunked_prefill_enabled = (
self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled)
if self.sliding_window is not None:
self.sliding_window_blocks = (
self.sliding_window + self.block_size - 1) // self.block_size
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens
for the given sequence data.
"""
seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
token_chunk_size = seq_group_metadata.token_chunk_size
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len = seq_data.get_len()
if inter_data.is_prompt:
context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len = seq_len - 1
seq_len = min(seq_len, context_len + token_chunk_size)
# Compute tokens.
if inter_data.is_prompt:
tokens = seq_data.get_token_ids()[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx] = tokens
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
def _compute_for_prefix_cache_hit(
self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Check if hit prefix cache (i.e., some blocks are already computed).
If hit, update input tokens and positions to only compute the
remaining blocks.
"""
computed_block_nums = inter_data.computed_block_nums
# Note that prefix caching does not support sliding window.
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")
# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]
inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Update seq_len and curr_sliding_window_block for the given
sequence data (only required by decoding) if sliding window is enabled.
"""
curr_sliding_window_block = 0
sliding_seq_len = inter_data.seq_lens[seq_idx]
if not inter_data.is_prompt and self.sliding_window is not None:
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
curr_sliding_window_block = self.sliding_window_blocks
if self.scheduler_config.use_v2_block_manager:
# number of elements in last block
suff_len = inter_data.seq_lens[seq_idx] % self.block_size
sliding_seq_len = min(
inter_data.seq_lens[seq_idx],
self.block_aligned_sliding_window + suff_len)
if suff_len > 0:
curr_sliding_window_block += 1
else:
sliding_seq_len = min(inter_data.seq_lens[seq_idx],
self.sliding_window)
inter_data.curr_sliding_window_blocks[
seq_idx] = curr_sliding_window_block
inter_data.seq_lens[seq_idx] = sliding_seq_len
def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
if not self.enable_lora:
return
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
inter_data.lora_requests.add(seq_group_metadata.lora_request)
query_len = inter_data.query_lens[seq_idx]
inter_data.lora_index_mapping.append([lora_id] * query_len)
inter_data.lora_prompt_mapping.append(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs is not None
else 1))
def _compute_prompt_adapter_input(
self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If prompt adapter is enabled, compute index and prompt mapping.
"""
# Note that when is_prompt=True, we expect only one sequence
# in the group.
if not self.enable_prompt_adapter:
return
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if prompt_adapter_id <= 0 or not inter_data.is_prompt:
return
# We expect only one sequence in the group when is_prompt=True.
assert inter_data.n_seqs == 1
query_len = inter_data.query_lens[0]
inter_data.prompt_adapter_request = (
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
inter_data.prompt_adapter_index_mapping = [
prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs else 1)
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If multi-modal data is given, add it to the input."""
mm_data = seq_group_metadata.multi_modal_data
if not mm_data:
return
mm_kwargs = self.multi_modal_input_mapper(mm_data)
inter_data.multi_modal_inputs = mm_kwargs
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
"""Add a sequence group to the builder."""
seq_ids = list(seq_group_metadata.seq_data.keys())
n_seqs = len(seq_ids)
is_prompt = seq_group_metadata.is_prompt
if is_prompt:
assert n_seqs == 1
self.decode_only = False
inter_data = self.InterDataForSeqGroup(
request_id=seq_group_metadata.request_id,
seq_ids=seq_ids,
is_prompt=is_prompt,
block_tables=seq_group_metadata.block_tables,
computed_block_nums=seq_group_metadata.computed_block_nums)
self.inter_data_list.append(inter_data)
for seq_idx in range(n_seqs):
for per_seq_fn in self.per_seq_compute_fns:
per_seq_fn(inter_data, seq_idx, seq_group_metadata)
for per_seq_group_fn in self.per_seq_group_compute_fns:
per_seq_group_fn(inter_data, seq_group_metadata)
def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens = flatten_2d_lists([
flatten_2d_lists(inter_data.input_tokens)
for inter_data in self.inter_data_list
])
if not input_tokens:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return self.model_input_cls()
input_positions = flatten_2d_lists([
flatten_2d_lists(inter_data.input_positions)
for inter_data in self.inter_data_list
])
seq_lens = []
max_decode_seq_len = 0
for inter_data in self.inter_data_list:
seq_lens.extend(inter_data.seq_lens)
if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens))
query_lens = flatten_2d_lists(
[inter_data.query_lens for inter_data in self.inter_data_list])
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
request_ids_to_seq_ids = {
data.request_id: data.seq_ids
for data in self.inter_data_list
}
batch_size = len(input_tokens)
use_captured_graph = (
self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size = -1
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
batch_size = graph_batch_size
# Tokens and positions.
input_tokens.extend([0] * cuda_graph_pad_size)
input_positions.extend([0] * cuda_graph_pad_size)
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.runner.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.runner.device)
# Sequence and query lengths.
seq_lens.extend([1] * cuda_graph_pad_size)
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
# LoRA data.
lora_requests = set()
lora_mapping = None
if self.enable_lora:
lora_requests = set(r for data in self.inter_data_list
for r in data.lora_requests)
lora_index_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_index_mapping)
for inter_data in self.inter_data_list
])
lora_index_mapping.extend([0] * cuda_graph_pad_size)
lora_prompt_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_prompt_mapping)
for inter_data in self.inter_data_list
])
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
# Prompt adapter data.
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
prompt_adapter_mapping = None
if self.enable_prompt_adapter:
prompt_adapter_requests = set(
data.prompt_adapter_request for data in self.inter_data_list
if data.prompt_adapter_request is not None)
prompt_adapter_index_mapping = flatten_2d_lists([
inter_data.prompt_adapter_index_mapping
for inter_data in self.inter_data_list
])
prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
prompt_adapter_prompt_mapping = flatten_2d_lists([
inter_data.prompt_adapter_prompt_mapping
for inter_data in self.inter_data_list
])
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
# Multi-modal data.
multi_modal_inputs_list = [
data.multi_modal_inputs for data in self.inter_data_list
if data.multi_modal_inputs is not None
]
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.runner.device)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=self.finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests)
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
""" """
Helper class for shared methods between GPU model runners. Helper class for shared methods between GPU model runners.
...@@ -251,7 +673,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -251,7 +673,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.flashinfer_prefill_workspace_buffer = None self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None self.flashinfer_prefill_wrapper = None
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
self.model = get_model(model_config=self.model_config, self.model = get_model(model_config=self.model_config,
device_config=self.device_config, device_config=self.device_config,
...@@ -368,464 +794,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -368,464 +794,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs. If cuda graph is required, this API automatically pads inputs.
""" """
input_tokens: List[int] = [] builder = ModelInputForGPUBuilder(weakref.proxy(self),
input_positions: List[int] = [] finished_requests_ids)
slot_mapping: List[int] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_adapter_index_mapping: List[int] = []
prompt_adapter_prompt_mapping: List[int] = []
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
seq_lens: List[int] = []
prefill_seq_lens: List[int] = []
decode_seq_lens: List[int] = []
context_lens: List[int] = []
query_lens: List[int] = []
block_tables: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
decode_only = True
num_prefills = 0
num_prefill_tokens = 0
num_decode_tokens = 0
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices: List[int] = []
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len: List[int] = []
if len(seq_group_metadata_list) == 0:
return self._model_input_cls()
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window + self.block_size -
1) // self.block_size
block_aligned_sliding_window = \
sliding_window_blocks * self.block_size
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys()) builder.add_seq_group(seq_group_metadata)
is_prompt = seq_group_metadata.is_prompt return builder.build() # type: ignore
for seq_id in seq_ids:
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
seq_data = seq_group_metadata.seq_data[seq_id]
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len = seq_data.get_len() - 1
seq_len = min(
seq_data.get_len(),
context_len + seq_group_metadata.token_chunk_size)
if is_prompt:
tokens = seq_data.get_token_ids()[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and is_prompt)
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
curr_sliding_window_blocks = None
sliding_seq_len = seq_len
sliding_context_len = context_len
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if (self.sliding_window is not None and not is_prompt):
curr_sliding_window_blocks = sliding_window_blocks
if self.scheduler_config.use_v2_block_manager:
# number of elements in last block
suff_len = seq_len % self.block_size
sliding_seq_len = min(
seq_len, block_aligned_sliding_window + suff_len)
if suff_len > 0:
curr_sliding_window_blocks += 1
else:
sliding_seq_len = min(seq_len, self.sliding_window)
sliding_context_len = sliding_seq_len - 1
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
tokens = tokens[context_len:]
# need to think what to set it to when we have both sliding
# window and prefix caching...
assert self.sliding_window is None, \
"Prefix caching is not supported with sliding window"
sliding_context_len = context_len
if self.attn_backend.get_name() == "flash-attn":
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# TODO(woosuk): This is a temporary fix. We should
# provide a unified interface for different backends.
block_table = seq_group_metadata.block_tables[seq_id]
else:
block_table = computed_block_nums
elif (self.scheduler_config.chunked_prefill_enabled
or not is_prompt):
if seq_group_metadata.block_tables is not None:
# chunked prefill or decode
block_table = seq_group_metadata.block_tables[seq_id]
if curr_sliding_window_blocks is not None:
block_table = block_table[
-curr_sliding_window_blocks:]
else:
# Only happens when memory profiling runs.
block_table = []
else:
# Prefill without chunked prefill or memory profiling.
block_table = []
block_tables.append(block_table)
seq_lens.append(sliding_seq_len)
context_lens.append(sliding_context_len)
query_len = sliding_seq_len - sliding_context_len
query_lens.append(query_len)
input_tokens.extend(tokens)
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if is_prompt:
assert len(seq_ids) == 1
num_prefills += 1
num_prefill_tokens += len(tokens)
decode_only = False
prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
num_decode_tokens += query_len
decode_seq_lens.append(sliding_seq_len)
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * query_len
lora_prompt_mapping.extend(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
is not None else 1))
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if prompt_adapter_id > 0 and is_prompt:
prompt_adapter_requests.add(
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.\
prompt_adapter_num_virtual_tokens
pm = [prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
prompt_adapter_index_mapping += pm
prompt_adapter_prompt_mapping.extend(
[prompt_adapter_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables)
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
if is_prompt:
assert self.scheduler_config.use_v2_block_manager \
or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# It is an optimization. When it is decoding, it is always
# 0. When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - self.sliding_window)
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Prepare input tensors for flashinfer
if self.attn_backend.get_name() == "flashinfer":
seq_len = seq_data.get_len()
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
paged_kv_indices.extend(block_table[:block_table_bound])
paged_kv_indptr.append(paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
batch_size = len(input_tokens)
max_query_len = max(query_lens)
max_prefill_seq_len = max(prefill_seq_lens, default=0)
max_decode_seq_len = max(decode_seq_lens, default=0)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
use_captured_graph = (
decode_only and not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.max_seq_len_to_capture)
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
for _ in range(graph_batch_size - batch_size):
input_tokens.append(0)
input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID)
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
prompt_adapter_index_mapping.append(0)
if self.attn_backend.get_name() == "flashinfer":
last_paged_kv_indptr = paged_kv_indptr[-1]
paged_kv_indptr.append(last_paged_kv_indptr)
paged_kv_last_page_len.append(0)
batch_size = graph_batch_size
num_decode_tokens = batch_size
if use_captured_graph:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.graph_block_tables[:batch_size]
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=self.device)
else:
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
logits_soft_cap = getattr(self.model_config.hf_config,
'attn_logit_softcapping', None)
if logits_soft_cap is not None and self.attn_backend.get_name(
) != "flashinfer":
raise ValueError("Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
if self.attn_backend.get_name() == "flashinfer":
if len(paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
device='cpu',
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
device='cpu',
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
paged_kv_last_page_len, device='cpu', dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
num_qo_heads=self.model_config.get_num_attention_heads(
self.parallel_config),
num_kv_heads=self.model_config.get_num_kv_heads(
self.parallel_config),
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=seq_start_loc,
query_start_loc=query_start_loc,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph,
logits_soft_cap=logits_soft_cap)
else:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
else:
prompt_adapter_mapping = None
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
request_ids_to_seq_ids = {
seq_group_metadata.request_id:
list(seq_group_metadata.seq_data.keys())
for seq_group_metadata in seq_group_metadata_list
}
return self._model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests,
)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
...@@ -847,7 +820,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -847,7 +820,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dummy_lora_request = LoRARequest( dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}", lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id, lora_int_id=lora_id,
lora_local_path="/not/a/real/path", lora_path="/not/a/real/path",
) )
self.lora_manager.add_dummy_lora(dummy_lora_request, self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK) rank=LORA_WARMUP_RANK)
...@@ -1549,15 +1522,3 @@ def _get_graph_batch_size(batch_size: int) -> int: ...@@ -1549,15 +1522,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
else: else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if block_tables is None:
return True
if isinstance(block_tables, dict) and all(
value is None for value in block_tables.values()):
return True
return False
...@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, ...@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch import torch
from vllm.platforms import current_platform
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -113,6 +114,21 @@ class ModelRunnerInputBase(ABC): ...@@ -113,6 +114,21 @@ class ModelRunnerInputBase(ABC):
raise NotImplementedError raise NotImplementedError
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""
@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""
raise NotImplementedError
@abstractmethod
def build(self, *args, **kwargs) -> T:
"""Build metadata with on-device tensors."""
raise NotImplementedError
class ModelRunnerBase(ABC, Generic[T]): class ModelRunnerBase(ABC, Generic[T]):
""" """
Model runner interface that abstracts a particular hardware and/or type of Model runner interface that abstracts a particular hardware and/or type of
...@@ -148,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -148,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
""" """
raise NotImplementedError raise NotImplementedError
@torch.inference_mode() @current_platform.inference_mode()
def execute_model( def execute_model(
self, self,
model_input: T, model_input: T,
......
...@@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
max_seq_len = max(seq_lens) max_seq_len = max(seq_lens)
assert max_seq_len > 0 assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
max_seq_len,
pad=0, pad=0,
max_len=max_seq_len,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = make_tensor_with_pad(input_positions, input_positions = make_tensor_with_pad(input_positions,
max_seq_len,
pad=0, pad=0,
max_len=max_seq_len,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_block_ids = torch.tensor(input_block_ids, input_block_ids = torch.tensor(input_block_ids,
...@@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids.append(block_table[0]) input_block_ids.append(block_table[0])
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
max_len=1,
pad=0, pad=0,
max_len=1,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = make_tensor_with_pad(input_positions, input_positions = make_tensor_with_pad(input_positions,
max_len=1,
pad=0, pad=0,
max_len=1,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
......
import time import time
from typing import List, Mapping, Optional, Tuple from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import numpy as np import numpy as np
import torch import torch
...@@ -12,12 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ...@@ -12,12 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
MultiModalInputs) Logprob, SamplerOutput, SequenceGroupMetadata,
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
SequenceOutput) SequenceOutput)
from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -29,7 +34,44 @@ _ENABLE_TOP_P = False ...@@ -29,7 +34,44 @@ _ENABLE_TOP_P = False
_MAX_NUM_SAMPLES = 128 _MAX_NUM_SAMPLES = 128
class TPUModelRunner: @dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
token_ids: torch.Tensor
position_ids: torch.Tensor
attn_metadata: AttentionMetadata
input_lens: torch.Tensor
t: torch.Tensor
p: torch.Tensor
num_samples: int
best_of: List[int]
seq_groups: List[List[int]]
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"token_ids": self.token_ids,
"position_ids": self.position_ids,
"input_lens": self.input_lens,
"t": self.t,
"p": self.p,
"num_samples": self.num_samples,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForTPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForTPU":
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def __init__( def __init__(
self, self,
...@@ -68,10 +110,6 @@ class TPUModelRunner: ...@@ -68,10 +110,6 @@ class TPUModelRunner:
False, False,
) )
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
def load_model(self) -> None: def load_model(self) -> None:
self.device = self.device_config.device self.device = self.device_config.device
...@@ -85,6 +123,7 @@ class TPUModelRunner: ...@@ -85,6 +123,7 @@ class TPUModelRunner:
multimodal_config=self.multimodal_config, multimodal_config=self.multimodal_config,
lora_config=None, lora_config=None,
) )
model = model.eval()
xm.wait_device_ops() xm.wait_device_ops()
model = ModelWrapper(model) model = ModelWrapper(model)
...@@ -153,8 +192,8 @@ class TPUModelRunner: ...@@ -153,8 +192,8 @@ class TPUModelRunner:
# Dummy run. # Dummy run.
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
self.model(token_ids, position_ids, kv_caches, attn_metadata, self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
input_lens, None, t, p, num_samples) num_samples, kv_caches)
def warmup_model( def warmup_model(
self, self,
...@@ -183,7 +222,7 @@ class TPUModelRunner: ...@@ -183,7 +222,7 @@ class TPUModelRunner:
# Decode # Decode
start = time.time() start = time.time()
seq_len = 1 seq_len = 1
batch_size = 1 batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True: while True:
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
xm.wait_device_ops() xm.wait_device_ops()
...@@ -199,14 +238,12 @@ class TPUModelRunner: ...@@ -199,14 +238,12 @@ class TPUModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[int] = []
input_positions: List[List[int]] = [] input_positions: List[int] = []
prompt_lens: List[int] = [] prompt_lens: List[int] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
...@@ -220,78 +257,62 @@ class TPUModelRunner: ...@@ -220,78 +257,62 @@ class TPUModelRunner:
prompt_len = len(prompt_tokens) prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
input_tokens.append(prompt_tokens) input_tokens.extend(prompt_tokens)
input_positions.append(list(range(prompt_len))) input_positions.extend(list(range(prompt_len)))
assert seq_group_metadata.block_tables is not None assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
slot_mapping.append([])
for i in range(prompt_len): for i in range(prompt_len):
block_number = block_table[i // self.block_size] block_number = block_table[i // self.block_size]
block_offset = i % self.block_size block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot) slot_mapping.append(slot)
mm_data = seq_group_metadata.multi_modal_data # Add paddings to EACH prompt to the smallest power of 2 that is
if mm_data: # greater than or equal to the prompt length.
mm_kwargs = self.multi_modal_input_mapper(mm_data) # We pad the seq_len to reduce the compilation overhead.
multi_modal_inputs_list.append(mm_kwargs) # We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len = _get_padded_prefill_len(prompt_len)
num_paddings = padded_prompt_len - prompt_len
input_tokens += [0] * num_paddings
input_positions += [0] * num_paddings
slot_mapping += [_PAD_SLOT_ID] * num_paddings
assert len(prompt_lens) > 0 assert len(prompt_lens) > 0
num_prefills = len(prompt_lens) num_prefills = len(prompt_lens)
num_prefill_tokens = sum(prompt_lens) input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
# Add paddings to make the shape [batch_size, max_prompt_len] where device="cpu")
# max_prompt_len is smallest power of 2 that is greater than or equal input_positions = torch.tensor(input_positions,
# to the maximum prompt length. dtype=torch.int32,
# We need the 2D input shape because the Pallas FlashAttention kernel device="cpu")
# does not support packed 1D inputs. slot_mapping = torch.tensor(slot_mapping,
# We pad the seq_len to powers of 2 to reduce the compilation overhead. dtype=torch.int64,
max_prompt_len = _get_padded_prefill_len(max(prompt_lens)) device="cpu")
input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.int32,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.int32,
device=self.device)
slot_mapping = make_tensor_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.int64,
device=self.device)
prompt_lens = torch.tensor(prompt_lens, prompt_lens = torch.tensor(prompt_lens,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device="cpu")
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used. num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
block_tables=None, block_tables=None,
context_lens=None, context_lens=None,
) )
return input_tokens, input_positions, attn_metadata, prompt_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_kwargs)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
context_lens: List[int] = [] context_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
batch_idx = 0 batch_idx = 0
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
...@@ -317,11 +338,6 @@ class TPUModelRunner: ...@@ -317,11 +338,6 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append([slot]) slot_mapping.append([slot])
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
batch_size = _get_padded_batch_size(batch_idx) batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings input_tokens = input_tokens + [[0]] * num_paddings
...@@ -331,22 +347,22 @@ class TPUModelRunner: ...@@ -331,22 +347,22 @@ class TPUModelRunner:
input_tokens = torch.tensor(input_tokens, input_tokens = torch.tensor(input_tokens,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device="cpu")
input_positions = torch.tensor(input_positions, input_positions = torch.tensor(input_positions,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device="cpu")
slot_mapping = torch.tensor(slot_mapping, slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64, dtype=torch.int64,
device=self.device) device="cpu")
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device="cpu")
block_tables = torch.tensor(self.block_tables[:batch_size], block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32, dtype=torch.int32,
device=self.device) device="cpu")
input_lens = torch.tensor([1] * batch_size, input_lens = torch.tensor([1] * batch_size,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device="cpu")
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
num_prefills=0, num_prefills=0,
num_prefill_tokens=0, num_prefill_tokens=0,
...@@ -355,12 +371,7 @@ class TPUModelRunner: ...@@ -355,12 +371,7 @@ class TPUModelRunner:
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
) )
return input_tokens, input_positions, attn_metadata, input_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, input_lens,
multi_modal_kwargs)
def _prepare_sample( def _prepare_sample(
self, self,
...@@ -412,16 +423,18 @@ class TPUModelRunner: ...@@ -412,16 +423,18 @@ class TPUModelRunner:
t += [1.0] * num_paddings t += [1.0] * num_paddings
p += [1.0] * num_paddings p += [1.0] * num_paddings
t = torch.tensor(t, dtype=torch.float32, device=self.device) t = torch.tensor(t, dtype=torch.float32, device="cpu")
p = torch.tensor(p, dtype=torch.float32, device=self.device) p = torch.tensor(p, dtype=torch.float32, device="cpu")
return t, p, best_of return t, p, best_of
def _execute_model( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], virtual_engine: int = 0,
) -> List[CompletionSequenceGroupOutput]: finished_requests_ids: Optional[List[str]] = None,
# Prepare inputs. ) -> ModelInputForTPU:
del finished_requests_ids # Unused.
assert virtual_engine == 0
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
...@@ -430,16 +443,104 @@ class TPUModelRunner: ...@@ -430,16 +443,104 @@ class TPUModelRunner:
inputs = self._prepare_prompt(seq_group_metadata_list) inputs = self._prepare_prompt(seq_group_metadata_list)
else: else:
inputs = self._prepare_decode(seq_group_metadata_list) inputs = self._prepare_decode(seq_group_metadata_list)
padded_batch_size = inputs[0].shape[0] input_tokens, input_positions, attn_metadata, input_lens = inputs
padded_batch_size = input_tokens.shape[0]
t, p, best_of = self._prepare_sample(seq_group_metadata_list, t, p, best_of = self._prepare_sample(seq_group_metadata_list,
padded_batch_size) padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
# Execute the model. seq_groups = [
next_token_ids = self.model(inputs[0], inputs[1], kv_caches, list(metadata.seq_data.keys())
*inputs[2:], t, p, num_samples) for metadata in seq_group_metadata_list
# Retrieve the outputs to CPU. ]
next_token_ids = next_token_ids.cpu().tolist() return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
input_lens, t, p, num_samples, best_of,
seq_groups)
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=self.attn_backend)
return model_input
def execute_model(
self,
model_input: ModelInputForTPU,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> List[SamplerOutput]:
assert intermediate_tensors is None
if num_steps > 1:
raise ValueError(
"TPUModelRunner does not support multi-step execution.")
def _execute_model(*args, clone: bool = False) -> torch.Tensor:
"""Move input args from CPU to device and execute the model."""
def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
if clone:
# When x is a slice of a CPU tensor, XLA may copy the whole
# original tensor to TPU instead of only copying x.
# To avoid this, we copy x after cloning.
x = x.clone()
return x.to(self.device)
new_args = []
for arg in args:
if isinstance(arg, torch.Tensor):
arg = _copy_to_device(arg)
elif isinstance(arg, AttentionMetadata):
arg.slot_mapping = _copy_to_device(arg.slot_mapping)
if getattr(arg, "block_tables", None) is not None:
arg.block_tables = _copy_to_device(arg.block_tables)
if getattr(arg, "context_lens", None) is not None:
arg.context_lens = _copy_to_device(arg.context_lens)
new_args.append(arg)
return self.model(*new_args)
num_prefills = model_input.attn_metadata.num_prefills
is_prompt = num_prefills > 0
if is_prompt:
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
next_token_ids = []
orig_slot_mapping = model_input.attn_metadata.slot_mapping
batch_size = model_input.input_lens.shape[0]
start_idx = 0
for i in range(batch_size):
# Get the actual prefill_len.
prefill_len = model_input.input_lens[i:i + 1].item()
prefill_len = _get_padded_prefill_len(prefill_len)
end_idx = start_idx + prefill_len
model_input.attn_metadata.slot_mapping = orig_slot_mapping[
None, start_idx:end_idx]
model_input.attn_metadata.num_prefills = 1
output_token_ids = _execute_model(
model_input.token_ids[None, start_idx:end_idx],
model_input.position_ids[None, start_idx:end_idx],
model_input.attn_metadata,
model_input.input_lens[i:i + 1],
model_input.t[i:i + 1],
model_input.p[i:i + 1],
model_input.num_samples,
kv_caches,
clone=True)
# Retrieve the outputs to CPU.
next_token_ids += output_token_ids.cpu().tolist()
start_idx = end_idx
else:
# Execute the model.
output_token_ids = _execute_model(
model_input.token_ids, model_input.position_ids,
model_input.attn_metadata, model_input.input_lens,
model_input.t, model_input.p, model_input.num_samples,
kv_caches)
# Retrieve the outputs to CPU.
next_token_ids = output_token_ids.cpu().tolist()
# NOTE(woosuk): Minimal code to construct the sampler outputs. # NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend # The TPU backend does not reuse the sampler, since the TPU backend
...@@ -447,13 +548,13 @@ class TPUModelRunner: ...@@ -447,13 +548,13 @@ class TPUModelRunner:
zero_logprob = Logprob(0.0) zero_logprob = Logprob(0.0)
batch_idx = 0 batch_idx = 0
sampler_outputs = [] sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list: for seq_group in model_input.seq_groups:
seq_ids = seq_group
seq_outputs = [] seq_outputs = []
seq_ids = list(seq_group_metadata.seq_data.keys())
if is_prompt: if is_prompt:
assert len(seq_ids) == 1 assert len(seq_ids) == 1
seq_id = seq_ids[0] seq_id = seq_ids[0]
for i in range(best_of[batch_idx]): for i in range(model_input.best_of[batch_idx]):
next_token_id = next_token_ids[batch_idx][i] next_token_id = next_token_ids[batch_idx][i]
seq_outputs.append( seq_outputs.append(
SequenceOutput(seq_id, next_token_id, SequenceOutput(seq_id, next_token_id,
...@@ -468,35 +569,6 @@ class TPUModelRunner: ...@@ -468,35 +569,6 @@ class TPUModelRunner:
batch_idx += 1 batch_idx += 1
sampler_outputs.append( sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None)) CompletionSequenceGroupOutput(seq_outputs, None))
return sampler_outputs
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
num_steps: int = 1,
) -> List[SamplerOutput]:
if num_steps > 1:
raise ValueError(
"TPUModelRunner does not support multi-step execution.")
assert seq_group_metadata_list is not None
assert len(seq_group_metadata_list) > 0
if seq_group_metadata_list[0].is_prompt:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
# aware of this limitation, we need to handle batch size > 1
# internally by calling the model multiple times and concatenating
# the outputs.
# FIXME(woosuk): This is a temporary hack to not change the existing
# scheduler. We need to fix this in the future.
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
sampler_outputs += self._execute_model([seq_group_metadata],
kv_caches)
else:
sampler_outputs = self._execute_model(seq_group_metadata_list,
kv_caches)
return [SamplerOutput(sampler_outputs)] return [SamplerOutput(sampler_outputs)]
...@@ -504,39 +576,37 @@ class ModelWrapper(nn.Module): ...@@ -504,39 +576,37 @@ class ModelWrapper(nn.Module):
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
super().__init__() super().__init__()
self.model = model.eval() self.model = model
def forward( def forward(
self, self,
token_ids: torch.Tensor, token_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
input_lens: torch.Tensor, input_lens: torch.Tensor,
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
t: torch.Tensor, t: torch.Tensor,
p: torch.Tensor, p: torch.Tensor,
num_samples: int, num_samples: int,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
) -> torch.Tensor: ) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token. """Executes the forward pass of the model and samples the next token.
Args: Args:
token_ids: The input token IDs of shape [batch_size, seq_len]. token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
attn_metadata: The Pallas attention metadata. attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size]. input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size]. t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size]. p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
""" """
batch_size, seq_len = token_ids.shape batch_size, seq_len = token_ids.shape
# Calculate the positions to sample from. # Calculate the positions to sample from.
base_indicies = torch.arange( start_indicies = torch.arange(
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
logits_indices = base_indicies + input_lens - 1 logits_indices = start_indicies + input_lens - 1
# FIXME(woosuk): This is a temporary hack to avoid using the existing # FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata. # sampler and sampling metadata.
...@@ -573,7 +643,6 @@ class ModelWrapper(nn.Module): ...@@ -573,7 +643,6 @@ class ModelWrapper(nn.Module):
position_ids, position_ids,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
**(multi_modal_kwargs or {}),
) )
hidden_states = hidden_states.flatten(0, 1) hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)
...@@ -598,11 +667,10 @@ def _get_padded_prefill_len(x: int) -> int: ...@@ -598,11 +667,10 @@ def _get_padded_prefill_len(x: int) -> int:
def _get_padded_batch_size(batch_size: int) -> int: def _get_padded_batch_size(batch_size: int) -> int:
if batch_size <= 2: # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
return batch_size # To meet this requirement in the simplest way, we set the minimal batch
elif batch_size <= 4: # size to 8.
return 4 if batch_size <= 8:
elif batch_size <= 8:
return 8 return 8
else: else:
return ((batch_size + 15) // 16) * 16 return ((batch_size + 15) // 16) * 16
......
...@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
logger = init_logger(__name__) logger = init_logger(__name__)
class TPUWorker(LoraNotSupportedWorkerBase): class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__( def __init__(
self, self,
...@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype] self.cache_config.cache_dtype]
self.model_runner = TPUModelRunner(model_config, self.model_runner: TPUModelRunner = TPUModelRunner(
parallel_config, model_config,
scheduler_config, parallel_config,
device_config, scheduler_config,
cache_config, device_config,
load_config, cache_config,
multimodal_config, load_config,
is_driver_worker=is_driver_worker) multimodal_config,
is_driver_worker=is_driver_worker)
def init_device(self) -> None: def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU" os.environ["PJRT_DEVICE"] = "TPU"
...@@ -98,8 +100,7 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -98,8 +100,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
# Use persistent cache to avoid XLA recompilation. # Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): This does not completely eliminate the recompilation # NOTE(woosuk): This does not completely eliminate the recompilation
# overhead because dynamo does not cache the compiled results. # overhead because dynamo does not cache the compiled results.
xr.initialize_cache(os.path.expanduser(envs.VLLM_XLA_CACHE_PATH), xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, readonly=False)
readonly=False)
def load_model(self): def load_model(self):
self.model_runner.load_model() self.model_runner.load_model()
...@@ -197,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -197,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
dtype_size = get_dtype_size(self.cache_dtype) dtype_size = get_dtype_size(self.cache_dtype)
return dtype_size * total return dtype_size * total
def execute_model( @property
def do_metadata_broadcast(self) -> bool:
# TODO(woosuk): Support TP.
return False
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return [self.tpu_cache]
def prepare_worker_input(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None, execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]: ) -> WorkerInput:
if not self.is_driver_worker: virtual_engine = execute_model_req.virtual_engine
self._execute_model_non_driver() num_seq_groups = len(execute_model_req.seq_group_metadata_list)
return [] blocks_to_swap_in = _make_src_to_dst(
assert execute_model_req is not None execute_model_req.blocks_to_swap_in, "cpu", self.device)
# Issue cache operations. blocks_to_swap_out = _make_src_to_dst(
self.cache_swap( execute_model_req.blocks_to_swap_out, self.device, "cpu")
execute_model_req.blocks_to_swap_in, blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
execute_model_req.blocks_to_swap_out, self.device, self.device)
execute_model_req.blocks_to_copy, return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
) )
# Run the model.
seq_group_metadata_list = execute_model_req.seq_group_metadata_list def execute_worker(self, worker_input: WorkerInput) -> None:
assert len(seq_group_metadata_list) > 0 virtual_engine = worker_input.virtual_engine
output = self.model_runner.execute_model(seq_group_metadata_list, assert virtual_engine == 0
self.tpu_cache)
return output
def cache_swap(
self,
blocks_to_swap_in: List[Tuple[int, int]],
blocks_to_swap_out: List[Tuple[int, int]],
blocks_to_copy: List[Tuple[int, int]],
) -> None:
attn_backend = self.model_runner.attn_backend attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
if blocks_to_swap_in: # Issue cache operations.
# Swap from CPU to TPU. if worker_input.blocks_to_swap_in is not None:
src_indices, dst_indices = _make_src_to_dst( src_indices, dst_indices = worker_input.blocks_to_swap_in
blocks_to_swap_in, "cpu", self.device) if src_indices.numel() > 0:
for i in range(num_layers): # Swap from CPU to TPU.
tpu_k_cache, tpu_v_cache = self.tpu_cache[i] for i in range(num_layers):
cpu_k_cache, cpu_v_cache = self.cpu_cache[i] tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device) cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
v = cpu_v_cache[:, src_indices].to(self.device) k = cpu_k_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
if blocks_to_swap_out:
# Swap from TPU to CPU. if worker_input.blocks_to_swap_out is not None:
src_indices, dst_indices = _make_src_to_dst( src_indices, dst_indices = worker_input.blocks_to_swap_out
blocks_to_swap_out, self.device, "cpu") if src_indices.numel() > 0:
for i in range(num_layers): # Swap from TPU to CPU.
tpu_k_cache, tpu_v_cache = self.tpu_cache[i] for i in range(num_layers):
cpu_k_cache, cpu_v_cache = self.cpu_cache[i] tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu() cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu() cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
if blocks_to_copy:
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device, if worker_input.blocks_to_copy is not None:
self.device) src_indices, dst_indices = worker_input.blocks_to_copy
attn_backend.copy_blocks(self.tpu_cache, src_to_dst) if src_indices.numel() > 0:
attn_backend.copy_blocks(self.tpu_cache,
def start_worker_execution_loop(self) -> None: (src_indices, dst_indices))
while self._execute_model_non_driver():
pass
def _execute_model_non_driver(self) -> bool:
self.model_runner.execute_model(None, self.tpu_cache)
return True
def _make_src_to_dst( def _make_src_to_dst(
......
...@@ -105,7 +105,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -105,7 +105,7 @@ class Worker(LocalOrDistributedWorkerBase):
# initialize_cache. # initialize_cache.
self.cache_engine: List[CacheEngine] self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches # Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment