Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
from ..utils import compare_two_settings
def test_custom_dispatcher():
compare_two_settings("google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
env2={})
...@@ -11,12 +11,16 @@ from typing import Any, Callable, Dict, List, Optional ...@@ -11,12 +11,16 @@ from typing import Any, Callable, Dict, List, Optional
import openai import openai
import requests import requests
from openai.types.completion import Completion
from transformers import AutoTokenizer from transformers import AutoTokenizer
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from tests.models.utils import TextTextLogprobs
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
...@@ -59,35 +63,50 @@ class RemoteOpenAIServer: ...@@ -59,35 +63,50 @@ class RemoteOpenAIServer:
def __init__(self, def __init__(self,
model: str, model: str,
cli_args: List[str], vllm_serve_args: List[str],
*, *,
env_dict: Optional[Dict[str, str]] = None, env_dict: Optional[Dict[str, str]] = None,
auto_port: bool = True, auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None: max_wait_seconds: Optional[float] = None) -> None:
if auto_port: if auto_port:
if "-p" in cli_args or "--port" in cli_args: if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError("You have manually specified the port" raise ValueError("You have manually specified the port "
"when `auto_port=True`.") "when `auto_port=True`.")
cli_args = cli_args + ["--port", str(get_open_port())] # Don't mutate the input args
vllm_serve_args = vllm_serve_args + [
"--port", str(get_open_port())
]
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.") description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser) parser = make_arg_parser(parser)
args = parser.parse_args(cli_args) args = parser.parse_args(["--model", model, *vllm_serve_args])
self.host = str(args.host or 'localhost') self.host = str(args.host or 'localhost')
self.port = int(args.port) self.port = int(args.port)
# download the model before starting the server to avoid timeout
is_local = os.path.isdir(model)
if not is_local:
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_config = engine_args.create_engine_config()
dummy_loader = DefaultModelLoader(engine_config.load_config)
dummy_loader._prepare_weights(engine_config.model_config.model,
engine_config.model_config.revision,
fall_back_to_pt=True)
env = os.environ.copy() env = os.environ.copy()
# the current process might initialize cuda, # the current process might initialize cuda,
# to be safe, we should use spawn method # to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None: if env_dict is not None:
env.update(env_dict) env.update(env_dict)
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, self.proc = subprocess.Popen(
env=env, ["vllm", "serve", model, *vllm_serve_args],
stdout=sys.stdout, env=env,
stderr=sys.stderr) stdout=sys.stdout,
stderr=sys.stderr,
)
max_wait_seconds = max_wait_seconds or 240 max_wait_seconds = max_wait_seconds or 240
self._wait_for_server(url=self.url_for("health"), self._wait_for_server(url=self.url_for("health"),
timeout=max_wait_seconds) timeout=max_wait_seconds)
...@@ -137,6 +156,7 @@ class RemoteOpenAIServer: ...@@ -137,6 +156,7 @@ class RemoteOpenAIServer:
return openai.AsyncOpenAI( return openai.AsyncOpenAI(
base_url=self.url_for("v1"), base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY, api_key=self.DUMMY_API_KEY,
max_retries=0,
) )
...@@ -414,3 +434,61 @@ def fork_new_process_for_each_test( ...@@ -414,3 +434,61 @@ def fork_new_process_for_each_test(
f" args {args} and kwargs {kwargs}") f" args {args} and kwargs {kwargs}")
return wrapper return wrapper
async def completions_with_server_args(
prompts: List[str],
model_name: str,
server_cli_args: List[str],
num_logprobs: Optional[int],
max_wait_seconds: int = 240,
) -> Completion:
'''Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.
Args:
prompts: test prompts
model_name: model to spin up on the vLLM server
server_cli_args: CLI args for starting the server
num_logprobs: Number of logprobs to report (or `None`)
max_wait_seconds: timeout interval for bringing up server.
Default: 240sec
Returns:
OpenAI Completion instance
'''
outputs = None
with RemoteOpenAIServer(model_name,
server_cli_args,
max_wait_seconds=max_wait_seconds) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5,
logprobs=num_logprobs)
assert outputs is not None
return outputs
def get_client_text_generations(completions: Completion) -> List[str]:
'''Extract generated tokens from the output of a
request made to an Open-AI-protocol completions endpoint.
'''
return [x.text for x in completions.choices]
def get_client_text_logprob_generations(
completions: Completion) -> List[TextTextLogprobs]:
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
'''
text_generations = get_client_text_generations(completions)
text = ''.join(text_generations)
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for x in completions.choices]
...@@ -4,6 +4,12 @@ gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main ...@@ -4,6 +4,12 @@ gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq, TheBloke/Llama-2-7B-GPTQ, main
gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq, TechxGenus/gemma-1.1-2b-it-GPTQ, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main
...@@ -13,8 +19,12 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main ...@@ -13,8 +19,12 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
awq, casperhansen/mixtral-instruct-awq, main awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
\ No newline at end of file qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main
\ No newline at end of file
import os import os
import torch
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
MODEL_NAME = os.environ.get("MODEL_NAME", MODEL_NAME = os.environ.get("MODEL_NAME",
"robertgshaw2/zephyr-7b-beta-channelwise-gptq") "robertgshaw2/zephyr-7b-beta-channelwise-gptq")
...@@ -8,9 +10,12 @@ QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin") ...@@ -8,9 +10,12 @@ QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
def test_weight_loading(vllm_runner): def test_weight_loading(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME, with vllm_runner(model_name=MODEL_NAME,
revision=REVISION, revision=REVISION,
dtype="auto", dtype=torch.half if QUANTIZATION == "gptq" else "auto",
quantization=QUANTIZATION, quantization=QUANTIZATION,
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model: tensor_parallel_size=2) as model:
......
...@@ -181,92 +181,98 @@ elif core_C_available: ...@@ -181,92 +181,98 @@ elif core_C_available:
ScalarType = torch.classes._core_C.ScalarType ScalarType = torch.classes._core_C.ScalarType
# Needed for dynamo support of ScalarType. if (hasattr(torch, "_library")
@torch._library.register_fake_class("_core_C::ScalarType") and hasattr(torch._library, "register_fake_class")):
class FakeScalarType: # Needed for dynamo support of ScalarType.
@torch._library.register_fake_class("_core_C::ScalarType")
class FakeScalarType:
def __init__(self, scalar_type): def __init__(self, scalar_type):
self.ScalarType = scalar_type self.ScalarType = scalar_type
def bias_getter(self) -> int: def bias_getter(self) -> int:
return self.ScalarType.bias return self.ScalarType.bias
def exponent_getter(self) -> int: def exponent_getter(self) -> int:
return self.ScalarType.exponent return self.ScalarType.exponent
def mantissa_getter(self) -> int: def mantissa_getter(self) -> int:
return self.ScalarType.mantissa return self.ScalarType.mantissa
def signed_getter(self) -> bool: def signed_getter(self) -> bool:
return self.ScalarType.signed return self.ScalarType.signed
def size_bits_getter(self) -> int: def size_bits_getter(self) -> int:
return self.ScalarType.size_bits return self.ScalarType.size_bits
@property @property
def size_bits(self) -> int: def size_bits(self) -> int:
return self.ScalarType.size_bits return self.ScalarType.size_bits
def min(self) -> Union[int, float]: def min(self) -> Union[int, float]:
return self.ScalarType.min() return self.ScalarType.min()
def max(self) -> Union[int, float]: def max(self) -> Union[int, float]:
return self.ScalarType.max() return self.ScalarType.max()
def is_signed(self) -> bool: def is_signed(self) -> bool:
return self.ScalarType.is_signed() return self.ScalarType.is_signed()
def is_floating_point(self) -> bool: def is_floating_point(self) -> bool:
return self.ScalarType.is_floating_point() return self.ScalarType.is_floating_point()
def is_integer(self) -> bool: def is_integer(self) -> bool:
return self.ScalarType.is_integer() return self.ScalarType.is_integer()
def has_bias(self) -> bool: def has_bias(self) -> bool:
return self.ScalarType.has_bias() return self.ScalarType.has_bias()
def has_infs(self) -> bool: def has_infs(self) -> bool:
return self.ScalarType.has_infs() return self.ScalarType.has_infs()
def has_nans(self) -> bool: def has_nans(self) -> bool:
return self.ScalarType.has_nans() return self.ScalarType.has_nans()
def is_ieee_754(self) -> bool: def is_ieee_754(self) -> bool:
return self.ScalarType.is_ieee_754() return self.ScalarType.is_ieee_754()
def __str__(self) -> str: def __str__(self) -> str:
return self.ScalarType.__str__() return self.ScalarType.__str__()
def __repr__(self) -> str: def __repr__(self) -> str:
return self.ScalarType.__repr__() return self.ScalarType.__repr__()
def __len__(self) -> int: def __len__(self) -> int:
return self.ScalarType.__len__() return self.ScalarType.__len__()
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]: def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
return torch.classes._core_C.ScalarType.__obj_flatten__( return torch.classes._core_C.ScalarType.__obj_flatten__(
self.ScalarType) self.ScalarType)
@classmethod @classmethod
def __obj_unflatten__( def __obj_unflatten__(
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType': cls, flat_type: Tuple[Tuple[str, Any],
return cls( ...]) -> 'ScalarType':
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type)) return cls(
torch.classes._core_C.ScalarType.__obj_unflatten__(
flat_type))
@classmethod @classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.int_(size_bits, bias) return ScalarType.int_(size_bits, bias)
@classmethod @classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.uint(size_bits, bias) return ScalarType.uint(size_bits, bias)
@classmethod @classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': def float_IEEE754(cls, exponent: int,
return ScalarType.float_IEEE754(exponent, mantissa) mantissa: int) -> 'ScalarType':
return ScalarType.float_IEEE754(exponent, mantissa)
@classmethod @classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, def float_(cls, exponent: int, mantissa: int,
nan_repr: int) -> 'ScalarType': finite_values_only: bool,
return ScalarType.float_(exponent, mantissa, finite_values_only, nan_repr: int) -> 'ScalarType':
nan_repr) return ScalarType.float_(exponent, mantissa,
finite_values_only, nan_repr)
...@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union ...@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import vllm.envs as envs
from vllm._core_ext import ScalarType from vllm._core_ext import ScalarType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -278,13 +279,22 @@ def GetAWQShareWorkspace()->torch.Tensor: ...@@ -278,13 +279,22 @@ def GetAWQShareWorkspace()->torch.Tensor:
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int, zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor: thy: int) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton)
return awq_dequantize_triton(qweight, scales, zeros)
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
thx, thy) thx, thy)
# def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, # def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
# scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: # scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
# return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # if envs.VLLM_USE_TRITON_AWQ:
# from vllm.model_executor.layers.quantization.awq_triton import (
# awq_gemm_triton)
# return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
# return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
def awq_gemm(input: torch.Tensor, weight: torch.Tensor, def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor, zeros_and_scales:torch.Tensor,
...@@ -317,6 +327,7 @@ def dequant_w4_gemm_colmajor(qweight:torch.Tensor, ...@@ -317,6 +327,7 @@ def dequant_w4_gemm_colmajor(qweight:torch.Tensor,
)->torch.Tensor: )->torch.Tensor:
return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size) return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size)
# gptq # gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
...@@ -434,6 +445,20 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, ...@@ -434,6 +445,20 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * 2),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
size_k, size_n, num_bits)
return output
def gptq_marlin_gemm(a: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, b_scales: torch.Tensor,
...@@ -611,6 +636,36 @@ def ggml_mul_mat_a8( ...@@ -611,6 +636,36 @@ def ggml_mul_mat_a8(
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
initial_states_, final_states_out_,
silu_activation)
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation)
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
delta_bias_, delta_softplus, index_,
x)
# moe # moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor, block_size: int, sorted_token_ids: torch.Tensor,
......
...@@ -19,7 +19,9 @@ class AudioAsset: ...@@ -19,7 +19,9 @@ class AudioAsset:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR) s3_prefix=ASSET_DIR)
return librosa.load(audio_path, sr=None) y, sr = librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr
@property @property
def url(self) -> str: def url(self) -> str:
......
...@@ -83,6 +83,15 @@ class FlashInferBackend(AttentionBackend): ...@@ -83,6 +83,15 @@ class FlashInferBackend(AttentionBackend):
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
return [64, 128, 256] return [64, 128, 256]
@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
class FlashInferState(AttentionState): class FlashInferState(AttentionState):
...@@ -113,8 +122,7 @@ class FlashInferState(AttentionState): ...@@ -113,8 +122,7 @@ class FlashInferState(AttentionState):
self.runner.parallel_config)) self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads( num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config) self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ use_tensor_cores = num_qo_heads // num_kv_heads > 4
(1, 2, 4, 8)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(), self._get_workspace_buffer(),
"NHD", "NHD",
...@@ -172,15 +180,18 @@ class FlashInferState(AttentionState): ...@@ -172,15 +180,18 @@ class FlashInferState(AttentionState):
self.runner.parallel_config)) self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads( num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config) self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ use_tensor_cores = num_qo_heads // num_kv_heads > 4
(1, 2, 4, 8)
self._graph_decode_wrapper = \ self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper( CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD", self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores) use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype( if self.runner.kv_cache_dtype.startswith("fp8"):
self.runner.kv_cache_dtype, self.runner.model_config.dtype) kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
else:
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(0, paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1, batch_size + 1,
...@@ -368,7 +379,8 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -368,7 +379,8 @@ class FlashInferMetadata(AttentionMetadata):
def decode_metadata(self) -> Optional["FlashInferMetadata"]: def decode_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported # Currently chunked prefill is not supported
if self.num_prefills > 0: if self.num_prefills > 0:
assert self.num_decode_tokens == 0 assert self.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
return None return None
return self return self
...@@ -578,8 +590,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -578,8 +590,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr_tensor = None paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None paged_kv_last_page_len_tensor = None
kv_cache_dtype = get_kv_cache_torch_dtype( if self.runner.kv_cache_dtype.startswith("fp8"):
self.runner.kv_cache_dtype, self.runner.model_config.dtype) kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
else:
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
return FlashInferMetadata( return FlashInferMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
...@@ -663,7 +680,6 @@ class FlashInferImpl(AttentionImpl): ...@@ -663,7 +680,6 @@ class FlashInferImpl(AttentionImpl):
if attn_metadata.num_decode_tokens > 0: if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, ( assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.") "Chunked prefill is not supported with flashinfer yet.")
if kv_cache is not None: if kv_cache is not None:
# Use the same reshape and cache kernel as flash attention. # Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash( ops.reshape_and_cache_flash(
...@@ -676,6 +692,12 @@ class FlashInferImpl(AttentionImpl): ...@@ -676,6 +692,12 @@ class FlashInferImpl(AttentionImpl):
k_scale, k_scale,
v_scale, v_scale,
) )
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
query = query.contiguous( query = query.contiguous(
) # Flashinfer requires query to be contiguous ) # Flashinfer requires query to be contiguous
...@@ -713,5 +735,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -713,5 +735,7 @@ class FlashInferImpl(AttentionImpl):
query, query,
kv_cache, kv_cache,
sm_scale=self.scale, sm_scale=self.scale,
logits_soft_cap=self.logits_soft_cap) logits_soft_cap=self.logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
...@@ -123,7 +123,13 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -123,7 +123,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("TPU version must be 4 or higher.") raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
or tpu_env.get("TYPE", None)
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
assert tpu_type is not None
tpu_type = tpu_type.lower()
if "lite" not in tpu_type: if "lite" not in tpu_type:
if self.num_kv_heads % 2 == 0: if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head" self.megacore_mode = "kv_head"
......
...@@ -226,6 +226,10 @@ def which_attn_to_use( ...@@ -226,6 +226,10 @@ def which_attn_to_use(
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info( logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache.") "Cannot use FlashAttention-2 backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
selected_backend = _Backend.XFORMERS selected_backend = _Backend.XFORMERS
elif block_size % 16 != 0: elif block_size % 16 != 0:
logger.info( logger.info(
......
"""Token blocks.""" """Token blocks."""
from typing import List, Optional from typing import TYPE_CHECKING, Iterator, List, Optional
from vllm.utils import Device from vllm.utils import Device
DEFAULT_LAST_ACCESSED_TIME = -1 DEFAULT_LAST_ACCESSED_TIME: float = -1
class PhysicalTokenBlock: class PhysicalTokenBlock:
...@@ -59,6 +59,11 @@ class BlockTable: ...@@ -59,6 +59,11 @@ class BlockTable:
def __getitem__(self, key): def __getitem__(self, key):
return self._blocks[key] return self._blocks[key]
if TYPE_CHECKING:
def __iter__(self) -> Iterator[PhysicalTokenBlock]:
raise RuntimeError("Method should be automatically generated")
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(key, slice): if isinstance(key, slice):
blocks = value blocks = value
......
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Callable, List
import torch
import vllm.envs as envs
class TorchCompileWrapperWithCustomDispacther:
"""
A wrapper class for torch.compile, with a custom dispatch logic.
Subclasses should:
1. Implement the forward method
2. Implement the dispatch logic in the __call__ method
It can use `self.compiled_codes` to access the compiled bytecode,
and `with self.dispatch_to_code(index):` to dispatch to
the compiled code.
3. Implement the `__init__` method to determine how to call
`torch.compile` over the forward method.
"""
def __init__(self, compiled_callable: Callable):
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = []
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
# read the env var to determine whether to use the custom dispatcher
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.
NOTE: this function can have additional arguments beyond the forward
method, for directly dispatching to the compiled code.
"""
return self.compiled_callable(*args, **kwargs)
@abstractmethod
def forward(self, *args, **kwargs):
...
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
"""Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object:
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe()
while True:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
if code_name == "_compile" and file_name == "convert_frame.py":
break
frame = frame.f_locals["frame"]
assert frame.f_code == old_code
if frame.f_locals["self"] is not self:
return
self.compiled_codes.append(new_code)
@contextmanager
def dispatch_to_code(self, index: int):
"""Context manager to dispatch to the compiled code.
Why does this work? Because Dynamo guarantees that the compiled
bytecode has exactly the same arguments, cell variables, and free
variables as the original code. Therefore we can directly switch
the code object in the function and call it.
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
""" # noqa
self.__class__.forward.__code__ = self.compiled_codes[index]
yield
self.__class__.forward.__code__ = self.original_code_object
import enum import enum
import json import json
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple, from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
Type, Union) Optional, Tuple, Type, Union)
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -32,6 +32,7 @@ if TYPE_CHECKING: ...@@ -32,6 +32,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
_PP_SUPPORTED_MODELS = [ _PP_SUPPORTED_MODELS = [
"AquilaModel", "AquilaModel",
...@@ -61,7 +62,8 @@ class ModelConfig: ...@@ -61,7 +62,8 @@ class ModelConfig:
output when `served_model_name` is not specified. output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use. tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer. available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer. downloading the model and tokenizer.
dtype: Data type for model weights and activations. The "auto" option dtype: Data type for model weights and activations. The "auto" option
...@@ -113,34 +115,39 @@ class ModelConfig: ...@@ -113,34 +115,39 @@ class ModelConfig:
the model name will be the same as `model`. the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data instances per modality limit_mm_per_prompt: Maximum number of data instances per modality
per prompt. Only applicable for multimodal models. per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
""" """
def __init__( def __init__(
self, self,
model: str, model: str,
tokenizer: str, tokenizer: str,
tokenizer_mode: str, tokenizer_mode: str,
trust_remote_code: bool, trust_remote_code: bool,
dtype: Union[str, torch.dtype], dtype: Union[str, torch.dtype],
seed: int, seed: int,
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None, code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None, rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None, rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None, spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None, quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None, enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20, max_logprobs: int = 20,
disable_sliding_window: bool = False, disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None, served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
) -> None: use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
...@@ -172,6 +179,7 @@ class ModelConfig: ...@@ -172,6 +179,7 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision) self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
# Choose a default enforce_eager value if the user did not specify # Choose a default enforce_eager value if the user did not specify
# a value (enforce_eager is None) # a value (enforce_eager is None)
...@@ -223,6 +231,9 @@ class ModelConfig: ...@@ -223,6 +231,9 @@ class ModelConfig:
limit_mm_per_prompt) limit_mm_per_prompt)
if not self.skip_tokenizer_init: if not self.skip_tokenizer_init:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self.override_neuron_config = override_neuron_config if is_neuron(
) else None
self._verify_embedding_mode() self._verify_embedding_mode()
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
...@@ -244,10 +255,10 @@ class ModelConfig: ...@@ -244,10 +255,10 @@ class ModelConfig:
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower() tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]: if tokenizer_mode not in ["auto", "slow", "mistral"]:
raise ValueError( raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'.") "either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_embedding_mode(self) -> None: def _verify_embedding_mode(self) -> None:
...@@ -264,13 +275,14 @@ class ModelConfig: ...@@ -264,13 +275,14 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm","awq"] # "fp8" rocm_supported_quantization = ["awq", "gptq", "squeezellm"] # "fp8"
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors", "fbgemm_fp8", "compressed_tensors", "compressed-tensors",
"experts_int8" "experts_int8"
] ]
tpu_supported_quantization = ["tpu_int8"] tpu_supported_quantization = ["tpu_int8"]
neuron_supported_quantization = ["neuron_quant"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
...@@ -319,6 +331,17 @@ class ModelConfig: ...@@ -319,6 +331,17 @@ class ModelConfig:
"%s quantization is not fully " "%s quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization) "non-quantized models.", self.quantization)
if (self.quantization == "awq" and is_hip()
and not envs.VLLM_USE_TRITON_AWQ):
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True
if is_neuron(
) and self.quantization not in neuron_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in Neuron Backend.")
def _verify_cuda_graph(self) -> None: def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None: if self.max_seq_len_to_capture is None:
...@@ -326,6 +349,49 @@ class ModelConfig: ...@@ -326,6 +349,49 @@ class ModelConfig:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len) self.max_model_len)
def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
# Nothing to check
return
if parallel_config.pipeline_parallel_size > 1:
logger.warning("Async output processing can not be enabled "
"with pipeline parallel")
self.use_async_output_proc = False
return
if device_config.device_type not in ("cuda", "tpu"):
logger.warning(
"Async output processing is only supported for CUDA or TPU. "
"Disabling it for other platforms.")
self.use_async_output_proc = False
return
if envs.VLLM_USE_RAY_SPMD_WORKER:
logger.warning(
"Async output processing can not be enabled with ray spmd")
self.use_async_output_proc = False
return
if self.enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
self.use_async_output_proc = not self.enforce_eager
return
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.embedding_mode:
self.use_async_output_proc = False
if speculative_config:
logger.warning("Async output processing is not supported with"
" speculative decoding currently.")
self.use_async_output_proc = False
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
...@@ -353,11 +419,18 @@ class ModelConfig: ...@@ -353,11 +419,18 @@ class ModelConfig:
raise ValueError( raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.") "BitAndBytes quantization with TP or PP is not supported yet.")
# Remove the constraint after the bitsandbytes issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
if self.quantization == "bitsandbytes" and self.enforce_eager is False: if self.quantization == "bitsandbytes" and self.enforce_eager is False:
logger.warning("CUDA graph is not supported on BitAndBytes yet, " logger.warning("CUDA graph is not supported on BitAndBytes yet, "
"fallback to the eager mode.") "fallback to the eager mode.")
self.enforce_eager = True self.enforce_eager = True
if pipeline_parallel_size > 1 and self.use_async_output_proc:
logger.warning("Async output processor is not supported with "
"pipeline parallelism currently. Disabling it.")
self.use_async_output_proc = False
def get_hf_config_sliding_window(self) -> Optional[int]: def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.""" """Get the sliding window size, or None if disabled."""
...@@ -512,6 +585,10 @@ class ModelConfig: ...@@ -512,6 +585,10 @@ class ModelConfig:
"""Extract the embedding model flag.""" """Extract the embedding model flag."""
return self.embedding_mode return self.embedding_mode
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache. """Configuration for the KV cache.
...@@ -888,25 +965,36 @@ class SchedulerConfig: ...@@ -888,25 +965,36 @@ class SchedulerConfig:
num_lookahead_slots: int = 0, num_lookahead_slots: int = 0,
delay_factor: float = 0.0, delay_factor: float = 0.0,
enable_chunked_prefill: bool = False, enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False, embedding_mode: bool = False,
is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None, preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1, num_scheduler_steps: int = 1,
send_delta_data: bool = False) -> None: send_delta_data: bool = False) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
if enable_chunked_prefill: if enable_chunked_prefill:
# It is the values that have the best balance between ITL # It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput. # and TTFT on A100. Note it is not optimized for throughput.
self.max_num_batched_tokens = 512 max_num_batched_tokens = 512
elif embedding_mode:
# For embedding, choose specific value for higher throughput
self.max_num_batched_tokens = max(
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
else: else:
# If max_model_len is too short, use 2048 as the default value # If max_model_len is too short, use 2048 as the default value
# for higher throughput. # for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048) max_num_batched_tokens = max(max_model_len, 2048)
if embedding_mode:
# For embedding, choose specific value for higher throughput
max_num_batched_tokens = max(
max_num_batched_tokens,
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
max_num_batched_tokens = max(
max_num_batched_tokens,
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)
self.max_num_batched_tokens = max_num_batched_tokens
if enable_chunked_prefill: if enable_chunked_prefill:
logger.info( logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.", "Chunked prefill is enabled with max_num_batched_tokens=%d.",
...@@ -1769,6 +1857,9 @@ class EngineConfig: ...@@ -1769,6 +1857,9 @@ class EngineConfig:
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
""" """
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
......
...@@ -132,7 +132,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -132,7 +132,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def allocate_immutable_blocks(self, prev_block: Optional[Block], def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]], block_token_ids: List[List[int]],
device: Optional[Device]) -> List[Block]: device: Device) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block """Allocates a new group of immutable blocks with the provided block
token IDs on the specified device. token IDs on the specified device.
......
"""Token blocks.""" """Token blocks."""
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively) get_all_blocks_recursively)
...@@ -73,6 +73,11 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -73,6 +73,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# prefix hash will be in this dict, even if they have refcount 0. # prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {} self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A list of immutable block IDs that have been touched by scheduler
# and should be marked as computed after an entire batch of sequences
# are scheduled.
self._touched_blocks: Set[BlockId] = set()
# Used to track status of each physical block id # Used to track status of each physical block id
self._block_tracker: Dict[BlockId, BlockTracker] = {} self._block_tracker: Dict[BlockId, BlockTracker] = {}
for block_id in block_ids: for block_id in block_ids:
...@@ -438,10 +443,14 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -438,10 +443,14 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert self._refcounter.get(block.block_id) > 0 assert self._refcounter.get(block.block_id) > 0
if block.content_hash not in self._cached_blocks: if block.content_hash not in self._cached_blocks:
# No cached content hash => Set this block as cached # No cached content hash => Set this block as cached.
# (Note that this block is not computed yet => # Note that this block cannot be marked as computed yet
# Will be computed after free()) # because other sequences in the same batch cannot reuse
# this block.
self._cached_blocks[block.content_hash] = block.block_id self._cached_blocks[block.content_hash] = block.block_id
# Mark this block as touched so that it can be marked as
# computed after the entire batch of sequences are scheduled.
self._touched_blocks.add(block.block_id)
return block.block_id return block.block_id
# Reuse the cached content hash # Reuse the cached content hash
...@@ -507,7 +516,10 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -507,7 +516,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"Mark block as accessed which is not belonged to GPU") "Mark block as accessed which is not belonged to GPU")
def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
raise NotImplementedError("Marking as computed is incremental") # Mark all touched blocks as computed.
for block_id in self._touched_blocks:
self._block_tracker[block_id].computed = True
self._touched_blocks.clear()
def _track_block_id(self, block_id: Optional[BlockId], def _track_block_id(self, block_id: Optional[BlockId],
computed: bool) -> None: computed: bool) -> None:
......
...@@ -278,7 +278,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -278,7 +278,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# request ID # request ID
self.cross_block_tables: Dict[str, BlockTable] = {} self.cross_block_tables: Dict[str, BlockTable] = {}
def _get_seq_num_required_blocks(self, seq: Sequence) -> int: def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
return 0 if seq is None else seq.n_blocks return 0 if seq is None else seq.n_blocks
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
...@@ -310,13 +310,14 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -310,13 +310,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
return AllocStatus.LATER return AllocStatus.LATER
def _allocate_sequence(self, \ def _allocate_sequence(self, \
seq: Sequence, \ seq: Optional[Sequence], \
ref_count: int, \ ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable: is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = seq.n_blocks num_prompt_blocks = self._get_seq_num_required_blocks(seq)
block_table: BlockTable = BlockTable() block_table: BlockTable = BlockTable()
assert seq is not None
for logical_idx in range(num_prompt_blocks): for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window): and logical_idx >= self.block_sliding_window):
...@@ -680,14 +681,20 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -680,14 +681,20 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for block in block_table: for block in block_table:
block.last_accessed = access_time block.last_accessed = access_time
def compute_full_blocks_in_seq(self, seq: Sequence): def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
if seq.seq_id not in self.block_tables: if seq.seq_id not in self.block_tables:
return return
max_full_block = seq.get_len() // self.block_size - 1
# When chunked prefill is enabled, the computed full blocks
# should be calculated based on the number of computed tokens.
max_computed_tokens = (seq.data.get_num_computed_tokens() +
token_chunk_size)
computed_full_blocks = max_computed_tokens // self.block_size
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
if max_full_block == -1: if computed_full_blocks == 0:
return return
for i in reversed(range(max_full_block)): for i in reversed(range(computed_full_blocks)):
if block_table[i].computed: if block_table[i].computed:
break break
block_table[i].computed = True block_table[i].computed = True
...@@ -717,10 +724,11 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -717,10 +724,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs] ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []]) return commonprefix([ids for ids in ids_list if ids != []])
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
if self.enable_caching: if self.enable_caching:
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq) self.compute_full_blocks_in_seq(seq, token_chunk_size)
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU: if device == Device.GPU:
......
...@@ -120,8 +120,10 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -120,8 +120,10 @@ class BlockSpaceManagerV2(BlockSpaceManager):
) )
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
num_required_blocks += BlockTable.get_num_required_blocks( num_required_blocks += BlockTable.get_num_required_blocks(
seq_group.get_encoder_seq().get_token_ids(), encoder_seq.get_token_ids(),
block_size=self.block_size, block_size=self.block_size,
) )
...@@ -189,7 +191,9 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -189,7 +191,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
block_table = self._allocate_sequence(seq_group.get_encoder_seq()) encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
block_table = self._allocate_sequence(encoder_seq)
self.cross_block_tables[request_id] = block_table self.cross_block_tables[request_id] = block_table
def can_append_slots(self, seq_group: SequenceGroup, def can_append_slots(self, seq_group: SequenceGroup,
...@@ -286,12 +290,13 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -286,12 +290,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self._last_access_blocks_tracker.update_last_access( self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now) seq.seq_id, now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup,
# The only need for mark block as computed is for prefix caching, token_chunk_size: int):
# while currently we could determine whether one block is computed # If prefix caching is enabled, mark immutable blocks as computed
# or not by check whether it has content hash. # right after they have been scheduled (for prefill). This assumes
# So this function is useless for block_v2. # the scheduler is synchronous so blocks are actually computed when
pass # scheduling the next batch.
self.block_allocator.mark_blocks_as_computed([])
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]: self, seqs: List[Sequence]) -> GenericSequence[int]:
......
...@@ -77,10 +77,11 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager): ...@@ -77,10 +77,11 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
pass pass
def get_common_computed_block_ids(self, def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]: seq_group: List[Sequence]) -> List[int]:
return None # type: ignore return []
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass pass
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
......
...@@ -115,7 +115,8 @@ class BlockSpaceManager(ABC): ...@@ -115,7 +115,8 @@ class BlockSpaceManager(ABC):
pass pass
@abstractmethod @abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass pass
@abstractmethod @abstractmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment