Commit a99300bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev

parents cc3e01c7 5438967f
......@@ -790,15 +790,7 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# # marlin
# def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
# size_n: int, size_k: int) -> torch.Tensor:
# return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
# size_n, size_k)
# # marlin_24
# marlin_24
# def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# b_meta: torch.Tensor, b_scales: torch.Tensor,
# workspace: torch.Tensor, b_q_type: ScalarType,
......@@ -840,25 +832,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
# is_zp_float: bool = False) -> torch.Tensor:
# return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
# @register_fake("_C::marlin_qqq_gemm")
# def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
# s_tok: torch.Tensor, s_ch: torch.Tensor,
# s_group: torch.Tensor, workspace: torch.Tensor,
# size_m: torch.SymInt, size_n: torch.SymInt,
# size_k: torch.SymInt) -> torch.Tensor:
# return torch.empty((size_m, size_n),
# dtype=torch.float16,
# device=a.device)
# @register_fake("_C::marlin_gemm")
# def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
# b_scales: torch.Tensor, workspace: torch.Tensor,
# size_m: torch.SymInt, size_n: torch.SymInt,
# size_k: torch.SymInt) -> torch.Tensor:
# return torch.empty((size_m, size_n),
# dtype=torch.float16,
# device=a.device)
# @register_fake("_C::awq_dequantize")
# def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
# zeros: torch.Tensor, split_k_iters: torch.SymInt,
......@@ -904,6 +877,30 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
# return torch.empty_like(b_q_weight,
# memory_format=torch.contiguous_format)
# @register_fake("_C::cutlass_w4a8_mm")
# def cutlass_w4a8_mm_fake(
# a: torch.Tensor,
# # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
# b_q: torch.Tensor,
# b_group_scales: torch.Tensor,
# b_group_size: int,
# b_channel_scales: torch.Tensor,
# a_token_scales: torch.Tensor,
# out_type: Optional[torch.dtype] = None,
# maybe_schedule: Optional[str] = None) -> torch.Tensor:
# m = a.size(0)
# n = b_q.size(1)
# out_dtype = out_type if out_type is not None else torch.bfloat16
# return torch.empty((m, n), device=a.device, dtype=out_dtype)
# @register_fake("_C::cutlass_pack_scale_fp8")
# def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor:
# return torch.empty_like(scales, memory_format=torch.contiguous_format)
# @register_fake("_C::cutlass_encode_and_reorder_int4b")
# def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
# return torch.empty_like(b, memory_format=torch.contiguous_format)
# if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
......@@ -920,7 +917,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
# m = a.size(0)
# return torch.empty((m, n), device=a.device, dtype=a.dtype)
if hasattr(torch.ops._C, "ggml_dequantize"):
@register_fake("_C::ggml_dequantize")
......@@ -1291,6 +1287,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
blockscale_offsets)
def get_cutlass_moe_mm_problem_sizes(
topk_ids: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
num_experts: int,
n: int,
k: int,
blockscale_offsets: Optional[torch.Tensor] = None):
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
multiplications used in CUTLASS-based fused MoE.
The function takes in topk_ids (token→expert mapping) and computes:
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
used in the fused MoE operation.
"""
return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k,
blockscale_offsets)
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
"""
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
......@@ -1484,6 +1502,30 @@ def machete_prepack_B(
group_scales_type)
# CUTLASS W4A8
def cutlass_w4a8_mm(
a: torch.Tensor,
# b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
b_q: torch.Tensor,
b_group_scales: torch.Tensor,
b_group_size: int,
b_channel_scales: torch.Tensor,
a_token_scales: torch.Tensor,
out_type: Optional[torch.dtype] = None,
maybe_schedule: Optional[str] = None) -> torch.Tensor:
return torch.ops._C.cutlass_w4a8_mm(a, b_q, b_group_scales, b_group_size,
b_channel_scales, a_token_scales,
out_type, maybe_schedule)
def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor:
return torch.ops._C.cutlass_pack_scale_fp8(scales)
def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
return torch.ops._C.cutlass_encode_and_reorder_int4b(b)
if hasattr(torch.ops._C, "permute_cols"):
@register_fake("_C::permute_cols")
......@@ -1773,15 +1815,6 @@ def scaled_int8_quant(
return output, input_scales, input_azp
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
workspace, size_m, size_n, size_k)
# gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int,
dtype: Optional[torch.dtype]) -> torch.Tensor:
......@@ -1918,6 +1951,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
gating_output)
def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor,
num_expert_group: int, topk_group: int, topk: int,
renormalize: bool, routed_scaling_factor: float):
if not current_platform.is_cuda():
raise NotImplementedError("The fused grouped_topk kernel is only "
"available on CUDA platforms")
return torch.ops._moe_C.grouped_topk(scores, scores_with_bias,
num_expert_group, topk_group, topk,
renormalize, routed_scaling_factor)
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor,
b_bias: Optional[torch.Tensor],
......@@ -2045,6 +2089,20 @@ def concat_and_cache_mla(
scale)
def cp_fused_concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
cp_local_token_select_indices: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
kv_cache_dtype, scale)
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None:
......@@ -2068,14 +2126,28 @@ def convert_fp8(output: torch.Tensor,
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
def gather_cache(src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
def gather_and_maybe_dequant_cache(
src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
kv_cache_dtype: str,
scale: torch.Tensor,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache(
src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, seq_starts)
def cp_gather_cache(src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.cp_gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
def get_device_attribute(attribute: int, device: int) -> int:
......@@ -2378,9 +2450,92 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
N = mat2.size(0)
return torch.empty((M, N), dtype=out_dtype)
class CPUDNNLGEMMHandler:
def __init__(self) -> None:
self.handler: Optional[int] = None
self.n = -1
self.k = -1
def __del__(self):
if self.handler is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler)
def create_onednn_scaled_mm(
weight: torch.Tensor, # [K, N]
weight_scales: torch.Tensor,
output_type: torch.dtype,
dynamic_quant: bool,
use_azp: bool,
primitive_cache_size: int = 128,
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_scaled_mm_handler(
weight, weight_scales, output_type, dynamic_quant, use_azp,
primitive_cache_size)
return handler
def onednn_scaled_int8_quant(input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
azp: Optional[torch.Tensor] = None,
symmetric: bool = True):
"""
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
Args:
input: The input tensor to be quantized to int8.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
azp: Optional zero-point for the int8 quantization.
Must be provided for asymmetric quantization if `scale` is provided.
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
"""
output = torch.empty_like(input, dtype=torch.int8)
token_num = input.numel() // input.shape[-1]
input = input.view((token_num, input.shape[-1]))
if scale is not None:
# static-per-tensor quantization.
assert symmetric == (
azp
is None), "azp must only be provided for asymmetric quantization."
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
return output, scale, azp
# dynamic-per-token quantization.
input_scales = torch.empty((token_num, 1),
device=input.device,
dtype=torch.float32)
input_azp = None if symmetric else torch.empty_like(input_scales,
dtype=torch.int32)
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
input_azp)
return output, input_scales, input_azp
def onednn_scaled_mm(
dnnl_handler: CPUDNNLGEMMHandler,
x: torch.Tensor,
output: torch.Tensor,
input_scale: Optional[torch.Tensor],
input_zp: Optional[torch.Tensor],
input_zp_adj: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
torch.ops._C.onednn_scaled_mm(output, x, input_scale, input_zp,
input_zp_adj, bias, dnnl_handler.handler)
return output
direct_register_custom_op(
op_name="awq_gemm",
op_func=awq_gemm,
mutates_args=[],
fake_impl=awq_gemm_fake,
)
\ No newline at end of file
)
......@@ -11,7 +11,7 @@ from .base import get_vllm_public_assets
VLM_IMAGES_DIR = "vision_model_images"
ImageAssetName = Literal["stop_sign", "cherry_blossom"]
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato"]
@dataclass(frozen=True)
......
......@@ -14,7 +14,6 @@ __all__ = [
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
]
......@@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING:
......@@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
def fused_output_quant_supported(self, quant_key: QuantKey):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape.
:param quant_key: QuantKey object that describes the quantization op
:return: is fusion supported for this type of quantization
"""
return False
......@@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
......
......@@ -800,23 +800,33 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
attn_metadata: DifferentialFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
layer: Attention layer instance.
q: Query tensor with shape = [num_tokens, num_heads, head_size]
k: Key tensor with shape = [num_tokens, num_kv_heads, head_size]
v: Value tensor with shape = [num_tokens, num_kv_heads, head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size, num_kv_heads, head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
output: Output tensor with shape [num_tokens, num_heads, head_size]
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for DifferentialFlashAttentionImpl")
if self.lambda_full is None:
self.lambda_init = self.differential_flash_attention_config[
"lambda_init"]
......
......@@ -376,6 +376,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention.
Args:
......@@ -391,7 +392,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
"""
assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
......
......@@ -603,6 +603,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
......@@ -611,7 +612,8 @@ class FlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size, num_kv_heads, head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
......@@ -622,7 +624,7 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
......@@ -925,7 +927,7 @@ class FlashAttentionImpl(AttentionImpl):
def _get_query_key_seq_metadata(
attn_metadata,
attn_metadata: FlashAttentionMetadata,
is_prompt: bool,
attn_type: str,
) -> tuple:
......
......@@ -837,8 +837,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
self.context_chunk_workspace_size // num_prefills_with_context
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
# currently the `gather_and_maybe_dequant_cache` kernel cannot
# handle `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk, self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
......@@ -1090,6 +1090,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
):
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
......@@ -1111,12 +1112,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
for i in range(iters):
toks = prefill_metadata.context_chunk_seq_tot[i]
ops.gather_cache(
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_tables,
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
batch_size=prefill_metadata.num_prefills,
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.context_chunk_starts[i],
)
......@@ -1173,6 +1176,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
......@@ -1208,7 +1212,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
output = torch.empty_like(suffix_output)
merge_attn_states(
......@@ -1245,12 +1249,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")
......@@ -1298,7 +1303,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if has_prefill:
output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
attn_metadata, layer._k_scale)
if has_decode:
decode_q_nope, decode_q_pe = decode_q.split(
......
......@@ -23,7 +23,7 @@ from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform
from vllm.utils import SUPPORT_TC, gpuname
......@@ -549,11 +549,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
def fused_output_quant_supported(self, quant_key: QuantKey):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == GroupShape.PER_TENSOR
return quant_key == kFp8StaticTensorSym
# Only supported in the Triton backend
return False
......@@ -568,6 +566,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -605,17 +604,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
use prefill sequence attributes
Args:
layer: Attention layer instance.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size * num_kv_heads * head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
output: Optional output tensor.
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
......@@ -626,6 +626,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
if output_block_scale is not None:
raise NotImplementedError(
"fused nvfp4 output quantization is not supported"
" for ROCMFlashAttentionImpl")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
......
......@@ -585,7 +585,7 @@ def get_num_prefill_decode_query_kv_tokens(
Raises:
AssertionError: If the number of encoder tokens in `attn_metadata`
is `None` when required for the calculations.
is `None` when required for the calculations.
"""
num_prefill_query_tokens = 0
num_decode_query_tokens = 0
......
......@@ -439,6 +439,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -477,21 +478,22 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
max_encoder_seq_len)
Args:
layer: Attention layer instance.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size * num_kv_heads * head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
output: Optional output tensor.
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")
......@@ -654,7 +656,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
for API spec.
Args:
output: shape = [num_prefill_tokens, num_heads, head_size]
query: shape = [num_prefill_tokens, num_heads, head_size]
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
......
......@@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
......@@ -54,7 +55,7 @@ def check_xformers_availability():
return USE_XFORMERS_OPS
class Attention(nn.Module):
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
......@@ -128,11 +129,17 @@ class Attention(nn.Module):
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.num_heads = num_heads
self.head_size = head_size
......@@ -183,8 +190,7 @@ class Attention(nn.Module):
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
......@@ -291,6 +297,7 @@ class Attention(nn.Module):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
......@@ -488,6 +495,7 @@ def unified_attention_with_output(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
......@@ -503,7 +511,8 @@ def unified_attention_with_output(
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale)
output_scale=output_scale,
output_block_scale=output_block_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
......@@ -515,6 +524,7 @@ def unified_attention_with_output_fake(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
return
......@@ -522,7 +532,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output"],
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
......@@ -6,12 +6,13 @@ from typing import List, Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)
subclass_attention_backend)
from ..layer import Attention
......@@ -24,21 +25,23 @@ def create_chunked_local_attention_backend(
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
block_size)
underlying_builder = underlying_attn_backend.get_builder_cls()
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
common_attn_metadata = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size)
return super().build(common_prefix_len, common_attn_metadata,
fast_build)
# Dynamically create a new attention backend that wraps the
# underlying attention backend but applies
# `make_local_attention_virtual_batches` before calling `build(...)`
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)
builder_cls=ChunkedLocalAttentionBuilder)
return attn_backend
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
subclass_attention_backend)
@functools.lru_cache
def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_common_attn_metadata = copy(common_attn_metadata)
new_common_attn_metadata.causal = False
return super().build(common_prefix_len, new_common_attn_metadata,
fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=EncoderOnlyAttentionBuilder)
return attn_backend
class EncoderOnlyAttention(Attention):
"""
Encoder attention is a special case that doesn't need a KV Cache.
"""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_encoder_only_attention_backend(
underlying_attn_backend)
else:
# in v0 encoder only attention is handled inside the backends
attn_backend = None
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_ONLY, \
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_ONLY,
**kwargs)
......@@ -75,8 +75,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
k_scale = None,
kv_cache_dtype = "auto",
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
......@@ -91,6 +91,8 @@ def flash_mla_with_kvcache(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q.
descale_k: (batch_size), torch.float32. Descaling factors for K.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
......@@ -99,22 +101,22 @@ def flash_mla_with_kvcache(
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm():
if kv_cache_dtype == "fp8":
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
"fp8_e4m3",
)
return out, softmax_lse
# if kv_cache_dtype == "fp8":
# out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
# q,
# k_cache,
# None,
# head_dim_v,
# cache_seqlens,
# block_table,
# softmax_scale,
# causal,
# tile_scheduler_metadata,
# num_splits,
# k_scale,
# "fp8_e4m3",
# )
# return out, softmax_lse
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
......@@ -126,12 +128,13 @@ def flash_mla_with_kvcache(
causal,
tile_scheduler_metadata,
num_splits,
# descale_q,
# descale_k,
)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
......@@ -139,6 +142,8 @@ def flash_mla_with_kvcache(
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
......
......@@ -18,7 +18,7 @@ class BeamSearchSequence:
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
# The tokens include the prompt.
tokens: list[int]
logprobs: list[dict[int, Logprob]]
lora_request: Optional[LoRARequest] = None
......
......@@ -11,17 +11,21 @@ generation. Supported dataset types include:
- HuggingFace
- VisionArena
"""
import ast
import base64
import io
import json
import logging
import math
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
from collections.abc import Iterator, Mapping
from contextlib import suppress
from copy import deepcopy
from dataclasses import dataclass
from functools import cache
from io import BytesIO
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, cast
import numpy as np
from PIL import Image
......@@ -69,13 +73,14 @@ class SampleRequest:
Represents a single inference request for benchmarking.
"""
prompt: Union[str, Any]
prompt: Union[str, list[str]]
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[
Union[MultiModalDataDict, dict, list[dict]]
] = None
lora_request: Optional[LoRARequest] = None
request_id: Optional[str] = None
# -----------------------------------------------------------------------------
......@@ -112,7 +117,9 @@ class BenchmarkDataset(ABC):
def apply_multimodal_chat_transformation(
self,
prompt: str,
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
mm_content: Optional[
Union[MultiModalDataDict, dict, list[dict]]
] = None) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
......@@ -120,7 +127,15 @@ class BenchmarkDataset(ABC):
"""
content = [{"text": prompt, "type": "text"}]
if mm_content is not None:
content.append(mm_content)
if isinstance(mm_content, list):
content.extend(cast(list[dict[str, Any]], mm_content))
elif isinstance(mm_content, dict):
content.append(mm_content)
else:
raise TypeError(
"Could not process multimodal content of type: " +
f"{type(mm_content)}"
)
return [{"role": "user", "content": content}]
def load_data(self) -> None:
......@@ -183,7 +198,8 @@ class BenchmarkDataset(ABC):
@abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase,
num_requests: int) -> list[SampleRequest]:
num_requests: int,
request_id_prefix: str = "") -> list[SampleRequest]:
"""
Abstract method to generate sample requests from the dataset.
......@@ -194,6 +210,8 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
Returns:
list[SampleRequest]: A list of sample requests generated from the
......@@ -201,8 +219,12 @@ class BenchmarkDataset(ABC):
"""
raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest],
num_requests: int) -> None:
def maybe_oversample_requests(
self,
requests: list[SampleRequest],
num_requests: int,
request_id_prefix: str = "",
) -> None:
"""
Oversamples the list of requests if its size is less than the desired
number.
......@@ -211,11 +233,17 @@ class BenchmarkDataset(ABC):
requests (List[SampleRequest]): The current list of sampled
requests.
num_requests (int): The target number of requests.
request_id_prefix (str) The prefix of the request ids.
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
additional = random.choices(requests,
k=num_requests - len(requests))
additional = deepcopy(
random.choices(requests, k=num_requests - len(requests))
)
for i in range(len(additional)):
req = additional[i]
req.request_id = request_id_prefix + str(len(requests) + i)
requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.",
num_requests)
......@@ -266,7 +294,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
"""
Process a single image input and return a multimedia content dictionary.
Supports three input types:
Supports the following input types:
1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
containing raw image data. - Loads the bytes as a PIL.Image.Image.
......@@ -306,94 +334,592 @@ def process_image(image: Any) -> Mapping[str, Any]:
" or str or dictionary with raw image bytes.")
def process_video(video: Any) -> Mapping[str, Any]:
"""
Process a single video input and return a multimedia content dictionary.
Supports the following input types:
1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
containing raw video data.
2. String input: - Treats the string as a URL or local file path. -
Prepends "file://" if the string doesn't start with "http://" or
"file://". - Returns a dictionary with the image URL.
Raises:
ValueError: If the input is not a supported type.
"""
if isinstance(video, dict) and 'bytes' in video:
video_bytes = video['bytes']
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
return {
"type": "video_url",
"video_url": {
"url": f"data:video/mp4;base64,{video_base64}"
},
}
if isinstance(video, str):
video_url = (video if video.startswith(
("http://", "file://")) else f"file://{video}")
return {"type": "video_url", "video_url": {"url": video_url}}
raise ValueError(
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
)
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class RandomDataset(BenchmarkDataset):
"""
Synthetic text-only dataset for serving/throughput benchmarks.
Strategy:
- Sample input/output token lengths per request from integer-uniform ranges
around configured means (controlled by range_ratio).
- Prepend a fixed random prefix of length prefix_len.
- Generate the remaining tokens as a reproducible sequence:
(offset + index + arange(input_len)) % vocab_size.
- Decode then re-encode/truncate to ensure prompt token counts match.
- Uses numpy.default_rng seeded with random_seed for reproducible sampling.
"""
# Default values copied from benchmark_serving.py for the random dataset.
DEFAULT_PREFIX_LEN = 0
DEFAULT_RANGE_RATIO = 0.0
DEFAULT_INPUT_LEN = 1024
DEFAULT_OUTPUT_LEN = 128
def __init__(
self,
**kwargs,
) -> None:
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
random.seed(self.random_seed)
np.random.seed(self.random_seed)
# Use numpy's default_rng for deterministic sampling
# Do not use random.seed() or np.random.seed() elsewhere in this class.
# This ensures that the RNG is isolated from global RNG state.
self._rng = np.random.default_rng(self.random_seed)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
prefix_len: int = DEFAULT_PREFIX_LEN,
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
batchsize: int = 1,
**kwargs,
) -> list[SampleRequest]:
# Enforce range_ratio < 1
assert range_ratio < 1.0, (
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
requests = []
for i in range(num_requests):
prompt, total_input_len = self.generate_token_sequence(
tokenizer=tokenizer,
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
vocab_size=vocab_size,
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
)
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
)
)
# only used for embeddings benchmark.
if batchsize > 1:
batch_requests = []
# Create batched requests
for i in range(0, num_requests, batchsize):
batch = requests[i : i + batchsize]
batch_requests.append(
SampleRequest(
prompt=[req.prompt for req in batch],
prompt_len=sum(req.prompt_len for req in batch),
expected_output_len=0,
request_id=request_id_prefix + str(i // batchsize),
)
)
requests = batch_requests
return requests
def get_prefix(
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
) -> list[int]:
"""
Get the prefix for the dataset.
"""
return (
self._rng.integers(
0, tokenizer.vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(real_input_len * (1 - range_ratio))
input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))
def get_sampling_params(
self,
num_requests: int,
range_ratio: float,
input_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the sampling parameters for the dataset.
"""
# Enforce range_ratio < 1
if not (0.0 <= range_ratio < 1.0):
raise ValueError("range_ratio must be in [0, 1).")
num_special_tokens = int(tokenizer.num_special_tokens_to_add())
real_input_len = max(0, int(input_len) - num_special_tokens)
# Bounds use floor for low and ceil for high
input_low = math.floor(real_input_len * (1 - range_ratio))
input_high = math.ceil(real_input_len * (1 + range_ratio))
output_low = math.floor(output_len * (1 - range_ratio))
output_high = math.ceil(output_len * (1 + range_ratio))
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low = max(output_low, 1)
if input_low > input_high:
raise ValueError(
"Invalid input sampling interval: "
f"low={input_low} > high={input_high}"
)
if output_low > output_high:
raise ValueError(
"Invalid output sampling interval: "
f"low={output_low} > high={output_high}"
)
# Add logging for debugging
logger.info(
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
input_low, input_high, output_low, output_high)
input_low,
input_high,
output_low,
output_high,
)
input_lens = np.random.randint(input_low,
input_high + 1,
size=num_requests)
output_lens = np.random.randint(output_low,
output_high + 1,
input_lens = self._rng.integers(input_low, input_high + 1,
size=num_requests)
output_lens = self._rng.integers(output_low, output_high + 1,
size=num_requests)
offsets = self._rng.integers(0, tokenizer.vocab_size,
size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests)
return input_lens, output_lens, offsets
requests = []
def generate_token_sequence(
self,
*,
tokenizer: PreTrainedTokenizerBase,
prefix_token_ids: list[int],
prefix_len: int,
vocab_size: int,
input_len: int,
offset: int,
index: int,
) -> tuple[str, int]:
"""
Returns (prompt, total_input_len).
NOTE: After decoding the prompt we have to encode and decode it again.
This is done because in some cases N consecutive tokens
give a string tokenized into != N number of tokens.
For example for GPT2Tokenizer:
[6880, 6881] -> ['Ġcalls', 'here'] ->
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
To avoid uncontrolled change of the prompt length,
the encoded sequence is truncated before being decode again.
"""
# Build the inner sequence by sampling sequentially from the vocab
inner_seq = ((offset + index + np.arange(input_len))
% vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
# Decode, then re-encode and truncate to preserve token count invariants
prompt = tokenizer.decode(token_sequence)
total_input_len = prefix_len + int(input_len)
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:total_input_len]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = len(re_encoded_sequence)
return prompt, total_input_len
# -----------------------------------------------------------------------------
# MultiModalDataset Implementation
# -----------------------------------------------------------------------------
class RandomMultiModalDataset(RandomDataset):
"""
Synthetic multimodal dataset (text + images) that extends RandomDataset.
Status:
- Images: supported via synthetic RGB data.
- Video: not yet supported (TODO: implement video generation method).
- Audio: not yet supported.
Sampling overview:
1) Number of items per request is sampled uniformly from the integer range
[floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
The maximum is further clamped to the sum of per-modality limits.
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
mapping (height, width, num_frames) → probability. We treat
`num_frames`=1 as image and and `num_frames` > 1 as video.
Entries with zero probability are removed and the rest are renormalized
to sum to 1.
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
When a modality reaches its cap, all of its buckets are excluded and the
remaining probabilities are renormalized.
Example bucket configuration:
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
- Two image buckets (`num_frames`=1) and one video bucket
(`num_frames`=16).
OBS.: Only image sampling is supported for now.
"""
IS_MULTIMODAL = True
# NOTE: video sampling is WIP. Setting it to 0.
DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0}
DEFAULT_BASE_ITEMS_PER_REQUEST = 1
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
DEFAULT_MM_ITEM_BUCKET_CONFIG = {
(256, 256, 1): 0.5,
(720, 1280, 1): 0.5,
(720, 1280, 16): 0.0,
}
DEFAULT_ENABLE_MULTIMODAL_CHAT = False
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
"""Generate synthetic PIL image with random RGB values.
NOTE: iid pixel sampling results in worst-case compression
(good for stressing I/O), but very unlike real photos.
We could consider a “low-freq” mode (e.g., noise blur)
to emulate network realism instead of max stress.
"""
random_pixels = self._rng.integers(
0,
256,
(height, width, 3),
dtype=np.uint8,
)
return Image.fromarray(random_pixels)
def generate_synthetic_video(self, width: int,
height: int,
num_frames: int) -> Any:
"""Generate synthetic video with random values.
TODO: Finish this method.
"""
raise NotImplementedError("Video sampling is WIP.")
def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
"""Map the configuration to the modality."""
if config[-1] == 1:
return "image"
elif config[-1] > 1:
return "video"
else:
raise ValueError(f"Invalid multimodal item configuration: {config}")
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
float]) -> dict[tuple[int, int, int], float]:
"""
Remove zero probability entries
and normalize the bucket config to sum to 1.
"""
# Raise error if value is negative
if any(v < 0 for v in bucket_config.values()):
raise ValueError("Bucket config values must be non-negative.")
# Remove zero probability entries
bucket_config = {k: v for k, v in bucket_config.items() if v > 0}
# if bucket config is empty, raise error
if not bucket_config:
raise ValueError("Got invalid bucket config. "
"Bucket config values must be non-zero.")
# Normalize the remaining bucket config to sum to 1
total = sum(bucket_config.values())
return {k: v / total for k, v in bucket_config.items()}
def generate_mm_item(self,
mm_item_config: tuple[int, int, int],
) -> Mapping[str, Any]:
"""
Create synthetic images and videos and
apply process_image/process_video respectively.
This follows the OpenAI API chat completions
https://github.com/openai/openai-python
"""
if self.map_config_to_modality(mm_item_config) == "image":
return process_image(self.generate_synthetic_image(
mm_item_config[1],
mm_item_config[0]))
elif self.map_config_to_modality(mm_item_config) == "video":
return process_video(self.generate_synthetic_video(
mm_item_config[1],
mm_item_config[0],
mm_item_config[2]))
else:
raise ValueError(f"Invalid multimodal item configuration: "
f"{mm_item_config}")
def get_mm_item_sampling_params(
self,
base_items_per_request: int,
num_mm_items_range_ratio: float,
limit_mm_per_prompt: dict[str, int],
bucket_config: dict[tuple[int, int, int], float],
) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]:
"""
Get the sampling parameters for the multimodal items.
"""
# Enforce num_mm_items_range_ratio <= 1
if not (0.0 <= num_mm_items_range_ratio <= 1.0):
raise ValueError("num_mm_items_range_ratio must be in [0, 1].")
# Ensure modalities to sample are in limit_mm_per_prompt
for k, v in bucket_config.items():
# get modality from bucket config
modality = self.map_config_to_modality(k)
if modality not in limit_mm_per_prompt:
raise ValueError(f"Modality {modality} is not in "
f"limit_mm_per_prompt: "
f"{limit_mm_per_prompt.keys()}")
# Remove zero probability entries
# and normalize bucket config to sum to 1
bucket_config = self.normalize_bucket_config(bucket_config)
logger.info(
"Normalized bucket config: %s", bucket_config,
)
# Only consider limit per prompt for modalities in bucket config
allowed_modalities = {self.map_config_to_modality(cfg)
for cfg in bucket_config}
limit_mm_per_prompt = {
k: v for k, v in limit_mm_per_prompt.items()
if k in allowed_modalities}
if not limit_mm_per_prompt:
raise ValueError("No valid limits for modalities present in "
"bucket_config.")
logger.info(
"Updated mm-limit-per-prompt: %s", limit_mm_per_prompt,
)
# Get max and min num mm items and ensure
# it is at most the sum of limit_mm_per_prompt for all modalities
max_num_mm_items = min(
sum(limit_mm_per_prompt.values()),
math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio))
)
# Ensure min num mm items is at least 0
min_num_mm_items = max(
0,
math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
)
# Raise error if min num mm items is greater than max num mm items
if min_num_mm_items > max_num_mm_items:
raise ValueError(f"Min num mm items is greater than max mm items: "
f"{min_num_mm_items} > {max_num_mm_items}")
logger.info(
"Sampling number of multimodal items from [%s, %s]",
min_num_mm_items, max_num_mm_items,
)
return (
min_num_mm_items,
max_num_mm_items,
limit_mm_per_prompt,
bucket_config,
)
def get_mm_item_iterator(
self,
min_num_mm_items: int,
max_num_mm_items: int,
bucket_config: dict[tuple[int, int, int], float],
limit_mm_per_prompt: dict[str, int],
) -> Iterator[tuple[int,int, int]]:
"""
Iterator over the multimodal items for each request
whose size is between min_num_mm_items and max_num_mm_items.
Loop over the bucket config and sample a multimodal item.
Loop until the number of multimodal items sampled is equal to
request_num_mm_items or limit of multimodal items per prompt
for all modalities is reached.
Note:
- This function operates on a per-request shallow copy of
`bucket_config` (tuple->float). The original dict passed to
`sample` is not mutated. If this ever changes, a test
is implemented and will fail.
"""
# Get the number of multimodal items to sample
request_num_mm_items = int(
self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
)
# If request_num_mm_items is 0, yield an empty iterator
if request_num_mm_items == 0:
return
# Initialize modality counters
modality_counter = {self.map_config_to_modality(k): 0
for k in bucket_config}
# Copy the bucket config to avoid modifying the original
bucket_config_copy = bucket_config.copy()
# Loop over the number of multimodal items to sample
while sum(modality_counter.values()) < request_num_mm_items:
# Sample a multimodal item config
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
p=list(bucket_config_copy.values()))
modality = self.map_config_to_modality(mm_item_config)
# Check that modality count is less than limit per prompt
if modality_counter[modality] < limit_mm_per_prompt[modality]:
modality_counter[modality] += 1
yield (
mm_item_config
)
else:
# If the counter is greater than the limit per prompt
# set all multimodal items of this modality to 0
for k, v in bucket_config_copy.items():
if self.map_config_to_modality(k) == modality:
bucket_config_copy[k] = 0
# If all configs are 0, break the loop
# This should not happen as request_num_mm_items is at most
# the sum of limit_mm_per_prompt for all modalities
if all(v == 0 for v in bucket_config_copy.values()):
logger.warning("Exhausted all multimodal items "
"of modality %s",
modality)
break
# Renormalize the bucket config
bucket_config_copy = self.normalize_bucket_config(
bucket_config_copy)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN,
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN,
limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
bucket_config: dict[tuple[int, int, int], float] =
DEFAULT_MM_ITEM_BUCKET_CONFIG,
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
**kwargs,
) -> list[SampleRequest]:
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
# and probability is non-zero.
if any(self.map_config_to_modality(cfg) == "video" and p > 0
for cfg, p in bucket_config.items()):
raise NotImplementedError("Video sampling not implemented; "
"set its probability to 0.")
# Get the sampling parameters for the dataset
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer
)
(
min_num_mm_items,
max_num_mm_items,
limit_mm_per_prompt,
bucket_config,
) = self.get_mm_item_sampling_params(
base_items_per_request,
num_mm_items_range_ratio,
limit_mm_per_prompt,
bucket_config,
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
# Add synthetic multimodal items to each request
mm_requests = []
for i in range(num_requests):
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
total_input_len = prefix_len + int(input_lens[i])
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:total_input_len]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = len(re_encoded_sequence)
requests.append(
SampleRequest(
prompt, total_input_len = self.generate_token_sequence(
tokenizer=tokenizer,
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
vocab_size=vocab_size,
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
)
# Get multimodal item iterator for a given request
mm_item_iterator = self.get_mm_item_iterator(
min_num_mm_items,
max_num_mm_items,
bucket_config,
limit_mm_per_prompt,
)
mm_content = cast(list[dict[str, Any]], [
self.generate_mm_item(mm_item_config)
for mm_item_config in mm_item_iterator
])
if enable_multimodal_chat:
# NOTE: For now this option is only provided for completeness
# given that the serve.py benchmark currently does not use it.
mm_chat_prompt: Any = prompt
mm_chat_prompt = self.apply_multimodal_chat_transformation(
prompt, mm_content)
sample_request = SampleRequest(
prompt=mm_chat_prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
multi_modal_data=None,
request_id=request_id_prefix + str(i),
)
else:
sample_request = SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
))
return requests
multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
)
mm_requests.append(sample_request)
return mm_requests
# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
......@@ -432,9 +958,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
samples: list = []
ind = 0
for entry in self.data:
if len(samples) >= num_requests:
break
......@@ -455,9 +983,10 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check=output_len
is not None):
continue
# TODO: Also support ShareGPT4Video.
if image_path := entry.get("image"):
mm_content = process_image(image_path)
elif video_path := entry.get("video"):
mm_content = process_video(video_path)
else:
mm_content = None
if enable_multimodal_chat:
......@@ -470,8 +999,10 @@ class ShareGPTDataset(BenchmarkDataset):
expected_output_len=new_output_len,
lora_request=lora_request,
multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
))
self.maybe_oversample_requests(samples, num_requests)
ind += 1
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples
......@@ -488,8 +1019,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
type=str,
default="random",
choices=[
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
"prefix_repetition"
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
"custom", "prefix_repetition"
],
help="Name of the dataset to benchmark on.",
)
......@@ -589,6 +1120,103 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."),
)
random_group.add_argument(
"--random-batch-size",
type=int,
default=1,
help=("Batch size for random sampling. "
"Only used for embeddings benchmark."),
)
# random multimodal dataset options
random_mm_group = parser.add_argument_group(
"random multimodal dataset options extended from random dataset")
random_mm_group.add_argument(
"--random-mm-base-items-per-request",
type=int,
default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST,
help=(
"Base number of multimodal items per request for random-mm. "
"Actual per-request count is sampled around this base using "
"--random-mm-num-mm-items-range-ratio."
),
)
random_mm_group.add_argument(
"--random-mm-num-mm-items-range-ratio",
type=float,
default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
help=(
"Range ratio r in [0, 1] for sampling items per request. "
"We sample uniformly from the closed integer range "
"[floor(n*(1-r)), ceil(n*(1+r))] "
"where n is the base items per request. "
"r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped "
"to the sum of per-modality limits from "
"--random-mm-limit-mm-per-prompt. "
"An error is raised if the computed min exceeds the max."
),
)
random_mm_group.add_argument(
"--random-mm-limit-mm-per-prompt",
type=json.loads,
default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT,
help=(
"Per-modality hard caps for items attached per request, e.g. "
"'{\"image\": 3, \"video\": 0}'. The sampled per-request item "
"count is clamped to the sum of these limits. When a modality "
"reaches its cap, its buckets are excluded and probabilities are "
"renormalized."
"OBS.: Only image sampling is supported for now."
),
)
def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]:
# If already a dict (e.g., programmatic call), normalize keys
def normalize(d: dict) -> dict[tuple[int, int, int], float]:
out: dict[tuple[int, int, int], float] = {}
for k, val in d.items():
key = k
if isinstance(key, str):
with suppress(Exception):
key = ast.literal_eval(key)
if not (isinstance(key, tuple) and len(key) == 3
and all(isinstance(x, int) for x in key)):
raise ValueError(
f"Invalid bucket key {k!r}. Expected tuple (H, W, T)."
)
out[(int(key[0]), int(key[1]), int(key[2]))] = float(val)
return out
if isinstance(v, dict):
return normalize(v)
if isinstance(v, str):
# Python literal (supports tuple keys)
parsed = ast.literal_eval(v)
if not isinstance(parsed, dict):
raise ValueError("Bucket config must parse to a dict.")
return normalize(parsed)
raise ValueError("Unsupported value for --random-mm-bucket-config.")
random_mm_group.add_argument(
"--random-mm-bucket-config",
type=_parse_mm_bucket_config,
default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG,
help=(
"The bucket config is a dictionary mapping a multimodal item"
"sampling configuration to a probability."
"Currently allows for 2 modalities: images and videos. "
"An bucket key is a tuple of (height, width, num_frames)"
"The value is the probability of sampling that specific item. "
"Example: "
"--random-mm-bucket-config "
"{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} "
"First item: images with resolution 256x256 w.p. 0.5"
"Second item: images with resolution 720x1280 w.p. 0.4 "
"Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1"
"OBS.: If the probabilities do not sum to 1, they are normalized."
"OBS bis.: Only image sampling is supported for now."
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
......@@ -647,6 +1275,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer,
output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template,
request_id_prefix=args.request_id_prefix,
)
elif args.dataset_name == "sonnet":
......@@ -660,6 +1289,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=False,
request_id_prefix=args.request_id_prefix,
)
else:
assert tokenizer.chat_template or tokenizer.default_chat_template, (
......@@ -671,6 +1301,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=True,
request_id_prefix=args.request_id_prefix,
)
elif args.dataset_name == "hf":
......@@ -716,10 +1347,11 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
"openai-chat",
"openai-audio",
]:
# multi-modal benchmark is only available on OpenAI Chat backend.
# multi-modal benchmark is only available on OpenAI Chat
# endpoint-type.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' and "
"'openai-audio' backend.")
"'openai-audio' endpoint-type.")
input_requests = dataset_class(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
......@@ -730,31 +1362,54 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
)
else:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"sharegpt":
lambda: ShareGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
"burstgpt":
lambda: BurstGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
"random":
lambda: RandomDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
),
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,
),
"random": lambda: RandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
input_len=args.random_input_len,
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
batchsize=args.random_batch_size,
),
"random-mm":
lambda: RandomMultiModalDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
range_ratio=args.random_range_ratio,
input_len=args.random_input_len,
output_len=args.random_output_len,
base_items_per_request=args.random_mm_base_items_per_request,
limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt,
num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
bucket_config=args.random_mm_bucket_config,
request_id_prefix=args.request_id_prefix,
),
"prefix_repetition":
lambda: PrefixRepetitionRandomDataset(
......@@ -766,10 +1421,18 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
suffix_len=args.prefix_repetition_suffix_len,
num_prefixes=args.prefix_repetition_num_prefixes,
output_len=args.prefix_repetition_output_len,
request_id_prefix=args.request_id_prefix,
),
}
try:
# Enforce endpoint compatibility for multimodal datasets.
if args.dataset_name == "random-mm" and args.endpoint_type not in [
"openai-chat"]:
raise ValueError(
"Multi-modal content (images) is only supported on "
"'openai-chat' backend."
)
input_requests = dataset_mapping[args.dataset_name]()
except KeyError as err:
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
......@@ -839,10 +1502,11 @@ class CustomDataset(BenchmarkDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
sampled_requests = []
for item in self.data:
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = item["prompt"]
......@@ -864,8 +1528,10 @@ class CustomDataset(BenchmarkDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -909,6 +1575,7 @@ class SonnetDataset(BenchmarkDataset):
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
# Calculate average token length for a poem line.
......@@ -934,6 +1601,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines = self.data[:num_prefix_lines]
samples = []
ind = 0
while len(samples) < num_requests:
extra_lines = random.choices(self.data,
k=num_input_lines - num_prefix_lines)
......@@ -949,7 +1617,9 @@ class SonnetDataset(BenchmarkDataset):
if return_prompt_formatted else prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(ind),
))
ind += 1
return samples
......@@ -1000,6 +1670,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests: int,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
samples = []
......@@ -1020,6 +1691,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len,
expected_output_len=output_len,
lora_request=lora_req,
request_id=request_id_prefix + str(i),
))
return samples
......@@ -1075,11 +1747,13 @@ class ConversationDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs) -> list:
# Filter examples with at least 2 conversations
filtered_data = self.data.filter(
lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
ind = 0
dynamic_output = output_len is None
for item in filtered_data:
......@@ -1111,8 +1785,11 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1141,12 +1818,13 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
......@@ -1168,8 +1846,10 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1198,15 +1878,18 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = f"{item['input']}\n\n{item['instruction']} Just output \
the code, do not include any explanation."
prompt = (
f"{item['input']}\n\n{item['instruction']} Just output "
"the code, do not include any explanation."
)
# apply template
prompt = tokenizer.apply_chat_template(
......@@ -1224,8 +1907,10 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1255,13 +1940,14 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = item["turns"][0]
......@@ -1282,8 +1968,10 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1305,8 +1993,10 @@ class AIMODataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs) -> list:
sampled_requests = []
ind = 0
dynamic_output = output_len is None
for item in self.data:
......@@ -1331,8 +2021,12 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=None,
request_id=request_id_prefix + str(ind),
))
self.maybe_oversample_requests(sampled_requests, num_requests)
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1403,13 +2097,14 @@ class NextEditPredictionDataset(HuggingFaceDataset):
}
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
request_id_prefix: str = "",
**kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
self.dataset_path)
if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = []
for sample in self.data:
for i, sample in enumerate(self.data):
sample = formatting_prompt_func(sample)
samples.append(
SampleRequest(
......@@ -1417,10 +2112,11 @@ class NextEditPredictionDataset(HuggingFaceDataset):
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids),
request_id=request_id_prefix + str(i),
))
if len(samples) >= num_requests:
break
self.maybe_oversample_requests(samples, num_requests)
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples
......@@ -1470,6 +2166,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = (output_len
......@@ -1477,6 +2174,7 @@ class ASRDataset(HuggingFaceDataset):
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
ind = 0
skipped = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
......@@ -1496,7 +2194,9 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
))
ind += 1
if skipped:
logger.warning(
"%d samples discarded from dataset due to"
......@@ -1504,7 +2204,8 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.",
skipped,
)
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1541,11 +2242,13 @@ class MLPerfDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
# Force dynamic output length based on reference completion.
dynamic_output = output_len is None
sampled_requests: list[SampleRequest] = []
ind = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
......@@ -1580,10 +2283,13 @@ class MLPerfDataset(HuggingFaceDataset):
prompt=prompt_formatted,
prompt_len=prompt_len,
expected_output_len=expected_output_len,
request_id=request_id_prefix + str(ind),
)
)
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
......@@ -1616,6 +2322,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
suffix_len: int = DEFAULT_SUFFIX_LEN,
num_prefixes: int = DEFAULT_NUM_PREFIXES,
output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
vocab_size = tokenizer.vocab_size
......
......@@ -9,7 +9,7 @@ import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Union
import aiohttp
from tqdm.asyncio import tqdm
......@@ -28,9 +28,10 @@ class RequestFuncInput:
model_name: Optional[str] = None
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
multi_modal_content: Optional[dict | list[dict]] = None
multi_modal_content: Optional[Union[dict, list[dict]]] = None
ignore_eos: bool = False
language: Optional[str] = None
request_id: Optional[str] = None
@dataclass
......@@ -68,8 +69,8 @@ async def async_request_openai_completions(
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"repetition_penalty": 1.0,
......@@ -87,6 +88,8 @@ async def async_request_openai_completions(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
......@@ -132,7 +135,7 @@ async def async_request_openai_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
......@@ -210,6 +213,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
......@@ -249,7 +254,7 @@ async def async_request_openai_chat_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
......@@ -311,6 +316,8 @@ async def async_request_openai_audio(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
# Send audio file
def to_bytes(y, sr):
......@@ -387,12 +394,61 @@ async def async_request_openai_audio(
return output
async def async_request_openai_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None,
):
api_url = request_func_input.api_url
assert api_url.endswith(
"embeddings"
), "OpenAI Embeddings API URL must end with 'embeddings'."
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
payload = {
"model": request_func_input.model,
"input": request_func_input.prompt,
}
output = RequestFuncOutput()
st = time.perf_counter()
try:
async with session.post(
url=api_url,
headers=headers,
json=payload
) as response:
if response.status == 200:
output.latency = time.perf_counter() - st
data = await response.json()
output.success = True
output.generated_text = ""
output.prompt_len = data.get(
"usage", {}).get(
"prompt_tokens", 0)
else:
output.success = False
output.error = response.reason or ""
except Exception as e:
output.success = False
output.error = str(e)
if pbar:
pbar.update(1)
return output
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS = {
"vllm": async_request_openai_completions,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio,
"openai-embeddings": async_request_openai_embeddings,
}
OPENAI_COMPATIBLE_BACKENDS = [
......
......@@ -54,7 +54,12 @@ class InfEncoder(json.JSONEncoder):
def clear_inf(self, o: Any):
if isinstance(o, dict):
return {k: self.clear_inf(v) for k, v in o.items()}
return {
str(k)
if not isinstance(k, (str, int, float, bool, type(None)))
else k: self.clear_inf(v)
for k, v in o.items()
}
elif isinstance(o, list):
return [self.clear_inf(v) for v in o]
elif isinstance(o, float) and math.isinf(o):
......
......@@ -4,7 +4,7 @@ r"""Benchmark online serving throughput.
On the server side, run one of the following commands
to launch the vLLM OpenAI API server:
vllm serve <your_model> <engine arguments>
vllm serve <your_model> <engine arguments>
On the client side, run:
vllm bench serve \
......@@ -26,6 +26,7 @@ import warnings
from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Literal, Optional
import aiohttp
......@@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
class TaskType(Enum):
GENERATION = "generation"
EMBEDDING = "embedding"
@dataclass
class BenchmarkMetrics:
completed: int
......@@ -75,6 +81,16 @@ class BenchmarkMetrics:
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
@dataclass
class EmbedBenchmarkMetrics:
completed: int
total_input: int
request_throughput: float
total_token_throughput :float
mean_e2el_ms: float
std_e2el_ms: float
median_e2el_ms: float
percentiles_e2el_ms: float
def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
......@@ -146,11 +162,11 @@ async def get_request(
delay_ts = []
for request_index, request in enumerate(input_requests):
current_request_rate = _get_current_request_rate(ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
request_index,
total_requests,
request_rate)
ramp_up_start_rps,
ramp_up_end_rps,
request_index,
total_requests,
request_rate)
request_rates.append(current_request_rate)
if current_request_rate == float("inf"):
delay_ts.append(0)
......@@ -160,7 +176,7 @@ async def get_request(
# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
delay_ts.append(np.random.gamma(shape=burstiness, scale=theta))
# Calculate the cumulative delay time from the first sent out requests.
for i in range(1, len(delay_ts)):
delay_ts[i] += delay_ts[i - 1]
......@@ -170,11 +186,11 @@ async def get_request(
# logic would re-scale delay time to ensure the final delay_ts
# align with target_total_delay_s.
#
# NOTE: If we simply accumulate the random delta values
# from the gamma distribution, their sum would have 1-2% gap
# NOTE: If we simply accumulate the random delta values
# from the gamma distribution, their sum would have 1-2% gap
# from target_total_delay_s. The purpose of the following logic is to
# close the gap for stablizing the throughput data
# from different random seeds.
# close the gap for stablizing the throughput data
# from different random seeds.
target_total_delay_s = total_requests / request_rate
normalize_factor = target_total_delay_s / delay_ts[-1]
delay_ts = [delay * normalize_factor for delay in delay_ts]
......@@ -189,6 +205,51 @@ async def get_request(
yield request, request_rates[request_index]
def calculate_metrics_for_embeddings(
outputs: list[RequestFuncOutput],
dur_s: float,
selected_percentiles: list[float]
) -> EmbedBenchmarkMetrics:
"""Calculate the metrics for the embedding requests.
Args:
outputs: The outputs of the requests.
dur_s: The duration of the benchmark.
selected_percentiles: The percentiles to select.
Returns:
The calculated benchmark metrics.
"""
total_input = 0
completed = 0
e2els: list[float] = []
for i in range(len(outputs)):
if outputs[i].success:
e2els.append(outputs[i].latency)
completed += 1
total_input += outputs[i].prompt_len
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
metrics = EmbedBenchmarkMetrics(
completed=completed,
total_input=total_input,
request_throughput=completed / dur_s,
total_token_throughput=total_input / dur_s,
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles
],
)
return metrics
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
......@@ -334,8 +395,16 @@ async def benchmark(
ramp_up_end_rps: Optional[int] = None,
ready_check_timeout_sec: int = 600,
):
task_type = (
TaskType.EMBEDDING
if api_url.endswith("/v1/embeddings")
else TaskType.GENERATION
)
if endpoint_type in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
if task_type == TaskType.EMBEDDING:
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
else:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
else:
raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
......@@ -421,8 +490,8 @@ async def benchmark(
if profile_output.success:
print("Profiler started")
distribution = ("Poisson process" if burstiness == 1.0
else "Gamma distribution")
distribution = ("Poisson process" if burstiness == 1.0
else "Gamma distribution")
if ramp_up_strategy is not None:
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
......@@ -449,7 +518,7 @@ async def benchmark(
session=session,
pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input,
return await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar)
......@@ -478,11 +547,12 @@ async def benchmark(
"timestamp": timestamp
})
last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content = (
prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt,
request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
request.request_id,
)
req_model_id, req_model_name = model_id, model_name
if lora_modules:
......@@ -498,7 +568,8 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_body=extra_body)
extra_body=extra_body,
request_id=request_id,)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
......@@ -511,14 +582,22 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict,
)
if task_type == TaskType.GENERATION:
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict,
)
else:
metrics = calculate_metrics_for_embeddings(
outputs=outputs,
dur_s=benchmark_duration,
selected_percentiles=selected_percentiles,
)
actual_output_lens = 0
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
......@@ -527,39 +606,55 @@ async def benchmark(
max_concurrency))
if request_rate != float('inf'):
print("{:<40} {:<10.2f}".format("Request rate configured (RPS):",
request_rate ))
request_rate))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
if isinstance(metrics, BenchmarkMetrics):
print("{:<40} {:<10}".format(
"Total generated tokens:", metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
if isinstance(metrics, BenchmarkMetrics):
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput":
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
if isinstance(metrics, BenchmarkMetrics):
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput":
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
else:
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"request_throughput": metrics.request_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"errors": [output.error for output in outputs],
}
if rps_change_events:
result["rps_change_events"] = rps_change_events
......@@ -596,10 +691,11 @@ async def benchmark(
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
if task_type == TaskType.GENERATION:
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric(
"tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
print("=" * 50)
......@@ -730,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.")
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--model",
......@@ -741,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--tokenizer",
type=str,
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
......@@ -865,6 +961,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
parser.add_argument(
"--request-id-prefix",
type=str,
required=False,
default="benchmark-serving",
help="Specify the prefix of request id.",
)
sampling_group = parser.add_argument_group("sampling parameters")
sampling_group.add_argument(
......@@ -958,6 +1062,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))
async def main_async(args: argparse.Namespace) -> dict[str, Any]:
print(args)
random.seed(args.seed)
......@@ -1036,32 +1141,32 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
gc.freeze()
benchmark_result = await benchmark(
endpoint_type=args.endpoint_type,
api_url=api_url,
base_url=base_url,
model_id=model_id,
model_name=model_name,
tokenizer=tokenizer,
input_requests=input_requests,
logprobs=args.logprobs,
request_rate=args.request_rate,
burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps,
ramp_up_end_rps=args.ramp_up_end_rps,
ready_check_timeout_sec=args.ready_check_timeout_sec,
)
endpoint_type=args.endpoint_type,
api_url=api_url,
base_url=base_url,
model_id=model_id,
model_name=model_name,
tokenizer=tokenizer,
input_requests=input_requests,
logprobs=args.logprobs,
request_rate=args.request_rate,
burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps,
ramp_up_end_rps=args.ramp_up_end_rps,
ready_check_timeout_sec=args.ready_check_timeout_sec,
)
# Save config and results to json
result_json: dict[str, Any] = {}
......@@ -1088,7 +1193,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate
< float("inf") else "inf")
< float("inf") else "inf")
result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency
......@@ -1122,7 +1227,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if args.max_concurrency is not None else "")
label = label or endpoint_type
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
......@@ -1139,4 +1244,4 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name)
return result_json
\ No newline at end of file
return result_json
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment