Unverified Commit 009d9e75 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Convert `benchmarks` to `ruff format` (#18068)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent b922c2eb
# This local pyproject file is part of the migration from yapf to ruff format. # This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the # It uses the same core rules as the main pyproject.toml file, but with the
# following differences: # following differences:
# - isort profile is set to black
# - ruff line length is overridden to 88 # - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed # - deprecated typing ignores (UP006, UP035) have been removed
[tool.isort]
profile = "black"
[tool.ruff] [tool.ruff]
line-length = 88 line-length = 88
exclude = [ exclude = [
......
...@@ -17,7 +17,7 @@ repos: ...@@ -17,7 +17,7 @@ repos:
- id: ruff - id: ruff
args: [--output-format, github, --fix] args: [--output-format, github, --fix]
- id: ruff-format - id: ruff-format
files: ^(.buildkite).* files: ^(.buildkite|benchmarks)/.*
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.4.1 rev: v2.4.1
hooks: hooks:
...@@ -28,8 +28,6 @@ repos: ...@@ -28,8 +28,6 @@ repos:
rev: 6.0.1 rev: 6.0.1
hooks: hooks:
- id: isort - id: isort
# necessary during the transition from yapf to ruff format
args: [--resolve-all-configs, --config-root, .]
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v20.1.3 rev: v20.1.3
hooks: hooks:
......
...@@ -12,8 +12,7 @@ from typing import Optional, Union ...@@ -12,8 +12,7 @@ from typing import Optional, Union
import aiohttp import aiohttp
import huggingface_hub.constants import huggingface_hub.constants
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
PreTrainedTokenizerFast)
# NOTE(simon): do not import vLLM here so the benchmark script # NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed. # can run without vLLM installed.
...@@ -43,8 +42,7 @@ class RequestFuncOutput: ...@@ -43,8 +42,7 @@ class RequestFuncOutput:
latency: float = 0.0 latency: float = 0.0
output_tokens: int = 0 output_tokens: int = 0
ttft: float = 0.0 # Time to first token ttft: float = 0.0 # Time to first token
itl: list[float] = field( itl: list[float] = field(default_factory=list) # list of inter-token latencies
default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0 prompt_len: int = 0
error: str = "" error: str = ""
...@@ -57,8 +55,9 @@ async def async_request_tgi( ...@@ -57,8 +55,9 @@ async def async_request_tgi(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
params = { params = {
"max_new_tokens": request_func_input.output_len, "max_new_tokens": request_func_input.output_len,
"do_sample": True, "do_sample": True,
...@@ -105,8 +104,7 @@ async def async_request_tgi( ...@@ -105,8 +104,7 @@ async def async_request_tgi(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
...@@ -133,8 +131,9 @@ async def async_request_trt_llm( ...@@ -133,8 +131,9 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = { payload = {
"accumulate_tokens": True, "accumulate_tokens": True,
"text_input": request_func_input.prompt, "text_input": request_func_input.prompt,
...@@ -159,8 +158,7 @@ async def async_request_trt_llm( ...@@ -159,8 +158,7 @@ async def async_request_trt_llm(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
"data:")
data = json.loads(chunk) data = json.loads(chunk)
output.generated_text += data["text_output"] output.generated_text += data["text_output"]
...@@ -172,8 +170,7 @@ async def async_request_trt_llm( ...@@ -172,8 +170,7 @@ async def async_request_trt_llm(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
...@@ -197,9 +194,9 @@ async def async_request_deepspeed_mii( ...@@ -197,9 +194,9 @@ async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = { payload = {
"model": request_func_input.model, "model": request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
...@@ -217,19 +214,21 @@ async def async_request_deepspeed_mii( ...@@ -217,19 +214,21 @@ async def async_request_deepspeed_mii(
st = time.perf_counter() st = time.perf_counter()
try: try:
async with session.post(url=request_func_input.api_url, async with session.post(
json=payload) as response: url=request_func_input.api_url, json=payload
) as response:
if response.status == 200: if response.status == 200:
parsed_resp = await response.json() parsed_resp = await response.json()
output.latency = time.perf_counter() - st output.latency = time.perf_counter() - st
if "choices" in parsed_resp: if "choices" in parsed_resp:
output.generated_text = parsed_resp["choices"][0][ output.generated_text = parsed_resp["choices"][0]["text"]
"text"]
elif "text" in parsed_resp: elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0] output.generated_text = parsed_resp["text"][0]
else: else:
output.error = ("Unexpected response format: " output.error = (
"neither 'choices' nor 'text' found") "Unexpected response format: "
"neither 'choices' nor 'text' found"
)
output.success = False output.success = False
output.success = True output.success = True
else: else:
...@@ -250,15 +249,17 @@ async def async_request_openai_completions( ...@@ -250,15 +249,17 @@ async def async_request_openai_completions(
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("completions", "profile")), (
("completions", "profile") "OpenAI Completions API URL must end with 'completions' or 'profile'."
), "OpenAI Completions API URL must end with 'completions' or 'profile'." )
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name
else request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
"temperature": 0.0, "temperature": 0.0,
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
...@@ -273,9 +274,7 @@ async def async_request_openai_completions( ...@@ -273,9 +274,7 @@ async def async_request_openai_completions(
payload["ignore_eos"] = request_func_input.ignore_eos payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -284,8 +283,9 @@ async def async_request_openai_completions( ...@@ -284,8 +283,9 @@ async def async_request_openai_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(
headers=headers) as response: url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
first_chunk_received = False first_chunk_received = False
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
...@@ -293,8 +293,7 @@ async def async_request_openai_completions( ...@@ -293,8 +293,7 @@ async def async_request_openai_completions(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
data = json.loads(chunk) data = json.loads(chunk)
...@@ -314,21 +313,20 @@ async def async_request_openai_completions( ...@@ -314,21 +313,20 @@ async def async_request_openai_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += text or "" generated_text += text or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else: else:
output.success = False output.success = False
output.error = ( output.error = (
"Never received a valid chunk to calculate TTFT." "Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!") "This response will be marked as failed!"
)
output.generated_text = generated_text output.generated_text = generated_text
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
else: else:
...@@ -349,23 +347,22 @@ async def async_request_openai_chat_completions( ...@@ -349,23 +347,22 @@ async def async_request_openai_chat_completions(
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("chat/completions", "profile")), (
("chat/completions", "profile") "OpenAI Chat Completions API URL must end with 'chat/completions'."
), "OpenAI Chat Completions API URL must end with 'chat/completions'." )
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content: if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content) content.append(request_func_input.multi_modal_content)
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name
else request_func_input.model,
"messages": [ "messages": [
{ {"role": "user", "content": content},
"role": "user",
"content": content
},
], ],
"temperature": 0.0, "temperature": 0.0,
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
...@@ -391,16 +388,16 @@ async def async_request_openai_chat_completions( ...@@ -391,16 +388,16 @@ async def async_request_openai_chat_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(
headers=headers) as response: url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
...@@ -414,13 +411,11 @@ async def async_request_openai_chat_completions( ...@@ -414,13 +411,11 @@ async def async_request_openai_chat_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
...@@ -446,25 +441,28 @@ async def async_request_openai_audio( ...@@ -446,25 +441,28 @@ async def async_request_openai_audio(
) -> RequestFuncOutput: ) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep. # Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile import soundfile
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("transcriptions", "translations")), (
("transcriptions", "translations" "OpenAI Chat Completions API URL must end with 'transcriptions' "
)), "OpenAI Chat Completions API URL must end with 'transcriptions' " )
"or `translations`." "or `translations`."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name
else request_func_input.model,
"temperature": 0.0, "temperature": 0.0,
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
"stream": True, "stream": True,
"language": "en", "language": "en",
# Flattened due to multipart/form-data # Flattened due to multipart/form-data
"stream_include_usage": True, "stream_include_usage": True,
"stream_continuous_usage_stats": True "stream_continuous_usage_stats": True,
} }
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
...@@ -479,9 +477,9 @@ async def async_request_openai_audio( ...@@ -479,9 +477,9 @@ async def async_request_openai_audio(
buffer.seek(0) buffer.seek(0)
return buffer return buffer
with to_bytes(*request_func_input.multi_modal_content['audio']) as f: with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData() form = aiohttp.FormData()
form.add_field('file', f, content_type='audio/wav') form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items(): for key, value in payload.items():
form.add_field(key, str(value)) form.add_field(key, str(value))
...@@ -493,24 +491,22 @@ async def async_request_openai_audio( ...@@ -493,24 +491,22 @@ async def async_request_openai_audio(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, async with session.post(
data=form, url=api_url, data=form, headers=headers
headers=headers) as response: ) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
if choices := data.get("choices"): if choices := data.get("choices"):
content = choices[0]["delta"].get( content = choices[0]["delta"].get("content")
"content")
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = timestamp - st ttft = timestamp - st
...@@ -519,12 +515,14 @@ async def async_request_openai_audio( ...@@ -519,12 +515,14 @@ async def async_request_openai_audio(
# Decoding phase # Decoding phase
else: else:
output.itl.append( output.itl.append(
timestamp - most_recent_timestamp) timestamp - most_recent_timestamp
)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get(
"completion_tokens") "completion_tokens"
)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
...@@ -545,7 +543,7 @@ async def async_request_openai_audio( ...@@ -545,7 +543,7 @@ async def async_request_openai_audio(
def get_model(pretrained_model_name_or_path: str) -> str: def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true":
from modelscope import snapshot_download from modelscope import snapshot_download
from vllm.model_executor.model_loader.weight_utils import get_lock from vllm.model_executor.model_loader.weight_utils import get_lock
...@@ -556,7 +554,8 @@ def get_model(pretrained_model_name_or_path: str) -> str: ...@@ -556,7 +554,8 @@ def get_model(pretrained_model_name_or_path: str) -> str:
model_path = snapshot_download( model_path = snapshot_download(
model_id=pretrained_model_name_or_path, model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
)
return model_path return model_path
return pretrained_model_name_or_path return pretrained_model_name_or_path
...@@ -569,23 +568,23 @@ def get_tokenizer( ...@@ -569,23 +568,23 @@ def get_tokenizer(
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists( if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path): pretrained_model_name_or_path
pretrained_model_name_or_path = get_model( ):
pretrained_model_name_or_path) pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
if tokenizer_mode == "slow": if tokenizer_mode == "slow":
if kwargs.get("use_fast", False): if kwargs.get("use_fast", False):
raise ValueError( raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False kwargs["use_fast"] = False
if tokenizer_mode == "mistral": if tokenizer_mode == "mistral":
try: try:
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.transformers_utils.tokenizer import MistralTokenizer
except ImportError as e: except ImportError as e:
raise ImportError("MistralTokenizer requires vllm package.\n" raise ImportError(
"Please install it with `pip install vllm` " "MistralTokenizer requires vllm package.\n"
"to use mistral tokenizer mode.") from e "Please install it with `pip install vllm` "
return MistralTokenizer.from_pretrained( "to use mistral tokenizer mode."
str(pretrained_model_name_or_path)) ) from e
return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
else: else:
return AutoTokenizer.from_pretrained( return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -608,7 +607,7 @@ ASYNC_REQUEST_FUNCS = { ...@@ -608,7 +607,7 @@ ASYNC_REQUEST_FUNCS = {
} }
OPENAI_COMPATIBLE_BACKENDS = [ OPENAI_COMPATIBLE_BACKENDS = [
k for k, v in ASYNC_REQUEST_FUNCS.items() k
if v in (async_request_openai_completions, for k, v in ASYNC_REQUEST_FUNCS.items()
async_request_openai_chat_completions) if v in (async_request_openai_completions, async_request_openai_chat_completions)
] ]
...@@ -82,14 +82,12 @@ class BenchmarkDataset(ABC): ...@@ -82,14 +82,12 @@ class BenchmarkDataset(ABC):
self.dataset_path = dataset_path self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the # Set the random seed, ensuring that a None value is replaced with the
# default seed. # default seed.
self.random_seed = (random_seed self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
if random_seed is not None else self.DEFAULT_SEED)
self.data = None self.data = None
def apply_multimodal_chat_transformation( def apply_multimodal_chat_transformation(
self, self, prompt: str, mm_content: Optional[MultiModalDataDict] = None
prompt: str, ) -> list[dict]:
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
""" """
Transform a prompt and optional multimodal content into a chat format. Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation This method is used for chat models that expect a specific conversation
...@@ -111,8 +109,7 @@ class BenchmarkDataset(ABC): ...@@ -111,8 +109,7 @@ class BenchmarkDataset(ABC):
NotImplementedError: If a subclass does not implement this method. NotImplementedError: If a subclass does not implement this method.
""" """
# TODO (jenniferzhao): add support for downloading data # TODO (jenniferzhao): add support for downloading data
raise NotImplementedError( raise NotImplementedError("load_data must be implemented in subclasses.")
"load_data must be implemented in subclasses.")
def get_random_lora_request( def get_random_lora_request(
self, self,
...@@ -158,8 +155,9 @@ class BenchmarkDataset(ABC): ...@@ -158,8 +155,9 @@ class BenchmarkDataset(ABC):
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
@abstractmethod @abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase, def sample(
num_requests: int) -> list[SampleRequest]: self, tokenizer: PreTrainedTokenizerBase, num_requests: int
) -> list[SampleRequest]:
""" """
Abstract method to generate sample requests from the dataset. Abstract method to generate sample requests from the dataset.
...@@ -177,8 +175,9 @@ class BenchmarkDataset(ABC): ...@@ -177,8 +175,9 @@ class BenchmarkDataset(ABC):
""" """
raise NotImplementedError("sample must be implemented in subclasses.") raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest], def maybe_oversample_requests(
num_requests: int) -> None: self, requests: list[SampleRequest], num_requests: int
) -> None:
""" """
Oversamples the list of requests if its size is less than the desired Oversamples the list of requests if its size is less than the desired
number. number.
...@@ -189,11 +188,9 @@ class BenchmarkDataset(ABC): ...@@ -189,11 +188,9 @@ class BenchmarkDataset(ABC):
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
additional = random.choices(requests, additional = random.choices(requests, k=num_requests - len(requests))
k=num_requests - len(requests))
requests.extend(additional) requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", logger.info("Oversampled requests to reach %d total samples.", num_requests)
num_requests)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -218,14 +215,14 @@ def is_valid_sequence( ...@@ -218,14 +215,14 @@ def is_valid_sequence(
""" """
# Check for invalid conditions # Check for invalid conditions
prompt_too_short = prompt_len < min_len prompt_too_short = prompt_len < min_len
output_too_short = (not skip_min_output_len_check) and (output_len output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
< min_len)
prompt_too_long = prompt_len > max_prompt_len prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met # Return True if none of the invalid conditions are met
return not (prompt_too_short or output_too_short or prompt_too_long return not (
or combined_too_long) prompt_too_short or output_too_short or prompt_too_long or combined_too_long
)
@cache @cache
...@@ -257,28 +254,28 @@ def process_image(image: Any) -> Mapping[str, Any]: ...@@ -257,28 +254,28 @@ def process_image(image: Any) -> Mapping[str, Any]:
Raises: Raises:
ValueError: If the input is not a supported type. ValueError: If the input is not a supported type.
""" """
if isinstance(image, dict) and 'bytes' in image: if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image['bytes'])) image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = image.convert("RGB") image = image.convert("RGB")
with io.BytesIO() as image_data: with io.BytesIO() as image_data:
image.save(image_data, format="JPEG") image.save(image_data, format="JPEG")
image_base64 = base64.b64encode( image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
image_data.getvalue()).decode("utf-8")
return { return {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
"url": f"data:image/jpeg;base64,{image_base64}"
},
} }
if isinstance(image, str): if isinstance(image, str):
image_url = (image if image.startswith( image_url = (
("http://", "file://")) else f"file://{image}") image if image.startswith(("http://", "file://")) else f"file://{image}"
)
return {"type": "image_url", "image_url": {"url": image_url}} return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" raise ValueError(
" or str or dictionary with raw image bytes.") f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes."
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -318,8 +315,11 @@ class RandomDataset(BenchmarkDataset): ...@@ -318,8 +315,11 @@ class RandomDataset(BenchmarkDataset):
num_special_tokens = tokenizer.num_special_tokens_to_add() num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens real_input_len = input_len - num_special_tokens
prefix_token_ids = (np.random.randint( prefix_token_ids = (
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) np.random.randint(0, vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
# New sampling logic: [X * (1 - b), X * (1 + b)] # New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(real_input_len * (1 - range_ratio)) input_low = int(real_input_len * (1 - range_ratio))
...@@ -329,21 +329,17 @@ class RandomDataset(BenchmarkDataset): ...@@ -329,21 +329,17 @@ class RandomDataset(BenchmarkDataset):
# Add logging for debugging # Add logging for debugging
logger.info("Sampling input_len from [%s, %s]", input_low, input_high) logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
logger.info("Sampling output_len from [%s, %s]", output_low, logger.info("Sampling output_len from [%s, %s]", output_low, output_high)
output_high)
input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
input_lens = np.random.randint(input_low, output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
input_high + 1,
size=num_requests)
output_lens = np.random.randint(output_low,
output_high + 1,
size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests) offsets = np.random.randint(0, vocab_size, size=num_requests)
requests = [] requests = []
for i in range(num_requests): for i in range(num_requests):
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % inner_seq = (
vocab_size).tolist() (offsets[i] + i + np.arange(input_lens[i])) % vocab_size
).tolist()
token_sequence = prefix_token_ids + inner_seq token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence) prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again. # After decoding the prompt we have to encode and decode it again.
...@@ -354,8 +350,9 @@ class RandomDataset(BenchmarkDataset): ...@@ -354,8 +350,9 @@ class RandomDataset(BenchmarkDataset):
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length, # To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again. # the encoded sequence is truncated before being decode again.
re_encoded_sequence = tokenizer.encode( re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
prompt, add_special_tokens=False)[:input_lens[i]] : input_lens[i]
]
prompt = tokenizer.decode(re_encoded_sequence) prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = prefix_len + int(input_lens[i]) total_input_len = prefix_len + int(input_lens[i])
requests.append( requests.append(
...@@ -363,7 +360,8 @@ class RandomDataset(BenchmarkDataset): ...@@ -363,7 +360,8 @@ class RandomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=total_input_len, prompt_len=total_input_len,
expected_output_len=int(output_lens[i]), expected_output_len=int(output_lens[i]),
)) )
)
return requests return requests
...@@ -390,7 +388,8 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -390,7 +388,8 @@ class ShareGPTDataset(BenchmarkDataset):
self.data = json.load(f) self.data = json.load(f)
# Filter entries with at least two conversation turns. # Filter entries with at least two conversation turns.
self.data = [ self.data = [
entry for entry in self.data entry
for entry in self.data
if "conversations" in entry and len(entry["conversations"]) >= 2 if "conversations" in entry and len(entry["conversations"]) >= 2
] ]
random.seed(self.random_seed) random.seed(self.random_seed)
...@@ -416,27 +415,28 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -416,27 +415,28 @@ class ShareGPTDataset(BenchmarkDataset):
) )
lora_request, tokenizer = self.get_random_lora_request( lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
)
prompt_ids = tokenizer(prompt).input_ids prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids) prompt_len = len(prompt_ids)
new_output_len = (len(completion_ids) new_output_len = len(completion_ids) if output_len is None else output_len
if output_len is None else output_len) if not is_valid_sequence(
if not is_valid_sequence(prompt_len, prompt_len,
new_output_len, new_output_len,
skip_min_output_len_check=output_len skip_min_output_len_check=output_len is not None,
is not None): ):
continue continue
if enable_multimodal_chat: if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, None)
prompt, None)
samples.append( samples.append(
SampleRequest( SampleRequest(
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=new_output_len, expected_output_len=new_output_len,
lora_request=lora_request, lora_request=lora_request,
)) )
)
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests)
return samples return samples
...@@ -482,20 +482,20 @@ class SonnetDataset(BenchmarkDataset): ...@@ -482,20 +482,20 @@ class SonnetDataset(BenchmarkDataset):
) -> list: ) -> list:
# Calculate average token length for a poem line. # Calculate average token length for a poem line.
tokenized_lines = [tokenizer(line).input_ids for line in self.data] tokenized_lines = [tokenizer(line).input_ids for line in self.data]
avg_len = sum(len(tokens) avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
for tokens in tokenized_lines) / len(tokenized_lines)
# Build the base prompt. # Build the base prompt.
base_prompt = "Pick as many lines as you can from these poem lines:\n" base_prompt = "Pick as many lines as you can from these poem lines:\n"
base_msg = [{"role": "user", "content": base_prompt}] base_msg = [{"role": "user", "content": base_prompt}]
base_fmt = tokenizer.apply_chat_template(base_msg, base_fmt = tokenizer.apply_chat_template(
add_generation_prompt=True, base_msg, add_generation_prompt=True, tokenize=False
tokenize=False) )
base_offset = len(tokenizer(base_fmt).input_ids) base_offset = len(tokenizer(base_fmt).input_ids)
if input_len <= base_offset: if input_len <= base_offset:
raise ValueError( raise ValueError(
f"'input_len' must be higher than the base prompt length " f"'input_len' must be higher than the base prompt length "
f"({base_offset}).") f"({base_offset})."
)
# Determine how many poem lines to use. # Determine how many poem lines to use.
num_input_lines = round((input_len - base_offset) / avg_len) num_input_lines = round((input_len - base_offset) / avg_len)
...@@ -504,21 +504,23 @@ class SonnetDataset(BenchmarkDataset): ...@@ -504,21 +504,23 @@ class SonnetDataset(BenchmarkDataset):
samples = [] samples = []
while len(samples) < num_requests: while len(samples) < num_requests:
extra_lines = random.choices(self.data, extra_lines = random.choices(
k=num_input_lines - num_prefix_lines) self.data, k=num_input_lines - num_prefix_lines
)
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
msg = [{"role": "user", "content": prompt}] msg = [{"role": "user", "content": prompt}]
prompt_formatted = tokenizer.apply_chat_template( prompt_formatted = tokenizer.apply_chat_template(
msg, add_generation_prompt=True, tokenize=False) msg, add_generation_prompt=True, tokenize=False
)
prompt_len = len(tokenizer(prompt_formatted).input_ids) prompt_len = len(tokenizer(prompt_formatted).input_ids)
if prompt_len <= input_len: if prompt_len <= input_len:
samples.append( samples.append(
SampleRequest( SampleRequest(
prompt=prompt_formatted prompt=prompt_formatted if return_prompt_formatted else prompt,
if return_prompt_formatted else prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
)) )
)
return samples return samples
...@@ -538,7 +540,9 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -538,7 +540,9 @@ class BurstGPTDataset(BenchmarkDataset):
super().__init__(**kwargs) super().__init__(**kwargs)
self.load_data() self.load_data()
def load_data(self, ): def load_data(
self,
):
if self.dataset_path is None: if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.") raise ValueError("dataset_path must be provided for loading data.")
...@@ -552,8 +556,7 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -552,8 +556,7 @@ class BurstGPTDataset(BenchmarkDataset):
def _sample_loaded_data(self, num_requests: int) -> list: def _sample_loaded_data(self, num_requests: int) -> list:
if num_requests <= len(self.data): if num_requests <= len(self.data):
data = self.data.sample(n=num_requests, data = self.data.sample(n=num_requests, random_state=self.random_seed)
random_state=self.random_seed)
else: else:
data = self.data.sample( data = self.data.sample(
n=num_requests, n=num_requests,
...@@ -577,7 +580,8 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -577,7 +580,8 @@ class BurstGPTDataset(BenchmarkDataset):
input_len = int(data[i][2]) input_len = int(data[i][2])
output_len = int(data[i][3]) output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request( lora_req, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
)
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i + # Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size. # j) modulo vocab_size.
...@@ -589,7 +593,8 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -589,7 +593,8 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len, prompt_len=input_len,
expected_output_len=output_len, expected_output_len=output_len,
lora_request=lora_req, lora_request=lora_req,
)) )
)
return samples return samples
...@@ -632,20 +637,23 @@ class HuggingFaceDataset(BenchmarkDataset): ...@@ -632,20 +637,23 @@ class HuggingFaceDataset(BenchmarkDataset):
class ConversationDataset(HuggingFaceDataset): class ConversationDataset(HuggingFaceDataset):
"""Dataset for conversation data with multimodal support.""" """Dataset for conversation data with multimodal support."""
SUPPORTED_DATASET_PATHS = { SUPPORTED_DATASET_PATHS = {
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' "lmms-lab/LLaVA-OneVision-Data",
"Aeala/ShareGPT_Vicuna_unfiltered",
} }
IS_MULTIMODAL = True IS_MULTIMODAL = True
def sample(self, def sample(
tokenizer: PreTrainedTokenizerBase, self,
num_requests: int, tokenizer: PreTrainedTokenizerBase,
output_len: Optional[int] = None, num_requests: int,
enable_multimodal_chat: bool = False, output_len: Optional[int] = None,
**kwargs) -> list: enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
# Filter examples with at least 2 conversations # Filter examples with at least 2 conversations
filtered_data = self.data.filter( filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
lambda x: len(x["conversations"]) >= 2)
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
...@@ -661,24 +669,22 @@ class ConversationDataset(HuggingFaceDataset): ...@@ -661,24 +669,22 @@ class ConversationDataset(HuggingFaceDataset):
completion_len = len(completion_ids) completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0 assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence( if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
prompt_len, completion_len):
continue continue
mm_content = process_image( mm_content = process_image(item["image"]) if "image" in item else None
item["image"]) if "image" in item else None
if enable_multimodal_chat: if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer # Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the # accurate and we will be using request output to count the
# actual prompt len and output len # actual prompt len and output len
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
prompt, mm_content)
sampled_requests.append( sampled_requests.append(
SampleRequest( SampleRequest(
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
)) )
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
...@@ -695,10 +701,8 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -695,10 +701,8 @@ class VisionArenaDataset(HuggingFaceDataset):
DEFAULT_OUTPUT_LEN = 128 DEFAULT_OUTPUT_LEN = 128
SUPPORTED_DATASET_PATHS = { SUPPORTED_DATASET_PATHS = {
"lmarena-ai/VisionArena-Chat": "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
lambda x: x["conversation"][0][0]["content"], "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
"lmarena-ai/vision-arena-bench-v0.1":
lambda x: x["turns"][0][0]["content"]
} }
IS_MULTIMODAL = True IS_MULTIMODAL = True
...@@ -710,16 +714,14 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -710,16 +714,14 @@ class VisionArenaDataset(HuggingFaceDataset):
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = (output_len output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = [] sampled_requests = []
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
if parser_fn is None: if parser_fn is None:
raise ValueError( raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
f"Unsupported dataset path: {self.dataset_path}")
prompt = parser_fn(item) prompt = parser_fn(item)
mm_content = process_image(item["images"][0]) mm_content = process_image(item["images"][0])
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
...@@ -727,15 +729,15 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -727,15 +729,15 @@ class VisionArenaDataset(HuggingFaceDataset):
# Note: when chat is enabled the request prompt_len is no longer # Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the # accurate and we will be using request output to count the
# actual prompt len # actual prompt len
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
prompt, mm_content)
sampled_requests.append( sampled_requests.append(
SampleRequest( SampleRequest(
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
)) )
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
...@@ -760,14 +762,15 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -760,14 +762,15 @@ class InstructCoderDataset(HuggingFaceDataset):
"likaixin/InstructCoder", "likaixin/InstructCoder",
} }
def sample(self, def sample(
tokenizer: PreTrainedTokenizerBase, self,
num_requests: int, tokenizer: PreTrainedTokenizerBase,
output_len: Optional[int] = None, num_requests: int,
enable_multimodal_chat: bool = False, output_len: Optional[int] = None,
**kwargs) -> list: enable_multimodal_chat: bool = False,
output_len = (output_len **kwargs,
if output_len is not None else self.DEFAULT_OUTPUT_LEN) ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
...@@ -779,7 +782,8 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -779,7 +782,8 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
)) )
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
...@@ -794,38 +798,38 @@ class MTBenchDataset(HuggingFaceDataset): ...@@ -794,38 +798,38 @@ class MTBenchDataset(HuggingFaceDataset):
MT-Bench Dataset. MT-Bench Dataset.
https://huggingface.co/datasets/philschmid/mt-bench https://huggingface.co/datasets/philschmid/mt-bench
We create a single turn dataset for MT-Bench. We create a single turn dataset for MT-Bench.
This is similar to Spec decoding benchmark setup in vLLM This is similar to Spec decoding benchmark setup in vLLM
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
""" # noqa: E501 """ # noqa: E501
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = { SUPPORTED_DATASET_PATHS = {
"philschmid/mt-bench", "philschmid/mt-bench",
} }
def sample(self, def sample(
tokenizer: PreTrainedTokenizerBase, self,
num_requests: int, tokenizer: PreTrainedTokenizerBase,
output_len: Optional[int] = None, num_requests: int,
enable_multimodal_chat: bool = False, output_len: Optional[int] = None,
**kwargs) -> list: enable_multimodal_chat: bool = False,
output_len = (output_len **kwargs,
if output_len is not None else self.DEFAULT_OUTPUT_LEN) ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item['turns'][0] prompt = item["turns"][0]
# apply template # apply template
prompt = tokenizer.apply_chat_template([{ prompt = tokenizer.apply_chat_template(
"role": "user", [{"role": "user", "content": prompt}],
"content": prompt add_generation_prompt=True,
}], tokenize=False,
add_generation_prompt=True, )
tokenize=False)
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append( sampled_requests.append(
...@@ -833,7 +837,8 @@ class MTBenchDataset(HuggingFaceDataset): ...@@ -833,7 +837,8 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
)) )
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
...@@ -847,23 +852,27 @@ class AIMODataset(HuggingFaceDataset): ...@@ -847,23 +852,27 @@ class AIMODataset(HuggingFaceDataset):
""" """
Dataset class for processing a AIMO dataset with reasoning questions. Dataset class for processing a AIMO dataset with reasoning questions.
""" """
SUPPORTED_DATASET_PATHS = { SUPPORTED_DATASET_PATHS = {
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", "AI-MO/aimo-validation-aime",
"AI-MO/NuminaMath-CoT" "AI-MO/NuminaMath-1.5",
"AI-MO/NuminaMath-CoT",
} }
def sample(self, def sample(
tokenizer: PreTrainedTokenizerBase, self,
num_requests: int, tokenizer: PreTrainedTokenizerBase,
output_len: Optional[int] = None, num_requests: int,
**kwargs) -> list: output_len: Optional[int] = None,
**kwargs,
) -> list:
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt, completion = item['problem'], item["solution"] prompt, completion = item["problem"], item["solution"]
prompt_ids = tokenizer(prompt).input_ids prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids completion_ids = tokenizer(completion).input_ids
...@@ -871,10 +880,9 @@ class AIMODataset(HuggingFaceDataset): ...@@ -871,10 +880,9 @@ class AIMODataset(HuggingFaceDataset):
completion_len = len(completion_ids) completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0 assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence(prompt_len, if dynamic_output and not is_valid_sequence(
completion_len, prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
max_prompt_len=2048, ):
max_total_len=32000):
continue continue
sampled_requests.append( sampled_requests.append(
SampleRequest( SampleRequest(
...@@ -882,7 +890,8 @@ class AIMODataset(HuggingFaceDataset): ...@@ -882,7 +890,8 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=None, multi_modal_data=None,
)) )
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
...@@ -905,25 +914,25 @@ You are a code completion assistant and your task is to analyze user edits and t ...@@ -905,25 +914,25 @@ You are a code completion assistant and your task is to analyze user edits and t
### Response: ### Response:
""" # noqa: E501 """ # noqa: E501
def _format_zeta_prompt( def _format_zeta_prompt(
sample: dict, sample: dict, original_start_marker: str = "<|editable_region_start|>"
original_start_marker: str = "<|editable_region_start|>") -> dict: ) -> dict:
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset. """Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
This function formats examples from the NEP dataset This function formats examples from the NEP dataset
into prompts and expected outputs. It could be into prompts and expected outputs. It could be
further extended to support more NEP datasets. further extended to support more NEP datasets.
Args: Args:
sample: The dataset sample containing events, sample: The dataset sample containing events,
inputs, and outputs. inputs, and outputs.
original_start_marker: The marker indicating the original_start_marker: The marker indicating the
start of the editable region. Defaults to start of the editable region. Defaults to
"<|editable_region_start|>". "<|editable_region_start|>".
Returns: Returns:
A dictionary with the formatted prompts and expected outputs. A dictionary with the formatted prompts and expected outputs.
""" """
...@@ -953,10 +962,8 @@ class NextEditPredictionDataset(HuggingFaceDataset): ...@@ -953,10 +962,8 @@ class NextEditPredictionDataset(HuggingFaceDataset):
"zed-industries/zeta": _format_zeta_prompt, "zed-industries/zeta": _format_zeta_prompt,
} }
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
**kwargs): formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
self.dataset_path)
if formatting_prompt_func is None: if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}") raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = [] samples = []
...@@ -967,8 +974,10 @@ class NextEditPredictionDataset(HuggingFaceDataset): ...@@ -967,8 +974,10 @@ class NextEditPredictionDataset(HuggingFaceDataset):
prompt=sample["prompt"], prompt=sample["prompt"],
prompt_len=len(tokenizer(sample["prompt"]).input_ids), prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len( expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids), tokenizer(sample["expected_output"]).input_ids
)) ),
)
)
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests)
...@@ -997,18 +1006,22 @@ class ASRDataset(HuggingFaceDataset): ...@@ -997,18 +1006,22 @@ class ASRDataset(HuggingFaceDataset):
| AMI | Meetings | Spontaneous | ihm, sdm | | AMI | Meetings | Spontaneous | ihm, sdm |
+----------------+----------------------------------------+--------------------------+-----------------------------+ +----------------+----------------------------------------+--------------------------+-----------------------------+
""" # noqa: E501 """ # noqa: E501
SUPPORTED_DATASET_PATHS = { SUPPORTED_DATASET_PATHS = {
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium", "openslr/librispeech_asr",
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech" "facebook/voxpopuli",
"LIUM/tedlium",
"edinburghcstr/ami",
"speechcolab/gigaspeech",
"kensho/spgispeech",
} }
DEFAULT_OUTPUT_LEN = 128 DEFAULT_OUTPUT_LEN = 128
IS_MULTIMODAL = True IS_MULTIMODAL = True
# TODO Whisper-specific. Abstract interface when more models are supported. # TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\ TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
"<|notimestamps|>"
skip_long_audios: bool = True skip_long_audios: bool = True
def sample( def sample(
...@@ -1019,8 +1032,8 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1019,8 +1032,8 @@ class ASRDataset(HuggingFaceDataset):
**kwargs, **kwargs,
) -> list: ) -> list:
import librosa import librosa
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN) output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = [] sampled_requests = []
...@@ -1043,10 +1056,14 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1043,10 +1056,14 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
)) )
)
if skipped: if skipped:
logger.warning("%d samples discarded from dataset due to" \ logger.warning(
" their length being greater than" \ "%d samples discarded from dataset due to"
" what Whisper supports.", skipped) " their length being greater than"
" what Whisper supports.",
skipped,
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
...@@ -11,9 +11,9 @@ from typing import Any, Optional ...@@ -11,9 +11,9 @@ from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm from tqdm import tqdm
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
...@@ -21,13 +21,14 @@ from vllm.sampling_params import BeamSearchParams ...@@ -21,13 +21,14 @@ from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(
results: dict[str, Any]) -> None: args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={"latency": results["latencies"]}, metrics={"latency": results["latencies"]},
extra_info={k: results[k] extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
for k in ["avg_latency", "percentiles"]}) )
if pt_records: if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records) write_to_json(pt_file, pt_records)
...@@ -42,9 +43,11 @@ def main(args: argparse.Namespace): ...@@ -42,9 +43,11 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= ( assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.input_len + args.output_len
args.output_len), ("Please ensure that max_model_len is greater than" ), (
" the sum of input_len and output_len.") "Please ensure that max_model_len is greater than"
" the sum of input_len and output_len."
)
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
...@@ -55,18 +58,16 @@ def main(args: argparse.Namespace): ...@@ -55,18 +58,16 @@ def main(args: argparse.Namespace):
detokenize=not args.disable_detokenize, detokenize=not args.disable_detokenize,
) )
print(sampling_params) print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(
size=(args.batch_size, 10000, size=(args.batch_size, args.input_len)
args.input_len)) )
dummy_prompts: list[PromptType] = [{ dummy_prompts: list[PromptType] = [
"prompt_token_ids": batch {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
} for batch in dummy_prompt_token_ids.tolist()] ]
def llm_generate(): def llm_generate():
if not args.use_beam_search: if not args.use_beam_search:
llm.generate(dummy_prompts, llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
sampling_params=sampling_params,
use_tqdm=False)
else: else:
llm.beam_search( llm.beam_search(
dummy_prompts, dummy_prompts,
...@@ -80,12 +81,13 @@ def main(args: argparse.Namespace): ...@@ -80,12 +81,13 @@ def main(args: argparse.Namespace):
def run_to_completion(profile_dir: Optional[str] = None): def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir: if profile_dir:
with torch.profiler.profile( with torch.profiler.profile(
activities=[ activities=[
torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CUDA,
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir)), str(profile_dir)
),
) as p: ) as p:
llm_generate() llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total")) print(p.key_averages().table(sort_by="self_cuda_time_total"))
...@@ -103,8 +105,9 @@ def main(args: argparse.Namespace): ...@@ -103,8 +105,9 @@ def main(args: argparse.Namespace):
if args.profile: if args.profile:
profile_dir = args.profile_result_dir profile_dir = args.profile_result_dir
if not profile_dir: if not profile_dir:
profile_dir = (Path(".") / "vllm_benchmark_result" / profile_dir = (
f"latency_result_{time.time()}") Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
)
print(f"Profiling (results will be saved to '{profile_dir}')...") print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir) run_to_completion(profile_dir=profile_dir)
return return
...@@ -135,7 +138,8 @@ def main(args: argparse.Namespace): ...@@ -135,7 +138,8 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the latency of processing a single batch of " description="Benchmark the latency of processing a single batch of "
"requests till completion.") "requests till completion."
)
parser.add_argument("--input-len", type=int, default=32) parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128) parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
...@@ -152,10 +156,9 @@ if __name__ == "__main__": ...@@ -152,10 +156,9 @@ if __name__ == "__main__":
default=10, default=10,
help="Number of iterations to run for warmup.", help="Number of iterations to run for warmup.",
) )
parser.add_argument("--num-iters", parser.add_argument(
type=int, "--num-iters", type=int, default=30, help="Number of iterations to run."
default=30, )
help="Number of iterations to run.")
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
...@@ -165,8 +168,10 @@ if __name__ == "__main__": ...@@ -165,8 +168,10 @@ if __name__ == "__main__":
"--profile-result-dir", "--profile-result-dir",
type=str, type=str,
default=None, default=None,
help=("path to save the pytorch profiler output. Can be visualized " help=(
"with ui.perfetto.dev or Tensorboard."), "path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."
),
) )
parser.add_argument( parser.add_argument(
"--output-json", "--output-json",
...@@ -177,8 +182,10 @@ if __name__ == "__main__": ...@@ -177,8 +182,10 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--disable-detokenize", "--disable-detokenize",
action="store_true", action="store_true",
help=("Do not detokenize responses (i.e. do not include " help=(
"detokenization time in the latency measurement)"), "Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
......
...@@ -76,7 +76,7 @@ def repeat_prompts(prompts, repeat_count, mode: str): ...@@ -76,7 +76,7 @@ def repeat_prompts(prompts, repeat_count, mode: str):
- 'random': Shuffle the prompts randomly after repetition. - 'random': Shuffle the prompts randomly after repetition.
- 'tile': Repeat the entire prompt list in sequence. - 'tile': Repeat the entire prompt list in sequence.
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
- 'interleave': Repeat each prompt consecutively before moving to - 'interleave': Repeat each prompt consecutively before moving to
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
Returns: Returns:
...@@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str): ...@@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str):
ValueError: If an invalid mode is provided. ValueError: If an invalid mode is provided.
""" """
print("Repeat mode: ", mode) print("Repeat mode: ", mode)
if mode == 'random': if mode == "random":
repeated_prompts = prompts * repeat_count repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts) random.shuffle(repeated_prompts)
return repeated_prompts return repeated_prompts
elif mode == 'tile': elif mode == "tile":
return prompts * repeat_count return prompts * repeat_count
elif mode == 'interleave': elif mode == "interleave":
repeated_prompts = [] repeated_prompts = []
for prompt in prompts: for prompt in prompts:
repeated_prompts.extend([prompt] * repeat_count) repeated_prompts.extend([prompt] * repeat_count)
return repeated_prompts return repeated_prompts
else: else:
raise ValueError(f"Invalid mode: {mode}, only support " raise ValueError(
"'random', 'tile', 'interleave'") f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
)
def main(args): def main(args):
...@@ -109,16 +110,16 @@ def main(args): ...@@ -109,16 +110,16 @@ def main(args):
# we append the document id at the beginning to avoid any of the document # we append the document id at the beginning to avoid any of the document
# being the prefix of other documents # being the prefix of other documents
prompts = [ prompts = [
str(i) + ' '.join(['hi'] * args.document_length) str(i) + " ".join(["hi"] * args.document_length)
for i in range(args.num_documents) for i in range(args.num_documents)
] ]
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
warmup_prompts = [ warmup_prompts = [
"This is warm up request " + str(i) + \ "This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
' '.join(['hi'] * args.document_length) for i in range(args.num_documents)
for i in range(args.num_documents)] ]
# Create the LLM engine # Create the LLM engine
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
...@@ -142,42 +143,52 @@ def main(args): ...@@ -142,42 +143,52 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description= description="Benchmark the performance with or "
'Benchmark the performance with or without automatic prefix caching.') "without automatic prefix caching."
)
parser.add_argument( parser.add_argument(
'--document-length', "--document-length",
type=int, type=int,
# Roughly the number of tokens for a system paper, # Roughly the number of tokens for a system paper,
# excluding images # excluding images
default=20000, default=20000,
help='Range of input lengths for sampling prompts,' help="Range of input lengths for sampling prompts, "
'specified as "min:max" (e.g., "128:256").') 'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument('--num-documents',
type=int, parser.add_argument(
default=8, "--num-documents",
help='Range of input lengths for sampling prompts,' type=int,
'specified as "min:max" (e.g., "128:256").') default=8,
help="Range of input lengths for sampling prompts, "
parser.add_argument('--output-len', type=int, default=10) 'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument('--repeat-count',
type=int, parser.add_argument("--output-len", type=int, default=10)
default=2,
help='Number of times to repeat each prompt') parser.add_argument(
"--repeat-count",
parser.add_argument("--repeat-mode", type=int,
type=str, default=2,
default='random', help="Number of times to repeat each prompt",
help='The mode to repeat prompts. The supported ' )
'modes are "random", "tile", and "interleave". '
'See repeat_prompts() in the source code for details.') parser.add_argument(
"--repeat-mode",
parser.add_argument("--shuffle-seed", type=str,
type=int, default="random",
default=0, help="The mode to repeat prompts. The supported "
help='Random seed when the repeat mode is "random"') 'modes are "random", "tile", and "interleave". '
"See repeat_prompts() in the source code for details.",
)
parser.add_argument(
"--shuffle-seed",
type=int,
default=0,
help='Random seed when the repeat mode is "random"',
)
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -63,8 +63,7 @@ class Request: ...@@ -63,8 +63,7 @@ class Request:
output_len: int output_len: int
def sample_tokens(tokenizer: PreTrainedTokenizerBase, def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
length: int) -> list[int]:
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids) all_special_ids = set(tokenizer.all_special_ids)
...@@ -91,8 +90,10 @@ def sample_requests_from_dataset( ...@@ -91,8 +90,10 @@ def sample_requests_from_dataset(
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"], dataset = [
data["conversations"][1]["value"]) for data in dataset] (data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset. # Shuffle the dataset.
random.shuffle(dataset) random.shuffle(dataset)
...@@ -113,8 +114,9 @@ def sample_requests_from_dataset( ...@@ -113,8 +114,9 @@ def sample_requests_from_dataset(
completion = dataset[i][1] completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = (len(completion_token_ids) output_len = (
if fixed_output_len is None else fixed_output_len) len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if min_len <= prompt_len <= max_len: if min_len <= prompt_len <= max_len:
filtered_requests.append(Request(prompt, prompt_len, output_len)) filtered_requests.append(Request(prompt, prompt_len, output_len))
...@@ -128,27 +130,27 @@ def sample_requests_from_random( ...@@ -128,27 +130,27 @@ def sample_requests_from_random(
fixed_output_len: Optional[int], fixed_output_len: Optional[int],
prefix_len: int, prefix_len: int,
) -> list[Request]: ) -> list[Request]:
requests = [] requests = []
prefix_token_ids = sample_tokens(tokenizer, prefix_len) prefix_token_ids = sample_tokens(tokenizer, prefix_len)
min_len, max_len = input_length_range min_len, max_len = input_length_range
for i in range(num_requests): for i in range(num_requests):
unique_part_token_ids = sample_tokens( unique_part_token_ids = sample_tokens(
tokenizer, tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len)
random.randint(min_len - prefix_len, max_len - prefix_len)) )
prompt_token_ids = prefix_token_ids + unique_part_token_ids prompt_token_ids = prefix_token_ids + unique_part_token_ids
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
assert (min_len <= prompt_len <= max_len assert min_len <= prompt_len <= max_len, (
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
)
requests.append(Request(prompt, prompt_len, fixed_output_len)) requests.append(Request(prompt, prompt_len, fixed_output_len))
return requests return requests
def repeat_and_sort_requests(requests: list[Request], def repeat_and_sort_requests(
repeat_count: int, requests: list[Request], repeat_count: int, sort: bool = False
sort: bool = False) -> list[str]: ) -> list[str]:
repeated_requests = requests * repeat_count repeated_requests = requests * repeat_count
if sort: if sort:
repeated_requests.sort(key=lambda x: x[1]) repeated_requests.sort(key=lambda x: x[1])
...@@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request], ...@@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request],
def main(args): def main(args):
tokenizer = get_tokenizer(args.model, trust_remote_code=True) tokenizer = get_tokenizer(args.model, trust_remote_code=True)
input_length_range = tuple(map(int, args.input_length_range.split(':'))) input_length_range = tuple(map(int, args.input_length_range.split(":")))
random.seed(args.seed) random.seed(args.seed)
if args.dataset_path is not None: if args.dataset_path is not None:
if args.prefix_len > 0: if args.prefix_len > 0:
raise ValueError("prefix-len is not supported when " raise ValueError(
"dataset-path is provided.") "prefix-len is not supported when dataset-path is provided."
print(f"Start to sample {args.num_prompts} prompts " )
f"from {args.dataset_path}") print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}")
filtered_requests = sample_requests_from_dataset( filtered_requests = sample_requests_from_dataset(
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
...@@ -196,14 +198,16 @@ def main(args): ...@@ -196,14 +198,16 @@ def main(args):
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
sampling_params = SamplingParams(temperature=0, sampling_params = SamplingParams(
max_tokens=args.output_len, temperature=0,
detokenize=not args.disable_detokenize) max_tokens=args.output_len,
detokenize=not args.disable_detokenize,
)
print("Testing filtered requests") print("Testing filtered requests")
prompts = repeat_and_sort_requests(filtered_requests, prompts = repeat_and_sort_requests(
repeat_count=args.repeat_count, filtered_requests, repeat_count=args.repeat_count, sort=args.sort
sort=args.sort) )
print("------start generating------") print("------start generating------")
test_prefix( test_prefix(
...@@ -215,29 +219,35 @@ def main(args): ...@@ -215,29 +219,35 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description= description="Benchmark the performance with or without "
'Benchmark the performance with or without automatic prefix caching.') "automatic prefix caching."
parser.add_argument("--dataset-path", )
type=str, parser.add_argument(
default=None, "--dataset-path", type=str, default=None, help="Path to the dataset."
help="Path to the dataset.") )
parser.add_argument('--output-len', type=int, default=10) parser.add_argument("--output-len", type=int, default=10)
parser.add_argument('--num-prompts', parser.add_argument(
type=int, "--num-prompts",
required=True, type=int,
help="Number of the prompts sampled from dataset") required=True,
parser.add_argument('--repeat-count', help="Number of the prompts sampled from dataset",
type=int, )
default=1, parser.add_argument(
help='Number of times to repeat each prompt') "--repeat-count",
parser.add_argument('--sort', type=int,
action='store_true', default=1,
help='Sort prompts by input length') help="Number of times to repeat each prompt",
parser.add_argument('--input-length-range', )
type=str, parser.add_argument(
required=True, "--sort", action="store_true", help="Sort prompts by input length"
help='Range of input lengths for sampling prompts,' )
'specified as "min:max" (e.g., "128:256").') parser.add_argument(
"--input-length-range",
type=str,
required=True,
help="Range of input lengths for sampling prompts,"
'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument( parser.add_argument(
"--prefix-len", "--prefix-len",
type=int, type=int,
...@@ -248,10 +258,12 @@ if __name__ == "__main__": ...@@ -248,10 +258,12 @@ if __name__ == "__main__":
"when dataset-path is not provided.", "when dataset-path is not provided.",
) )
parser.add_argument( parser.add_argument(
'--disable-detokenize', "--disable-detokenize",
action='store_true', action="store_true",
help=("Do not detokenize responses (i.e. do not include " help=(
"detokenization time in the latency measurement)"), "Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Benchmark offline prioritization.""" """Benchmark offline prioritization."""
import argparse import argparse
import dataclasses import dataclasses
import json import json
...@@ -13,7 +14,7 @@ from vllm.engine.arg_utils import EngineArgs ...@@ -13,7 +14,7 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
#Select a equi-probable random priority # Select a equi-probable random priority
def get_random_flag(): def get_random_flag():
return 0 if random.random() < 0.5 else 1 return 0 if random.random() < 0.5 else 1
...@@ -33,8 +34,10 @@ def sample_requests( ...@@ -33,8 +34,10 @@ def sample_requests(
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"], dataset = [
data["conversations"][1]["value"]) for data in dataset] (data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset. # Shuffle the dataset.
random.shuffle(dataset) random.shuffle(dataset)
...@@ -51,8 +54,9 @@ def sample_requests( ...@@ -51,8 +54,9 @@ def sample_requests(
completion = dataset[i][1] completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids output_len = (
) if fixed_output_len is None else fixed_output_len len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4: if prompt_len < 4 or output_len < 4:
# Prune too short sequences. # Prune too short sequences.
continue continue
...@@ -74,13 +78,16 @@ def run_vllm( ...@@ -74,13 +78,16 @@ def run_vllm(
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert all( assert all(
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
for request in requests), ( for request in requests
"Please ensure that max_model_len is greater than the sum of" ), (
" input_len and output_len for all requests.") "Please ensure that max_model_len is greater than the sum of"
" input_len and output_len for all requests."
)
# Add the requests to the engine. # Add the requests to the engine.
prompts = [] prompts = []
...@@ -97,7 +104,8 @@ def run_vllm( ...@@ -97,7 +104,8 @@ def run_vllm(
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) )
)
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
...@@ -111,26 +119,33 @@ def main(args: argparse.Namespace): ...@@ -111,26 +119,33 @@ def main(args: argparse.Namespace):
# Sample the requests. # Sample the requests.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.dataset is None: if args.dataset is None:
# Synthesize a prompt with the given input length. # Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1) prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len, requests = [
get_random_flag()) for _ in range(args.num_prompts)] (prompt, args.input_len, args.output_len, get_random_flag())
for _ in range(args.num_prompts)
]
else: else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer, requests = sample_requests(
args.output_len) args.dataset, args.num_prompts, tokenizer, args.output_len
)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.n, elapsed_time = run_vllm(
EngineArgs.from_cli_args(args), requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
args.disable_detokenize) )
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len total_num_tokens = sum(
for _, prompt_len, output_len, priority in requests) prompt_len + output_len for _, prompt_len, output_len, priority in requests
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " )
f"{total_num_tokens / elapsed_time:.2f} tokens/s") print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s"
)
# Output JSON results if specified # Output JSON results if specified
if args.output_json: if args.output_json:
...@@ -147,41 +162,44 @@ def main(args: argparse.Namespace): ...@@ -147,41 +162,44 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
default="vllm")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=200,
help="Number of prompts to process.")
parser.add_argument( parser.add_argument(
'--output-json', "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
)
parser.add_argument(
"--dataset", type=str, default=None, help="Path to the dataset."
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="Input prompt length for each request",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.",
)
parser.add_argument(
"--n", type=int, default=1, help="Number of generated sequences per prompt."
)
parser.add_argument(
"--num-prompts", type=int, default=200, help="Number of prompts to process."
)
parser.add_argument(
"--output-json",
type=str, type=str,
default=None, default=None,
help='Path to save the throughput results in JSON format.') help="Path to save the throughput results in JSON format.",
)
parser.add_argument( parser.add_argument(
'--disable-detokenize', "--disable-detokenize",
action='store_true', action="store_true",
help=("Do not detokenize responses (i.e. do not include " help=(
"detokenization time in the latency measurement)"), "Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
......
...@@ -20,6 +20,7 @@ On the client side, run: ...@@ -20,6 +20,7 @@ On the client side, run:
--endpoint /generate_stream --endpoint /generate_stream
to the end of the command above. to the end of the command above.
""" """
import argparse import argparse
import asyncio import asyncio
import gc import gc
...@@ -34,12 +35,16 @@ from datetime import datetime ...@@ -34,12 +35,16 @@ from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
RequestFuncOutput)
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
RequestFuncOutput,
)
try: try:
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError: except ImportError:
...@@ -50,12 +55,21 @@ try: ...@@ -50,12 +55,21 @@ try:
except ImportError: except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser from argparse import ArgumentParser as FlexibleArgumentParser
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, from benchmark_dataset import (
ConversationDataset, HuggingFaceDataset, AIMODataset,
InstructCoderDataset, MTBenchDataset, ASRDataset,
NextEditPredictionDataset, RandomDataset, BurstGPTDataset,
SampleRequest, ShareGPTDataset, SonnetDataset, ConversationDataset,
VisionArenaDataset) HuggingFaceDataset,
InstructCoderDataset,
MTBenchDataset,
NextEditPredictionDataset,
RandomDataset,
SampleRequest,
ShareGPTDataset,
SonnetDataset,
VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
...@@ -118,7 +132,8 @@ async def get_request( ...@@ -118,7 +132,8 @@ async def get_request(
# Calculate scale parameter theta to maintain the desired request_rate. # Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, ( assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.") f"A positive burstiness factor is expected, but given {burstiness}."
)
theta = 1.0 / (request_rate * burstiness) theta = 1.0 / (request_rate * burstiness)
for request in input_requests: for request in input_requests:
...@@ -164,8 +179,10 @@ def calculate_metrics( ...@@ -164,8 +179,10 @@ def calculate_metrics(
# bundled together # bundled together
# Note : this may inflate the output token count slightly # Note : this may inflate the output token count slightly
output_len = len( output_len = len(
tokenizer(outputs[i].generated_text, tokenizer(
add_special_tokens=False).input_ids) outputs[i].generated_text, add_special_tokens=False
).input_ids
)
actual_output_lens.append(output_len) actual_output_lens.append(output_len)
total_input += input_requests[i].prompt_len total_input += input_requests[i].prompt_len
tpot = 0 tpot = 0
...@@ -188,16 +205,19 @@ def calculate_metrics( ...@@ -188,16 +205,19 @@ def calculate_metrics(
if "ttft" in goodput_config_dict: if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts) valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] / slo_values.append(
MILLISECONDS_TO_SECONDS_CONVERSION) goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "tpot" in goodput_config_dict: if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots) valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] / slo_values.append(
MILLISECONDS_TO_SECONDS_CONVERSION) goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "e2el" in goodput_config_dict: if "e2el" in goodput_config_dict:
valid_metrics.append(e2els) valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] / slo_values.append(
MILLISECONDS_TO_SECONDS_CONVERSION) goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
for req_metric in zip(*valid_metrics): for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
...@@ -208,7 +228,8 @@ def calculate_metrics( ...@@ -208,7 +228,8 @@ def calculate_metrics(
warnings.warn( warnings.warn(
"All requests failed. This is likely due to a misconfiguration " "All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.", "on the benchmark arguments.",
stacklevel=2) stacklevel=2,
)
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
...@@ -217,27 +238,31 @@ def calculate_metrics( ...@@ -217,27 +238,31 @@ def calculate_metrics(
request_goodput=good_completed / dur_s, request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s, output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0) * mean_ttft_ms=np.mean(ttfts or 0)
1000, # ttfts is empty if streaming is not supported by backend * 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) percentiles_ttft_ms=[
for p in selected_percentiles], (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
],
mean_tpot_ms=np.mean(tpots or 0) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) percentiles_tpot_ms=[
for p in selected_percentiles], (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or 0) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) percentiles_itl_ms=[
for p in selected_percentiles], (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or 0) * 1000, mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) percentiles_e2el_ms=[
for p in selected_percentiles], (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
],
) )
return metrics, actual_output_lens return metrics, actual_output_lens
...@@ -270,10 +295,12 @@ async def benchmark( ...@@ -270,10 +295,12 @@ async def benchmark(
raise ValueError(f"Unknown backend: {backend}") raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...") print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len, test_mm_content = \ test_prompt, test_prompt_len, test_output_len, test_mm_content = (
input_requests[0].prompt, input_requests[0].prompt_len, \ input_requests[0].prompt,
input_requests[0].expected_output_len, \ input_requests[0].prompt_len,
input_requests[0].multi_modal_data input_requests[0].expected_output_len,
input_requests[0].multi_modal_data,
)
assert test_mm_content is None or isinstance(test_mm_content, dict) assert test_mm_content is None or isinstance(test_mm_content, dict)
test_input = RequestFuncInput( test_input = RequestFuncInput(
...@@ -293,36 +320,36 @@ async def benchmark( ...@@ -293,36 +320,36 @@ async def benchmark(
if not test_output.success: if not test_output.success:
raise ValueError( raise ValueError(
"Initial test run failed - Please make sure benchmark arguments " "Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}") f"are correctly specified. Error: {test_output.error}"
)
else: else:
print("Initial test run completed. Starting main benchmark run...") print("Initial test run completed. Starting main benchmark run...")
if lora_modules: if lora_modules:
# For each input request, choose a LoRA module at random. # For each input request, choose a LoRA module at random.
lora_modules = iter( lora_modules = iter(
[random.choice(lora_modules) \ [random.choice(lora_modules) for _ in range(len(input_requests))]
for _ in range(len(input_requests))]) )
if profile: if profile:
print("Starting profiler...") print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id, profile_input = RequestFuncInput(
model_name=model_name, model=model_id,
prompt=test_prompt, model_name=model_name,
api_url=base_url + "/start_profile", prompt=test_prompt,
prompt_len=test_prompt_len, api_url=base_url + "/start_profile",
output_len=test_output_len, prompt_len=test_prompt_len,
logprobs=logprobs, output_len=test_output_len,
multi_modal_content=test_mm_content, logprobs=logprobs,
ignore_eos=ignore_eos, multi_modal_content=test_mm_content,
extra_body=extra_body) ignore_eos=ignore_eos,
extra_body=extra_body,
)
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
if burstiness == 1.0: distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
distribution = "Poisson process"
else:
distribution = "Gamma distribution"
print(f"Traffic request rate: {request_rate}") print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Burstiness factor: {burstiness} ({distribution})")
...@@ -334,42 +361,45 @@ async def benchmark( ...@@ -334,42 +361,45 @@ async def benchmark(
# and it will simplify the code in limited_request_func. # and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency) # semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext()) # if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency) semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
if max_concurrency else None)
async def limited_request_func(request_func_input, pbar): async def limited_request_func(request_func_input, pbar):
if semaphore is None: if semaphore is None:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness): async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request.prompt, \ prompt, prompt_len, output_len, mm_content = (
request.prompt_len, request.expected_output_len, \ request.prompt,
request.multi_modal_data request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
)
req_model_id, req_model_name = model_id, model_name req_model_id, req_model_name = model_id, model_name
if lora_modules: if lora_modules:
req_lora_module = next(lora_modules) req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id, request_func_input = RequestFuncInput(
model_name=req_model_name, model=req_model_id,
prompt=prompt, model_name=req_model_name,
api_url=api_url, prompt=prompt,
prompt_len=prompt_len, api_url=api_url,
output_len=output_len, prompt_len=prompt_len,
logprobs=logprobs, output_len=output_len,
multi_modal_content=mm_content, logprobs=logprobs,
ignore_eos=ignore_eos, multi_modal_content=mm_content,
extra_body=extra_body) ignore_eos=ignore_eos,
extra_body=extra_body,
)
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar))) )
)
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile: if profile:
...@@ -401,22 +431,32 @@ async def benchmark( ...@@ -401,22 +431,32 @@ async def benchmark(
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
) )
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
metrics.total_output)) print(
print("{:<40} {:<10.2f}".format("Request throughput (req/s):", "{:<40} {:<10.2f}".format(
metrics.request_throughput)) "Request throughput (req/s):", metrics.request_throughput
)
)
if goodput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print(
metrics.request_goodput)) "{:<40} {:<10.2f}".format(
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", "Request goodput (req/s):", metrics.request_goodput
metrics.output_throughput)) )
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", )
metrics.total_token_throughput)) print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Total Token throughput (tok/s):", metrics.total_token_throughput
)
)
result = { result = {
"duration": benchmark_duration, "duration": benchmark_duration,
...@@ -424,8 +464,7 @@ async def benchmark( ...@@ -424,8 +464,7 @@ async def benchmark(
"total_input_tokens": metrics.total_input, "total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput, "request_throughput": metrics.request_throughput,
"request_goodput:": "request_goodput:": metrics.request_goodput if goodput_config_dict else None,
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput, "total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
...@@ -448,29 +487,35 @@ async def benchmark( ...@@ -448,29 +487,35 @@ async def benchmark(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
f"Mean {metric_name} (ms):", "{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}_ms"))) f"Mean {metric_name} (ms):",
print("{:<40} {:<10.2f}".format( getattr(metrics, f"mean_{metric_attribute_name}_ms"),
f"Median {metric_name} (ms):", )
getattr(metrics, f"median_{metric_attribute_name}_ms"))) )
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr( result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms") metrics, f"mean_{metric_attribute_name}_ms"
)
result[f"median_{metric_attribute_name}_ms"] = getattr( result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms") metrics, f"median_{metric_attribute_name}_ms"
)
result[f"std_{metric_attribute_name}_ms"] = getattr( result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms") metrics, f"std_{metric_attribute_name}_ms"
for p, value in getattr(metrics, )
f"percentiles_{metric_attribute_name}_ms"): for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token") process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT", process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency")
...@@ -490,12 +535,14 @@ def check_goodput_args(args): ...@@ -490,12 +535,14 @@ def check_goodput_args(args):
raise ValueError( raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. " f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of " "The service level objective name should be one of "
f"{str(VALID_NAMES)}. ") f"{str(VALID_NAMES)}. "
)
if slo_val < 0: if slo_val < 0:
raise ValueError( raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. " f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be " "The service level objective value should be "
"non-negative.") "non-negative."
)
return goodput_config_dict return goodput_config_dict
...@@ -508,31 +555,42 @@ def parse_goodput(slo_pairs): ...@@ -508,31 +555,42 @@ def parse_goodput(slo_pairs):
except ValueError as err: except ValueError as err:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. " "Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" " 'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a " "pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err "number in milliseconds."
) from err
return goodput_config_dict return goodput_config_dict
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(
results: dict[str, Any], args: argparse.Namespace, results: dict[str, Any], file_name: str
file_name: str) -> None: ) -> None:
metrics = [ metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", "median_ttft_ms",
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", "mean_ttft_ms",
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" "std_ttft_ms",
"p99_ttft_ms",
"mean_tpot_ms",
"median_tpot_ms",
"std_tpot_ms",
"p99_tpot_ms",
"median_itl_ms",
"mean_itl_ms",
"std_itl_ms",
"p99_itl_ms",
] ]
# These raw data might be useful, but they are rather big. They can be added # These raw data might be useful, but they are rather big. They can be added
# later if needed # later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={k: [results[k]] metrics={k: [results[k]] for k in metrics},
for k in metrics},
extra_info={ extra_info={
k: results[k] k: results[k]
for k in results if k not in metrics and k not in ignored_metrics for k in results
}) if k not in metrics and k not in ignored_metrics
},
)
if pt_records: if pt_records:
# Don't use json suffix here as we don't want CI to pick it up # Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
...@@ -557,34 +615,42 @@ def main(args: argparse.Namespace): ...@@ -557,34 +615,42 @@ def main(args: argparse.Namespace):
api_url = f"http://{args.host}:{args.port}{args.endpoint}" api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}" base_url = f"http://{args.host}:{args.port}"
tokenizer = get_tokenizer(tokenizer_id, tokenizer = get_tokenizer(
tokenizer_mode=tokenizer_mode, tokenizer_id,
trust_remote_code=args.trust_remote_code) tokenizer_mode=tokenizer_mode,
trust_remote_code=args.trust_remote_code,
)
if args.dataset_name is None: if args.dataset_name is None:
raise ValueError( raise ValueError(
"Please specify '--dataset-name' and the corresponding " "Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required.") "'--dataset-path' if required."
)
if args.dataset_name == "sonnet": if args.dataset_name == "sonnet":
dataset = SonnetDataset(dataset_path=args.dataset_path) dataset = SonnetDataset(dataset_path=args.dataset_path)
# For the "sonnet" dataset, formatting depends on the backend. # For the "sonnet" dataset, formatting depends on the backend.
if args.backend == "openai-chat": if args.backend == "openai-chat":
input_requests = dataset.sample(num_requests=args.num_prompts, input_requests = dataset.sample(
input_len=args.sonnet_input_len, num_requests=args.num_prompts,
output_len=args.sonnet_output_len, input_len=args.sonnet_input_len,
prefix_len=args.sonnet_prefix_len, output_len=args.sonnet_output_len,
tokenizer=tokenizer, prefix_len=args.sonnet_prefix_len,
return_prompt_formatted=False) tokenizer=tokenizer,
return_prompt_formatted=False,
)
else: else:
assert tokenizer.chat_template or tokenizer.default_chat_template, ( assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.") "Tokenizer/model must have chat template for sonnet dataset."
input_requests = dataset.sample(num_requests=args.num_prompts, )
input_len=args.sonnet_input_len, input_requests = dataset.sample(
output_len=args.sonnet_output_len, num_requests=args.num_prompts,
prefix_len=args.sonnet_prefix_len, input_len=args.sonnet_input_len,
tokenizer=tokenizer, output_len=args.sonnet_output_len,
return_prompt_formatted=True) prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=True,
)
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
# all following datasets are implemented from the # all following datasets are implemented from the
...@@ -611,23 +677,30 @@ def main(args: argparse.Namespace): ...@@ -611,23 +677,30 @@ def main(args: argparse.Namespace):
dataset_class = ASRDataset dataset_class = ASRDataset
args.hf_split = "train" args.hf_split = "train"
else: else:
supported_datasets = set([ supported_datasets = set(
dataset_name for cls in HuggingFaceDataset.__subclasses__() [
for dataset_name in cls.SUPPORTED_DATASET_PATHS dataset_name
]) for cls in HuggingFaceDataset.__subclasses__()
for dataset_name in cls.SUPPORTED_DATASET_PATHS
]
)
raise ValueError( raise ValueError(
f"Unsupported dataset path: {args.dataset_path}. " f"Unsupported dataset path: {args.dataset_path}. "
"Huggingface dataset only supports dataset_path" "Huggingface dataset only supports dataset_path"
f" from one of following: {supported_datasets}. " f" from one of following: {supported_datasets}. "
"Please consider contributing if you would " "Please consider contributing if you would "
"like to add support for additional dataset formats.") "like to add support for additional dataset formats."
)
if (dataset_class.IS_MULTIMODAL and backend not in \ if dataset_class.IS_MULTIMODAL and backend not in [
["openai-chat", "openai-audio"]): "openai-chat",
"openai-audio",
]:
# multi-modal benchmark is only available on OpenAI Chat backend. # multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError( raise ValueError(
"Multi-modal content is only supported on 'openai-chat' and " \ "Multi-modal content is only supported on 'openai-chat' and "
"'openai-audio' backend.") "'openai-audio' backend."
)
input_requests = dataset_class( input_requests = dataset_class(
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
dataset_subset=args.hf_subset, dataset_subset=args.hf_subset,
...@@ -642,26 +715,24 @@ def main(args: argparse.Namespace): ...@@ -642,26 +715,24 @@ def main(args: argparse.Namespace):
else: else:
# For datasets that follow a similar structure, use a mapping. # For datasets that follow a similar structure, use a mapping.
dataset_mapping = { dataset_mapping = {
"sharegpt": "sharegpt": lambda: ShareGPTDataset(
lambda: ShareGPTDataset(random_seed=args.seed, random_seed=args.seed, dataset_path=args.dataset_path
dataset_path=args.dataset_path).sample( ).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
), ),
"burstgpt": "burstgpt": lambda: BurstGPTDataset(
lambda: BurstGPTDataset(random_seed=args.seed, random_seed=args.seed, dataset_path=args.dataset_path
dataset_path=args.dataset_path). ).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
sample(tokenizer=tokenizer, num_requests=args.num_prompts), "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
"random":
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
prefix_len=args.random_prefix_len, prefix_len=args.random_prefix_len,
input_len=args.random_input_len, input_len=args.random_input_len,
output_len=args.random_output_len, output_len=args.random_output_len,
range_ratio=args.random_range_ratio, range_ratio=args.random_range_ratio,
) ),
} }
try: try:
...@@ -677,15 +748,16 @@ def main(args: argparse.Namespace): ...@@ -677,15 +748,16 @@ def main(args: argparse.Namespace):
"top_p": args.top_p, "top_p": args.top_p,
"top_k": args.top_k, "top_k": args.top_k,
"min_p": args.min_p, "min_p": args.min_p,
"temperature": args.temperature "temperature": args.temperature,
}.items() if v is not None }.items()
if v is not None
} }
# Sampling parameters are only supported by openai-compatible backend. # Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError( raise ValueError(
"Sampling parameters are only supported by openai-compatible " "Sampling parameters are only supported by openai-compatible backends."
"backends.") )
if "temperature" not in sampling_params: if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding. sampling_params["temperature"] = 0.0 # Default to greedy decoding.
...@@ -709,15 +781,14 @@ def main(args: argparse.Namespace): ...@@ -709,15 +781,14 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile, profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","), selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[ selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
extra_body=sampling_params, extra_body=sampling_params,
)) )
)
# Save config and results to json # Save config and results to json
if args.save_result or args.append_result: if args.save_result or args.append_result:
...@@ -742,8 +813,9 @@ def main(args: argparse.Namespace): ...@@ -742,8 +813,9 @@ def main(args: argparse.Namespace):
"Invalid metadata format. Please use KEY=VALUE format." "Invalid metadata format. Please use KEY=VALUE format."
) )
# Traffic # Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate result_json["request_rate"] = (
< float("inf") else "inf") args.request_rate if args.request_rate < float("inf") else "inf"
)
result_json["burstiness"] = args.burstiness result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency result_json["max_concurrency"] = args.max_concurrency
...@@ -753,24 +825,31 @@ def main(args: argparse.Namespace): ...@@ -753,24 +825,31 @@ def main(args: argparse.Namespace):
if not args.save_detailed: if not args.save_detailed:
# Remove fields with too many data points # Remove fields with too many data points
for field in [ for field in [
"input_lens", "output_lens", "ttfts", "itls", "input_lens",
"generated_texts", "errors" "output_lens",
"ttfts",
"itls",
"generated_texts",
"errors",
]: ]:
if field in result_json: if field in result_json:
del result_json[field] del result_json[field]
# Save to file # Save to file
base_model_id = model_id.split("/")[-1] base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}" max_concurrency_str = (
if args.max_concurrency is not None else "") f"-concurrency{args.max_concurrency}"
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa if args.max_concurrency is not None
else ""
)
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename: if args.result_filename:
file_name = args.result_filename file_name = args.result_filename
if args.result_dir: if args.result_dir:
file_name = os.path.join(args.result_dir, file_name) file_name = os.path.join(args.result_dir, file_name)
with open(file_name, with open(
mode="a+" if args.append_result else "w", file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
encoding='utf-8') as outfile: ) as outfile:
# Append a newline. # Append a newline.
if args.append_result and outfile.tell() != 0: if args.append_result and outfile.tell() != 0:
outfile.write("\n") outfile.write("\n")
...@@ -780,7 +859,8 @@ def main(args: argparse.Namespace): ...@@ -780,7 +859,8 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.") description="Benchmark the online serving throughput."
)
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
...@@ -809,11 +889,13 @@ if __name__ == "__main__": ...@@ -809,11 +889,13 @@ if __name__ == "__main__":
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
) )
parser.add_argument("--dataset-path", parser.add_argument(
type=str, "--dataset-path",
default=None, type=str,
help="Path to the sharegpt/sonnet dataset. " default=None,
"Or the huggingface dataset ID if using HF dataset.") help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.",
)
parser.add_argument( parser.add_argument(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
...@@ -825,7 +907,8 @@ if __name__ == "__main__": ...@@ -825,7 +907,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed " "initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the " "to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, " "actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.") "if the server is not processing requests fast enough to keep up.",
)
parser.add_argument( parser.add_argument(
"--model", "--model",
...@@ -836,8 +919,7 @@ if __name__ == "__main__": ...@@ -836,8 +919,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",
type=str, type=str,
help= help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
) )
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument( parser.add_argument(
...@@ -850,11 +932,13 @@ if __name__ == "__main__": ...@@ -850,11 +932,13 @@ if __name__ == "__main__":
"--logprobs", "--logprobs",
type=int, type=int,
default=None, default=None,
help=("Number of logprobs-per-token to compute & return as part of " help=(
"the request. If unspecified, then either (1) if beam search " "Number of logprobs-per-token to compute & return as part of "
"is disabled, no logprobs are computed & a single dummy " "the request. If unspecified, then either (1) if beam search "
"logprob is returned for each token; or (2) if beam search " "is disabled, no logprobs are computed & a single dummy "
"is enabled 1 logprob per token is computed"), "logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"
),
) )
parser.add_argument( parser.add_argument(
"--request-rate", "--request-rate",
...@@ -938,35 +1022,38 @@ if __name__ == "__main__": ...@@ -938,35 +1022,38 @@ if __name__ == "__main__":
"--ignore-eos", "--ignore-eos",
action="store_true", action="store_true",
help="Set ignore_eos flag when sending the benchmark request." help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.") "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument( parser.add_argument(
"--percentile-metrics", "--percentile-metrics",
type=str, type=str,
default="ttft,tpot,itl", default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. " help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. " "This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
"Default value is \"ttft,tpot,itl\".") 'Default value is "ttft,tpot,itl".',
)
parser.add_argument( parser.add_argument(
"--metric-percentiles", "--metric-percentiles",
type=str, type=str,
default="99", default="99",
help="Comma-separated list of percentiles for selected metrics. " help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
"Default value is \"99\". " 'Default value is "99". '
"Use \"--percentile-metrics\" to select metrics.", 'Use "--percentile-metrics" to select metrics.',
) )
parser.add_argument( parser.add_argument(
"--goodput", "--goodput",
nargs="+", nargs="+",
required=False, required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" " help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in " "pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are " "separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " '"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve") "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
# group for dataset specific arguments # group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group = parser.add_argument_group("sonnet dataset options")
...@@ -974,22 +1061,19 @@ if __name__ == "__main__": ...@@ -974,22 +1061,19 @@ if __name__ == "__main__":
"--sonnet-input-len", "--sonnet-input-len",
type=int, type=int,
default=550, default=550,
help= help="Number of input tokens per request, used only for sonnet dataset.",
"Number of input tokens per request, used only for sonnet dataset.",
) )
sonnet_group.add_argument( sonnet_group.add_argument(
"--sonnet-output-len", "--sonnet-output-len",
type=int, type=int,
default=150, default=150,
help= help="Number of output tokens per request, used only for sonnet dataset.",
"Number of output tokens per request, used only for sonnet dataset.",
) )
sonnet_group.add_argument( sonnet_group.add_argument(
"--sonnet-prefix-len", "--sonnet-prefix-len",
type=int, type=int,
default=200, default=200,
help= help="Number of prefix tokens per request, used only for sonnet dataset.",
"Number of prefix tokens per request, used only for sonnet dataset.",
) )
sharegpt_group = parser.add_argument_group("sharegpt dataset options") sharegpt_group = parser.add_argument_group("sharegpt dataset options")
...@@ -998,22 +1082,21 @@ if __name__ == "__main__": ...@@ -998,22 +1082,21 @@ if __name__ == "__main__":
type=int, type=int,
default=None, default=None,
help="Output length for each request. Overrides the output length " help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.") "from the ShareGPT dataset.",
)
random_group = parser.add_argument_group("random dataset options") random_group = parser.add_argument_group("random dataset options")
random_group.add_argument( random_group.add_argument(
"--random-input-len", "--random-input-len",
type=int, type=int,
default=1024, default=1024,
help= help="Number of input tokens per request, used only for random sampling.",
"Number of input tokens per request, used only for random sampling.",
) )
random_group.add_argument( random_group.add_argument(
"--random-output-len", "--random-output-len",
type=int, type=int,
default=128, default=128,
help= help="Number of output tokens per request, used only for random sampling.",
"Number of output tokens per request, used only for random sampling.",
) )
random_group.add_argument( random_group.add_argument(
"--random-range-ratio", "--random-range-ratio",
...@@ -1028,23 +1111,23 @@ if __name__ == "__main__": ...@@ -1028,23 +1111,23 @@ if __name__ == "__main__":
"--random-prefix-len", "--random-prefix-len",
type=int, type=int,
default=0, default=0,
help=("Number of fixed prefix tokens before the random context " help=(
"in a request. " "Number of fixed prefix tokens before the random context "
"The total input length is the sum of `random-prefix-len` and " "in a request. "
"a random " "The total input length is the sum of `random-prefix-len` and "
"context length sampled from [input_len * (1 - range_ratio), " "a random "
"input_len * (1 + range_ratio)]."), "context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."
),
) )
hf_group = parser.add_argument_group("hf dataset options") hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset", hf_group.add_argument(
type=str, "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
default=None, )
help="Subset of the HF dataset.") hf_group.add_argument(
hf_group.add_argument("--hf-split", "--hf-split", type=str, default=None, help="Split of the HF dataset."
type=str, )
default=None,
help="Split of the HF dataset.")
hf_group.add_argument( hf_group.add_argument(
"--hf-output-len", "--hf-output-len",
type=int, type=int,
...@@ -1058,52 +1141,58 @@ if __name__ == "__main__": ...@@ -1058,52 +1141,58 @@ if __name__ == "__main__":
"--top-p", "--top-p",
type=float, type=float,
default=None, default=None,
help="Top-p sampling parameter. Only has effect on openai-compatible " help="Top-p sampling parameter. Only has effect on openai-compatible backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--top-k", "--top-k",
type=int, type=int,
default=None, default=None,
help="Top-k sampling parameter. Only has effect on openai-compatible " help="Top-k sampling parameter. Only has effect on openai-compatible backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--min-p", "--min-p",
type=float, type=float,
default=None, default=None,
help="Min-p sampling parameter. Only has effect on openai-compatible " help="Min-p sampling parameter. Only has effect on openai-compatible backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--temperature", "--temperature",
type=float, type=float,
default=None, default=None,
help="Temperature sampling parameter. Only has effect on " help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy " "openai-compatible backends. If not specified, default to greedy "
"decoding (i.e. temperature==0.0).") "decoding (i.e. temperature==0.0).",
)
parser.add_argument( parser.add_argument(
'--tokenizer-mode', "--tokenizer-mode",
type=str, type=str,
default="auto", default="auto",
choices=['auto', 'slow', 'mistral', 'custom'], choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the ' help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will ' 'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* ' "always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*' '"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.') '"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument("--served-model-name",
type=str, parser.add_argument(
default=None, "--served-model-name",
help="The model name used in the API. " type=str,
"If not specified, the model name will be the " default=None,
"same as the ``--model`` argument. ") help="The model name used in the API. "
"If not specified, the model name will be the "
parser.add_argument("--lora-modules", "same as the ``--model`` argument. ",
nargs='+', )
default=None,
help="A subset of LoRA module names passed in when " parser.add_argument(
"launching the server. For each request, the " "--lora-modules",
"script chooses a LoRA module at random.") nargs="+",
default=None,
help="A subset of LoRA module names passed in when "
"launching the server. For each request, the "
"script chooses a LoRA module at random.",
)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -19,6 +19,7 @@ On the client side, run: ...@@ -19,6 +19,7 @@ On the client side, run:
--endpoint /generate_stream --endpoint /generate_stream
to the end of the command above. to the end of the command above.
""" """
import argparse import argparse
import asyncio import asyncio
import copy import copy
...@@ -36,11 +37,15 @@ from typing import Optional ...@@ -36,11 +37,15 @@ from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
RequestFuncOutput)
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
RequestFuncInput,
RequestFuncOutput,
)
try: try:
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError: except ImportError:
...@@ -52,7 +57,8 @@ except ImportError: ...@@ -52,7 +57,8 @@ except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser from argparse import ArgumentParser as FlexibleArgumentParser
from vllm.v1.structured_output.backend_xgrammar import ( from vllm.v1.structured_output.backend_xgrammar import (
has_xgrammar_unsupported_json_features) has_xgrammar_unsupported_json_features,
)
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
...@@ -98,6 +104,7 @@ class SampleRequest: ...@@ -98,6 +104,7 @@ class SampleRequest:
prompt_len: The length of the prompt in tokens. prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens. expected_output_len: The expected length of the output in tokens.
""" """
prompt: str prompt: str
prompt_len: int prompt_len: int
expected_output_len: int expected_output_len: int
...@@ -106,32 +113,28 @@ class SampleRequest: ...@@ -106,32 +113,28 @@ class SampleRequest:
completion: str = None completion: str = None
def sample_requests(tokenizer: PreTrainedTokenizerBase, def sample_requests(
args: argparse.Namespace) -> list[SampleRequest]: tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace
if args.dataset == 'json' or args.dataset == 'json-unique': ) -> list[SampleRequest]:
if args.dataset == "json" or args.dataset == "json-unique":
if args.json_schema_path is None: if args.json_schema_path is None:
dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = os.path.dirname(os.path.realpath(__file__))
args.json_schema_path = os.path.join(dir_path, args.json_schema_path = os.path.join(
"structured_schemas", dir_path, "structured_schemas", "structured_schema_1.json"
"structured_schema_1.json") )
json_schemas = [] json_schemas = []
with open(args.json_schema_path) as f: with open(args.json_schema_path) as f:
schema = json.load(f) schema = json.load(f)
if args.dataset == 'json-unique': if args.dataset == "json-unique":
json_schemas = [ json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)]
copy.deepcopy(schema) for _ in range(args.num_prompts)
]
for i in range(len(json_schemas)): for i in range(len(json_schemas)):
if "properties" not in json_schemas[i]: if "properties" not in json_schemas[i]:
json_schemas[i]["properties"] = {} json_schemas[i]["properties"] = {}
json_schemas[i]["properties"][ json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = {
f"__optional_field_{uuid.uuid4()}"] = { "type": "string",
"type": "description": "An unique optional field to avoid cached schemas",
"string", }
"description":
"An unique optional field to avoid cached schemas"
}
else: else:
json_schemas = [schema] * args.num_prompts json_schemas = [schema] * args.num_prompts
...@@ -142,11 +145,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -142,11 +145,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
return json_schemas[index % len(json_schemas)] return json_schemas[index % len(json_schemas)]
requests = [ requests = [
SampleRequest(prompt=gen_prompt(i), SampleRequest(
prompt_len=len(tokenizer(gen_prompt(i)).input_ids), prompt=gen_prompt(i),
expected_output_len=args.output_len, prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
schema=get_schema(i), expected_output_len=args.output_len,
structure_type=args.structure_type) schema=get_schema(i),
structure_type=args.structure_type,
)
for i in range(args.num_prompts) for i in range(args.num_prompts)
] ]
...@@ -170,11 +175,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -170,11 +175,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
input_len = len(tokenizer(prompt).input_ids) input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens") print(f"Input length of the prompt: {input_len} tokens")
requests = [ requests = [
SampleRequest(prompt=prompt, SampleRequest(
prompt_len=input_len, prompt=prompt,
expected_output_len=args.output_len, prompt_len=input_len,
schema=schema, expected_output_len=args.output_len,
structure_type=args.structure_type) schema=schema,
structure_type=args.structure_type,
)
for _ in range(args.num_prompts) for _ in range(args.num_prompts)
] ]
...@@ -188,11 +195,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -188,11 +195,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
input_len = len(tokenizer(prompt).input_ids) input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens") print(f"Input length of the prompt: {input_len} tokens")
requests = [ requests = [
SampleRequest(prompt=prompt, SampleRequest(
prompt_len=input_len, prompt=prompt,
expected_output_len=args.output_len, prompt_len=input_len,
schema=regex, expected_output_len=args.output_len,
structure_type=args.structure_type) schema=regex,
structure_type=args.structure_type,
)
for _ in range(args.num_prompts) for _ in range(args.num_prompts)
] ]
...@@ -203,48 +212,55 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -203,48 +212,55 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
input_len = len(tokenizer(prompt).input_ids) input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens") print(f"Input length of the prompt: {input_len} tokens")
requests = [ requests = [
SampleRequest(prompt=prompt, SampleRequest(
prompt_len=input_len, prompt=prompt,
expected_output_len=args.output_len, prompt_len=input_len,
schema=choice, expected_output_len=args.output_len,
structure_type=args.structure_type) schema=choice,
structure_type=args.structure_type,
)
for _ in range(args.num_prompts) for _ in range(args.num_prompts)
] ]
elif args.dataset == "xgrammar_bench": elif args.dataset == "xgrammar_bench":
requests: list[SampleRequest] = [] requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval", dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train")
split="train")
full_dataset_len = len(dataset) full_dataset_len = len(dataset)
def _filter_func(item): def _filter_func(item):
import json import json
schema = json.loads(item["schema"]) schema = json.loads(item["schema"])
return not has_xgrammar_unsupported_json_features(schema) return not has_xgrammar_unsupported_json_features(schema)
dataset = dataset.filter(_filter_func) dataset = dataset.filter(_filter_func)
num_filtered_out = full_dataset_len - len(dataset) num_filtered_out = full_dataset_len - len(dataset)
print(f"dataset has {len(dataset)} entries after filtering " print(
f"out {num_filtered_out} entries with unsupported features") f"dataset has {len(dataset)} entries after filtering "
f"out {num_filtered_out} entries with unsupported features"
)
len_dataset = len(dataset) len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts): for data_point_idx in range(args.num_prompts):
idx = data_point_idx idx = data_point_idx
while idx >= len_dataset: while idx >= len_dataset:
idx -= len_dataset idx -= len_dataset
schema = dataset["schema"][idx] schema = dataset["schema"][idx]
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], prompt = tokenizer.apply_chat_template(
tokenize=False, dataset["prompt"][idx], tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
input_len = len(tokenizer(prompt).input_ids) input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx] completion = dataset["completion"][idx]
requests.append( requests.append(
SampleRequest(prompt=prompt, SampleRequest(
prompt_len=input_len, prompt=prompt,
expected_output_len=args.output_len, prompt_len=input_len,
schema=schema, expected_output_len=args.output_len,
structure_type=args.structure_type, schema=schema,
completion=completion)) structure_type=args.structure_type,
completion=completion,
)
)
return requests return requests
...@@ -276,7 +292,8 @@ async def get_request( ...@@ -276,7 +292,8 @@ async def get_request(
# Calculate scale parameter theta to maintain the desired request_rate. # Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, ( assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.") f"A positive burstiness factor is expected, but given {burstiness}."
)
theta = 1.0 / (request_rate * burstiness) theta = 1.0 / (request_rate * burstiness)
for i, request in enumerate(input_requests): for i, request in enumerate(input_requests):
...@@ -318,8 +335,8 @@ def calculate_metrics( ...@@ -318,8 +335,8 @@ def calculate_metrics(
# multiple output tokens may be bundled together # multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly # Note : this may inflate the output token count slightly
output_len = len( output_len = len(
tokenizer(outputs[i].generated_text, tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
add_special_tokens=False).input_ids) )
actual_output_lens.append(output_len) actual_output_lens.append(output_len)
total_input += input_requests[i].prompt_len total_input += input_requests[i].prompt_len
tpot = 0 tpot = 0
...@@ -343,16 +360,19 @@ def calculate_metrics( ...@@ -343,16 +360,19 @@ def calculate_metrics(
if "ttft" in goodput_config_dict: if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts) valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] / slo_values.append(
MILLISECONDS_TO_SECONDS_CONVERSION) goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "tpot" in goodput_config_dict: if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots) valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] / slo_values.append(
MILLISECONDS_TO_SECONDS_CONVERSION) goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "e2el" in goodput_config_dict: if "e2el" in goodput_config_dict:
valid_metrics.append(e2els) valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] / slo_values.append(
MILLISECONDS_TO_SECONDS_CONVERSION) goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
for req_metric in zip(*valid_metrics): for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
...@@ -363,7 +383,8 @@ def calculate_metrics( ...@@ -363,7 +383,8 @@ def calculate_metrics(
warnings.warn( warnings.warn(
"All requests failed. This is likely due to a misconfiguration " "All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.", "on the benchmark arguments.",
stacklevel=2) stacklevel=2,
)
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
...@@ -372,27 +393,31 @@ def calculate_metrics( ...@@ -372,27 +393,31 @@ def calculate_metrics(
request_goodput=good_completed / dur_s, request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s, output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0) * mean_ttft_ms=np.mean(ttfts or 0)
1000, # ttfts is empty if streaming is not supported by backend * 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) percentiles_ttft_ms=[
for p in selected_percentiles], (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
],
mean_tpot_ms=np.mean(tpots or 0) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) percentiles_tpot_ms=[
for p in selected_percentiles], (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or 0) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) percentiles_itl_ms=[
for p in selected_percentiles], (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or 0) * 1000, mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) percentiles_e2el_ms=[
for p in selected_percentiles], (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
],
) )
return metrics, actual_output_lens return metrics, actual_output_lens
...@@ -429,12 +454,13 @@ async def benchmark( ...@@ -429,12 +454,13 @@ async def benchmark(
print("Starting initial single prompt test run...") print("Starting initial single prompt test run...")
structured_output_req_idx = random.sample( structured_output_req_idx = random.sample(
range(len(input_requests)), range(len(input_requests)), int(len(input_requests) * structured_output_ratio)
int(len(input_requests) * structured_output_ratio)) )
test_request = input_requests[0] test_request = input_requests[0]
test_req_extra_body = (prepare_extra_body(test_request) test_req_extra_body = (
if 0 in structured_output_req_idx else None) prepare_extra_body(test_request) if 0 in structured_output_req_idx else None
)
test_input = RequestFuncInput( test_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=test_request.prompt, prompt=test_request.prompt,
...@@ -448,7 +474,8 @@ async def benchmark( ...@@ -448,7 +474,8 @@ async def benchmark(
if not test_output.success: if not test_output.success:
raise ValueError( raise ValueError(
"Initial test run failed - Please make sure benchmark arguments " "Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}") f"are correctly specified. Error: {test_output.error}"
)
else: else:
print("Initial test run completed. Starting main benchmark run...") print("Initial test run completed. Starting main benchmark run...")
...@@ -467,10 +494,7 @@ async def benchmark( ...@@ -467,10 +494,7 @@ async def benchmark(
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
if burstiness == 1.0: distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
distribution = "Poisson process"
else:
distribution = "Gamma distribution"
print(f"Traffic request rate: {request_rate}") print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Burstiness factor: {burstiness} ({distribution})")
...@@ -482,24 +506,21 @@ async def benchmark( ...@@ -482,24 +506,21 @@ async def benchmark(
# and it will simplify the code in limited_request_func. # and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency) # semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext()) # if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency) semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
if max_concurrency else None)
async def limited_request_func(request_func_input, pbar): async def limited_request_func(request_func_input, pbar):
if semaphore is None: if semaphore is None:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
expected: list[str] = [] expected: list[str] = []
async for i, request in get_request(input_requests, request_rate, async for i, request in get_request(input_requests, request_rate, burstiness):
burstiness): extra_body = (
extra_body = prepare_extra_body( prepare_extra_body(request) if i in structured_output_req_idx else None
request) if i in structured_output_req_idx else None )
request_func_input = RequestFuncInput( request_func_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=request.prompt, prompt=request.prompt,
...@@ -512,8 +533,9 @@ async def benchmark( ...@@ -512,8 +533,9 @@ async def benchmark(
expected.append(request.completion) expected.append(request.completion)
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar))) )
)
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile: if profile:
...@@ -545,54 +567,58 @@ async def benchmark( ...@@ -545,54 +567,58 @@ async def benchmark(
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
) )
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
metrics.total_output)) print(
print("{:<40} {:<10.2f}".format("Request throughput (req/s):", "{:<40} {:<10.2f}".format(
metrics.request_throughput)) "Request throughput (req/s):", metrics.request_throughput
)
)
if goodput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print(
metrics.request_goodput)) "{:<40} {:<10.2f}".format(
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", "Request goodput (req/s):", metrics.request_goodput
metrics.output_throughput)) )
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", )
metrics.total_token_throughput)) print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Total Token throughput (tok/s):", metrics.total_token_throughput
)
)
result = { result = {
"duration": "duration": benchmark_duration,
benchmark_duration, "completed": metrics.completed,
"completed": "total_input_tokens": metrics.total_input,
metrics.completed, "total_output_tokens": metrics.total_output,
"total_input_tokens": "request_throughput": metrics.request_throughput,
metrics.total_input, "output_throughput": metrics.output_throughput,
"total_output_tokens": "total_token_throughput": metrics.total_token_throughput,
metrics.total_output, "ttft_description": pd.Series([output.ttft for output in outputs])
"request_throughput": .describe()
metrics.request_throughput, .to_dict(),
"output_throughput": "tpot_description": pd.Series([output.tpot for output in outputs])
metrics.output_throughput, .describe()
"total_token_throughput": .to_dict(),
metrics.total_token_throughput,
"ttft_description":
pd.Series([output.ttft for output in outputs]).describe().to_dict(),
"tpot_description":
pd.Series([output.tpot for output in outputs]).describe().to_dict(),
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
"output_lens": "output_lens": actual_output_lens,
actual_output_lens,
"ttfts": [output.ttft for output in outputs], "ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs], "itls": [output.itl for output in outputs],
"errors": [output.error for output in outputs], "errors": [output.error for output in outputs],
} }
ret = [{ ret = [
'generated': output.generated_text, {"generated": output.generated_text, "expected": gt}
'expected': gt for output, gt in zip(outputs, expected)
} for output, gt in zip(outputs, expected)] ]
def process_one_metric( def process_one_metric(
# E.g., "ttft" # E.g., "ttft"
...@@ -606,29 +632,35 @@ async def benchmark( ...@@ -606,29 +632,35 @@ async def benchmark(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
f"Mean {metric_name} (ms):", "{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}_ms"))) f"Mean {metric_name} (ms):",
print("{:<40} {:<10.2f}".format( getattr(metrics, f"mean_{metric_attribute_name}_ms"),
f"Median {metric_name} (ms):", )
getattr(metrics, f"median_{metric_attribute_name}_ms"))) )
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr( result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms") metrics, f"mean_{metric_attribute_name}_ms"
)
result[f"median_{metric_attribute_name}_ms"] = getattr( result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms") metrics, f"median_{metric_attribute_name}_ms"
)
result[f"std_{metric_attribute_name}_ms"] = getattr( result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms") metrics, f"std_{metric_attribute_name}_ms"
for p, value in getattr(metrics, )
f"percentiles_{metric_attribute_name}_ms"): for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token") process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT", process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency")
...@@ -638,13 +670,13 @@ async def benchmark( ...@@ -638,13 +670,13 @@ async def benchmark(
def evaluate(ret, args): def evaluate(ret, args):
def _eval_correctness_json(expected, actual): def _eval_correctness_json(expected, actual):
# extract json string from string using regex # extract json string from string using regex
import re import re
actual = actual.replace('\n', '').replace(' ', '').strip()
actual = actual.replace("\n", "").replace(" ", "").strip()
try: try:
actual = re.search(r'\{.*\}', actual).group() actual = re.search(r"\{.*\}", actual).group()
actual = json.loads(actual) actual = json.loads(actual)
except Exception: except Exception:
return False return False
...@@ -656,28 +688,32 @@ def evaluate(ret, args): ...@@ -656,28 +688,32 @@ def evaluate(ret, args):
def _eval_correctness_regex(expected, actual): def _eval_correctness_regex(expected, actual):
import re import re
return re.match(args.regex, actual) is not None return re.match(args.regex, actual) is not None
def _eval_correctness(expected, actual): def _eval_correctness(expected, actual):
if args.structure_type == 'guided_json': if args.structure_type == "guided_json":
return _eval_correctness_json(expected, actual) return _eval_correctness_json(expected, actual)
elif args.structure_type == 'guided_regex': elif args.structure_type == "guided_regex":
return _eval_correctness_regex(expected, actual) return _eval_correctness_regex(expected, actual)
elif args.structure_type == 'guided_choice': elif args.structure_type == "guided_choice":
return _eval_correctness_choice(expected, actual) return _eval_correctness_choice(expected, actual)
else: else:
return None return None
scores = [] scores = []
for res in ret: for res in ret:
score = _eval_correctness(res['expected'], res['generated']) score = _eval_correctness(res["expected"], res["generated"])
res['correctness'] = score res["correctness"] = score
scores.append(score) scores.append(score)
not_none_scores = [score for score in scores if score is not None] not_none_scores = [score for score in scores if score is not None]
return (sum(not_none_scores) / len(not_none_scores) * return (
100) if len(not_none_scores) > 0 else None (sum(not_none_scores) / len(not_none_scores) * 100)
if len(not_none_scores) > 0
else None
)
def parse_goodput(slo_pairs): def parse_goodput(slo_pairs):
...@@ -689,9 +725,10 @@ def parse_goodput(slo_pairs): ...@@ -689,9 +725,10 @@ def parse_goodput(slo_pairs):
except ValueError as err: except ValueError as err:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. " "Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" " 'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a " "pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err "number in milliseconds."
) from err
return goodput_config_dict return goodput_config_dict
...@@ -705,12 +742,14 @@ def check_goodput_args(args): ...@@ -705,12 +742,14 @@ def check_goodput_args(args):
raise ValueError( raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. " f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of " "The service level objective name should be one of "
f"{str(VALID_NAMES)}. ") f"{str(VALID_NAMES)}. "
)
if slo_val < 0: if slo_val < 0:
raise ValueError( raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. " f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be " "The service level objective value should be "
"non-negative.") "non-negative."
)
return goodput_config_dict return goodput_config_dict
...@@ -736,19 +775,19 @@ def main(args: argparse.Namespace): ...@@ -736,19 +775,19 @@ def main(args: argparse.Namespace):
tokenizer_mode=args.tokenizer_mode, tokenizer_mode=args.tokenizer_mode,
) )
if args.dataset == 'grammar': if args.dataset == "grammar":
args.structure_type = 'guided_grammar' args.structure_type = "guided_grammar"
elif args.dataset == 'regex': elif args.dataset == "regex":
args.structure_type = 'guided_regex' args.structure_type = "guided_regex"
elif args.dataset == 'choice': elif args.dataset == "choice":
args.structure_type = 'guided_choice' args.structure_type = "guided_choice"
else: else:
args.structure_type = 'guided_json' args.structure_type = "guided_json"
if args.no_structured_output: if args.no_structured_output:
args.structured_output_ratio = 0 args.structured_output_ratio = 0
if args.save_results: if args.save_results:
result_file_name = f'{args.structured_output_ratio}guided' result_file_name = f"{args.structured_output_ratio}guided"
result_file_name += f"_{backend}" result_file_name += f"_{backend}"
result_file_name += f"_{args.request_rate}qps" result_file_name += f"_{args.request_rate}qps"
result_file_name += f"_{args.model.split('/')[-1]}" result_file_name += f"_{args.model.split('/')[-1]}"
...@@ -776,36 +815,29 @@ def main(args: argparse.Namespace): ...@@ -776,36 +815,29 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile, profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","), selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[ selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
structured_output_ratio=args.structured_output_ratio, structured_output_ratio=args.structured_output_ratio,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
)) )
)
# Save config and results to json # Save config and results to json
score = evaluate(ret, args) score = evaluate(ret, args)
print("correct_rate(%)", score, '\n') print("correct_rate(%)", score, "\n")
if args.save_results: if args.save_results:
results = { results = {
"backend": "backend": backend,
backend, "model_id": model_id,
"model_id": "tokenizer_id": tokenizer_id,
model_id, "num_prompts": args.num_prompts,
"tokenizer_id": "request_rate": args.request_rate
tokenizer_id, if args.request_rate < float("inf")
"num_prompts": else "inf",
args.num_prompts, "burstiness": args.burstiness,
"request_rate": "max_concurrency": args.max_concurrency,
args.request_rate if args.request_rate < float("inf") else "inf", "correct_rate(%)": score,
"burstiness":
args.burstiness,
"max_concurrency":
args.max_concurrency,
"correct_rate(%)":
score
} }
results = {"outputs": ret, **results, **benchmark_result} results = {"outputs": ret, **results, **benchmark_result}
...@@ -814,13 +846,14 @@ def main(args: argparse.Namespace): ...@@ -814,13 +846,14 @@ def main(args: argparse.Namespace):
result_file_name = args.result_filename result_file_name = args.result_filename
if args.result_dir: if args.result_dir:
result_file_name = os.path.join(args.result_dir, result_file_name) result_file_name = os.path.join(args.result_dir, result_file_name)
with open(result_file_name, "w", encoding='utf-8') as outfile: with open(result_file_name, "w", encoding="utf-8") as outfile:
json.dump(results, outfile, indent=4) json.dump(results, outfile, indent=4)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.") description="Benchmark the online serving throughput."
)
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
...@@ -842,16 +875,14 @@ if __name__ == "__main__": ...@@ -842,16 +875,14 @@ if __name__ == "__main__":
default="/v1/completions", default="/v1/completions",
help="API endpoint.", help="API endpoint.",
) )
parser.add_argument("--dataset", parser.add_argument(
default='json', "--dataset",
choices=[ default="json",
'json', 'json-unique', 'grammar', 'regex', choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"],
'choice', 'xgrammar_bench' )
]) parser.add_argument(
parser.add_argument("--json-schema-path", "--json-schema-path", type=str, default=None, help="Path to json schema."
type=str, )
default=None,
help="Path to json schema.")
parser.add_argument( parser.add_argument(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
...@@ -863,7 +894,8 @@ if __name__ == "__main__": ...@@ -863,7 +894,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed " "initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the " "to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, " "actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.") "if the server is not processing requests fast enough to keep up.",
)
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
...@@ -873,15 +905,13 @@ if __name__ == "__main__": ...@@ -873,15 +905,13 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",
type=str, type=str,
help= help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
) )
parser.add_argument( parser.add_argument(
"--tokenizer-mode", "--tokenizer-mode",
type=str, type=str,
default="auto", default="auto",
help= help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
) )
parser.add_argument( parser.add_argument(
"--num-prompts", "--num-prompts",
...@@ -958,44 +988,51 @@ if __name__ == "__main__": ...@@ -958,44 +988,51 @@ if __name__ == "__main__":
"--ignore-eos", "--ignore-eos",
action="store_true", action="store_true",
help="Set ignore_eos flag when sending the benchmark request." help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.") "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument( parser.add_argument(
"--percentile-metrics", "--percentile-metrics",
type=str, type=str,
default="ttft,tpot,itl", default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. " help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. " "This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
"Default value is \"ttft,tpot,itl\".") 'Default value is "ttft,tpot,itl".',
)
parser.add_argument( parser.add_argument(
"--metric-percentiles", "--metric-percentiles",
type=str, type=str,
default="99", default="99",
help="Comma-separated list of percentiles for selected metrics. " help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
"Default value is \"99\". " 'Default value is "99". '
"Use \"--percentile-metrics\" to select metrics.", 'Use "--percentile-metrics" to select metrics.',
) )
parser.add_argument( parser.add_argument(
"--goodput", "--goodput",
nargs="+", nargs="+",
required=False, required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" " help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in " "pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are " "separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " '"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve") "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
parser.add_argument("--no-structured-output",
action='store_true', parser.add_argument(
default=False, "--no-structured-output",
help="Whether to disable JSON decoding or not.") action="store_true",
parser.add_argument("--structured-output-ratio", default=False,
type=float, help="Whether to disable JSON decoding or not.",
default=1.0, )
help="Ratio of Structured Outputs requests") parser.add_argument(
"--structured-output-ratio",
type=float,
default=1.0,
help="Ratio of Structured Outputs requests",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Benchmark offline inference throughput.""" """Benchmark offline inference throughput."""
import argparse import argparse
import dataclasses import dataclasses
import json import json
...@@ -11,18 +12,25 @@ from typing import Any, Optional, Union ...@@ -11,18 +12,25 @@ from typing import Any, Optional, Union
import torch import torch
import uvloop import uvloop
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
ConversationDataset, InstructCoderDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
PreTrainedTokenizerBase)
from benchmark_dataset import (
AIMODataset,
BurstGPTDataset,
ConversationDataset,
InstructCoderDataset,
RandomDataset,
SampleRequest,
ShareGPTDataset,
SonnetDataset,
VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args) build_async_engine_client_from_engine_args,
)
from vllm.inputs import TextPrompt, TokensPrompt from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -37,23 +45,30 @@ def run_vllm( ...@@ -37,23 +45,30 @@ def run_vllm(
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]: ) -> tuple[float, Optional[list[RequestOutput]]]:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert all( assert all(
llm.llm_engine.model_config.max_model_len >= ( llm.llm_engine.model_config.max_model_len
request.prompt_len + request.expected_output_len) >= (request.prompt_len + request.expected_output_len)
for request in requests), ( for request in requests
"Please ensure that max_model_len is greater than the sum of" ), (
" prompt_len and expected_output_len for all requests.") "Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine. # Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = [] prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = [] sampling_params: list[SamplingParams] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], TokensPrompt(
multi_modal_data=request.multi_modal_data) prompt_token_ids=request.prompt["prompt_token_ids"],
if "prompt_token_ids" in request.prompt else \ multi_modal_data=request.multi_modal_data,
TextPrompt(prompt=request.prompt, )
multi_modal_data=request.multi_modal_data)) if "prompt_token_ids" in request.prompt
else TextPrompt(
prompt=request.prompt, multi_modal_data=request.multi_modal_data
)
)
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
...@@ -62,7 +77,8 @@ def run_vllm( ...@@ -62,7 +77,8 @@ def run_vllm(
ignore_eos=True, ignore_eos=True,
max_tokens=request.expected_output_len, max_tokens=request.expected_output_len,
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) )
)
lora_requests: Optional[list[LoRARequest]] = None lora_requests: Optional[list[LoRARequest]] = None
if engine_args.enable_lora: if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests] lora_requests = [request.lora_request for request in requests]
...@@ -72,10 +88,9 @@ def run_vllm( ...@@ -72,10 +88,9 @@ def run_vllm(
outputs = None outputs = None
if not use_beam_search: if not use_beam_search:
start = time.perf_counter() start = time.perf_counter()
outputs = llm.generate(prompts, outputs = llm.generate(
sampling_params, prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
lora_request=lora_requests, )
use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
else: else:
assert lora_requests is None, "BeamSearch API does not support LoRA" assert lora_requests is None, "BeamSearch API does not support LoRA"
...@@ -91,30 +106,35 @@ def run_vllm( ...@@ -91,30 +106,35 @@ def run_vllm(
beam_width=n, beam_width=n,
max_tokens=output_len, max_tokens=output_len,
ignore_eos=True, ignore_eos=True,
)) ),
)
end = time.perf_counter() end = time.perf_counter()
return end - start, outputs return end - start, outputs
def run_vllm_chat( def run_vllm_chat(
requests: list[SampleRequest], requests: list[SampleRequest],
n: int, n: int,
engine_args: EngineArgs, engine_args: EngineArgs,
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
""" """
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
multimodal models as it properly handles multimodal inputs and chat multimodal models as it properly handles multimodal inputs and chat
formatting. For non-multimodal models, use run_vllm() instead. formatting. For non-multimodal models, use run_vllm() instead.
""" """
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert all( assert all(
llm.llm_engine.model_config.max_model_len >= ( llm.llm_engine.model_config.max_model_len
request.prompt_len + request.expected_output_len) >= (request.prompt_len + request.expected_output_len)
for request in requests), ( for request in requests
"Please ensure that max_model_len is greater than the sum of " ), (
"prompt_len and expected_output_len for all requests.") "Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests."
)
prompts = [] prompts = []
sampling_params: list[SamplingParams] = [] sampling_params: list[SamplingParams] = []
...@@ -128,7 +148,8 @@ def run_vllm_chat( ...@@ -128,7 +148,8 @@ def run_vllm_chat(
ignore_eos=True, ignore_eos=True,
max_tokens=request.expected_output_len, max_tokens=request.expected_output_len,
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) )
)
start = time.perf_counter() start = time.perf_counter()
outputs = llm.chat(prompts, sampling_params, use_tqdm=True) outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
...@@ -145,14 +166,17 @@ async def run_vllm_async( ...@@ -145,14 +166,17 @@ async def run_vllm_async(
from vllm import SamplingParams from vllm import SamplingParams
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm: engine_args, disable_frontend_multiprocessing
) as llm:
model_config = await llm.get_model_config() model_config = await llm.get_model_config()
assert all( assert all(
model_config.max_model_len >= (request.prompt_len + model_config.max_model_len
request.expected_output_len) >= (request.prompt_len + request.expected_output_len)
for request in requests), ( for request in requests
"Please ensure that max_model_len is greater than the sum of" ), (
" prompt_len and expected_output_len for all requests.") "Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine. # Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = [] prompts: list[Union[TextPrompt, TokensPrompt]] = []
...@@ -160,11 +184,15 @@ async def run_vllm_async( ...@@ -160,11 +184,15 @@ async def run_vllm_async(
lora_requests: list[Optional[LoRARequest]] = [] lora_requests: list[Optional[LoRARequest]] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], TokensPrompt(
multi_modal_data=request.multi_modal_data) prompt_token_ids=request.prompt["prompt_token_ids"],
if "prompt_token_ids" in request.prompt else \ multi_modal_data=request.multi_modal_data,
TextPrompt(prompt=request.prompt, )
multi_modal_data=request.multi_modal_data)) if "prompt_token_ids" in request.prompt
else TextPrompt(
prompt=request.prompt, multi_modal_data=request.multi_modal_data
)
)
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
...@@ -173,17 +201,16 @@ async def run_vllm_async( ...@@ -173,17 +201,16 @@ async def run_vllm_async(
ignore_eos=True, ignore_eos=True,
max_tokens=request.expected_output_len, max_tokens=request.expected_output_len,
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) )
)
lora_requests.append(request.lora_request) lora_requests.append(request.lora_request)
generators = [] generators = []
start = time.perf_counter() start = time.perf_counter()
for i, (prompt, sp, for i, (prompt, sp, lr) in enumerate(
lr) in enumerate(zip(prompts, sampling_params, lora_requests)): zip(prompts, sampling_params, lora_requests)
generator = llm.generate(prompt, ):
sp, generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
lora_request=lr,
request_id=f"test{i}")
generators.append(generator) generators.append(generator)
all_gens = merge_async_iterators(*generators) all_gens = merge_async_iterators(*generators)
async for i, res in all_gens: async for i, res in all_gens:
...@@ -202,7 +229,8 @@ def run_hf( ...@@ -202,7 +229,8 @@ def run_hf(
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> float: ) -> float:
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
)
if llm.config.model_type == "llama": if llm.config.model_type == "llama":
# To enable padding in the HF backend. # To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
...@@ -225,14 +253,15 @@ def run_hf( ...@@ -225,14 +253,15 @@ def run_hf(
# Check if we can add more requests to the batch. # Check if we can add more requests to the batch.
next_prompt_len = requests[i + 1].prompt_len next_prompt_len = requests[i + 1].prompt_len
next_output_len = requests[i + 1].expected_output_len next_output_len = requests[i + 1].expected_output_len
if (max(max_prompt_len, next_prompt_len) + if (
max(max_output_len, next_output_len)) <= 2048: max(max_prompt_len, next_prompt_len)
+ max(max_output_len, next_output_len)
) <= 2048:
# We can add more requests to the batch. # We can add more requests to the batch.
continue continue
# Generate the sequences. # Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
padding=True).input_ids
llm_outputs = llm.generate( llm_outputs = llm.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
do_sample=True, do_sample=True,
...@@ -262,6 +291,7 @@ def run_mii( ...@@ -262,6 +291,7 @@ def run_mii(
output_len: int, output_len: int,
) -> float: ) -> float:
from mii import client, serve from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size) llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [request.prompt for request in requests] prompts = [request.prompt for request in requests]
...@@ -273,8 +303,9 @@ def run_mii( ...@@ -273,8 +303,9 @@ def run_mii(
return end - start return end - start
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(
results: dict[str, Any]) -> None: args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={ metrics={
...@@ -282,9 +313,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, ...@@ -282,9 +313,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
"tokens_per_second": [results["tokens_per_second"]], "tokens_per_second": [results["tokens_per_second"]],
}, },
extra_info={ extra_info={
k: results[k] k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
for k in ["elapsed_time", "num_requests", "total_num_tokens"] },
}) )
if pt_records: if pt_records:
# Don't use json suffix here as we don't want CI to pick it up # Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
...@@ -316,7 +347,8 @@ def get_requests(args, tokenizer): ...@@ -316,7 +347,8 @@ def get_requests(args, tokenizer):
sample_kwargs["enable_multimodal_chat"] = True sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, ( assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.") "Tokenizer/model must have chat template for sonnet dataset."
)
dataset_cls = SonnetDataset dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True sample_kwargs["return_prompt_formatted"] = True
...@@ -325,21 +357,21 @@ def get_requests(args, tokenizer): ...@@ -325,21 +357,21 @@ def get_requests(args, tokenizer):
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset dataset_cls = VisionArenaDataset
common_kwargs['dataset_subset'] = None common_kwargs["dataset_subset"] = None
common_kwargs['dataset_split'] = "train" common_kwargs["dataset_split"] = "train"
sample_kwargs["enable_multimodal_chat"] = True sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset dataset_cls = InstructCoderDataset
common_kwargs['dataset_split'] = "train" common_kwargs["dataset_split"] = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset dataset_cls = ConversationDataset
common_kwargs['dataset_subset'] = args.hf_subset common_kwargs["dataset_subset"] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split common_kwargs["dataset_split"] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_cls = AIMODataset dataset_cls = AIMODataset
common_kwargs['dataset_subset'] = None common_kwargs["dataset_subset"] = None
common_kwargs['dataset_split'] = "train" common_kwargs["dataset_split"] = "train"
else: else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}") raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values # Remove None values
...@@ -354,10 +386,10 @@ def main(args: argparse.Namespace): ...@@ -354,10 +386,10 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code
)
requests = get_requests(args, tokenizer) requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None is_multi_modal = any(request.multi_modal_data is not None for request in requests)
for request in requests)
request_outputs: Optional[list[RequestOutput]] = None request_outputs: Optional[list[RequestOutput]] = None
if args.backend == "vllm": if args.backend == "vllm":
if args.async_engine: if args.async_engine:
...@@ -368,23 +400,34 @@ def main(args: argparse.Namespace): ...@@ -368,23 +400,34 @@ def main(args: argparse.Namespace):
AsyncEngineArgs.from_cli_args(args), AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing, args.disable_frontend_multiprocessing,
args.disable_detokenize, args.disable_detokenize,
)) )
)
else: else:
elapsed_time, request_outputs = run_vllm( elapsed_time, request_outputs = run_vllm(
requests, args.n, EngineArgs.from_cli_args(args), requests,
args.disable_detokenize) args.n,
EngineArgs.from_cli_args(args),
args.disable_detokenize,
)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(
args.hf_max_batch_size, args.trust_remote_code, requests,
args.disable_detokenize) args.model,
tokenizer,
args.n,
args.hf_max_batch_size,
args.trust_remote_code,
args.disable_detokenize,
)
elif args.backend == "mii": elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, elapsed_time = run_mii(
args.output_len) requests, args.model, args.tensor_parallel_size, args.output_len
)
elif args.backend == "vllm-chat": elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat( elapsed_time, request_outputs = run_vllm_chat(
requests, args.n, EngineArgs.from_cli_args(args), requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
args.disable_detokenize) )
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
...@@ -396,28 +439,31 @@ def main(args: argparse.Namespace): ...@@ -396,28 +439,31 @@ def main(args: argparse.Namespace):
for ro in request_outputs: for ro in request_outputs:
if not isinstance(ro, RequestOutput): if not isinstance(ro, RequestOutput):
continue continue
total_prompt_tokens += len( total_prompt_tokens += (
ro.prompt_token_ids) if ro.prompt_token_ids else 0 len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
total_output_tokens += sum( )
len(o.token_ids) for o in ro.outputs if o) total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
total_num_tokens = total_prompt_tokens + total_output_tokens total_num_tokens = total_prompt_tokens + total_output_tokens
else: else:
total_num_tokens = sum(r.prompt_len + r.expected_output_len total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
for r in requests)
total_output_tokens = sum(r.expected_output_len for r in requests) total_output_tokens = sum(r.expected_output_len for r in requests)
total_prompt_tokens = total_num_tokens - total_output_tokens total_prompt_tokens = total_num_tokens - total_output_tokens
if is_multi_modal and args.backend != "vllm-chat": if is_multi_modal and args.backend != "vllm-chat":
print("\033[91mWARNING\033[0m: Multi-modal request with " print(
f"{args.backend} backend detected. The " "\033[91mWARNING\033[0m: Multi-modal request with "
"following metrics are not accurate because image tokens are not" f"{args.backend} backend detected. The "
" counted. See vllm-project/vllm/issues/9778 for details.") "following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details."
)
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
# vllm-chat backend counts the image tokens now # vllm-chat backend counts the image tokens now
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s") f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
)
print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}") print(f"Total num output tokens: {total_output_tokens}")
...@@ -445,7 +491,8 @@ def validate_args(args): ...@@ -445,7 +491,8 @@ def validate_args(args):
warnings.warn( warnings.warn(
"The '--dataset' argument will be deprecated in the next release. " "The '--dataset' argument will be deprecated in the next release. "
"Please use '--dataset-name' and '--dataset-path' instead.", "Please use '--dataset-name' and '--dataset-path' instead.",
stacklevel=2) stacklevel=2,
)
args.dataset_path = args.dataset args.dataset_path = args.dataset
if not getattr(args, "tokenizer", None): if not getattr(args, "tokenizer", None):
...@@ -458,9 +505,8 @@ def validate_args(args): ...@@ -458,9 +505,8 @@ def validate_args(args):
# === Dataset Configuration === # === Dataset Configuration ===
if not args.dataset and not args.dataset_path: if not args.dataset and not args.dataset_path:
print( print("When dataset path is not set, it will default to random dataset")
"When dataset path is not set, it will default to random dataset") args.dataset_name = "random"
args.dataset_name = 'random'
if args.input_len is None: if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset") raise ValueError("input_len must be provided for a random dataset")
...@@ -468,41 +514,55 @@ def validate_args(args): ...@@ -468,41 +514,55 @@ def validate_args(args):
# --hf-subset and --hf-split: only used # --hf-subset and --hf-split: only used
# when dataset_name is 'hf' # when dataset_name is 'hf'
if args.dataset_name != "hf" and ( if args.dataset_name != "hf" and (
getattr(args, "hf_subset", None) is not None getattr(args, "hf_subset", None) is not None
or getattr(args, "hf_split", None) is not None): or getattr(args, "hf_split", None) is not None
warnings.warn("--hf-subset and --hf-split will be ignored \ ):
warnings.warn(
"--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.", since --dataset-name is not 'hf'.",
stacklevel=2) stacklevel=2,
)
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
if args.dataset_path in ( if args.dataset_path in (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| ConversationDataset.SUPPORTED_DATASET_PATHS): | ConversationDataset.SUPPORTED_DATASET_PATHS
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 ):
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS assert args.backend == "vllm-chat", (
| AIMODataset.SUPPORTED_DATASET_PATHS): f"{args.dataset_path} needs to use vllm-chat as the backend."
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 ) # noqa: E501
elif args.dataset_path in (
InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS
):
assert args.backend == "vllm", (
f"{args.dataset_path} needs to use vllm as the backend."
) # noqa: E501
else: else:
raise ValueError( raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random' # --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != 'random' and args.random_range_ratio is not None: if args.dataset_name != "random" and args.random_range_ratio is not None:
warnings.warn("--random-range-ratio will be ignored since \ warnings.warn(
"--random-range-ratio will be ignored since \
--dataset-name is not 'random'.", --dataset-name is not 'random'.",
stacklevel=2) stacklevel=2,
)
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set. # set.
if args.dataset_name not in {"random", "sonnet", None if (
} and args.prefix_len is not None: args.dataset_name not in {"random", "sonnet", None}
warnings.warn("--prefix-len will be ignored since --dataset-name\ and args.prefix_len is not None
):
warnings.warn(
"--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.", is not 'random', 'sonnet', or not set.",
stacklevel=2) stacklevel=2,
)
# === LoRA Settings === # === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm": if getattr(args, "enable_lora", False) and args.backend != "vllm":
raise ValueError( raise ValueError("LoRA benchmarking is only supported for vLLM backend")
"LoRA benchmarking is only supported for vLLM backend")
if getattr(args, "enable_lora", False) and args.lora_path is None: if getattr(args, "enable_lora", False) and args.lora_path is None:
raise ValueError("LoRA path must be provided when enable_lora is True") raise ValueError("LoRA path must be provided when enable_lora is True")
...@@ -512,8 +572,10 @@ def validate_args(args): ...@@ -512,8 +572,10 @@ def validate_args(args):
if args.backend != "hf" and args.hf_max_batch_size is not None: if args.backend != "hf" and args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.") raise ValueError("HF max batch size is only for HF backend.")
if args.backend in {"hf", "mii"} and getattr(args, "quantization", if (
None) is not None: args.backend in {"hf", "mii"}
and getattr(args, "quantization", None) is not None
):
raise ValueError("Quantization is only for vLLM backend.") raise ValueError("Quantization is only for vLLM backend.")
if args.backend == "mii" and args.dtype != "auto": if args.backend == "mii" and args.dtype != "auto":
...@@ -521,29 +583,32 @@ def validate_args(args): ...@@ -521,29 +583,32 @@ def validate_args(args):
if args.backend == "mii" and args.n != 1: if args.backend == "mii" and args.n != 1:
raise ValueError("n must be 1 for MII backend.") raise ValueError("n must be 1 for MII backend.")
if args.backend == "mii" and args.tokenizer != args.model: if args.backend == "mii" and args.tokenizer != args.model:
raise ValueError( raise ValueError("Tokenizer must be the same as the model for MII backend.")
"Tokenizer must be the same as the model for MII backend.")
# --data-parallel is not supported currently. # --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222 # https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1: if args.data_parallel_size > 1:
raise ValueError( raise ValueError(
"Data parallel is not supported in offline benchmark, \ "Data parallel is not supported in offline benchmark, \
please use benchmark serving instead") please use benchmark serving instead"
)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", parser.add_argument(
type=str, "--backend",
choices=["vllm", "hf", "mii", "vllm-chat"], type=str,
default="vllm") choices=["vllm", "hf", "mii", "vllm-chat"],
default="vllm",
)
parser.add_argument( parser.add_argument(
"--dataset-name", "--dataset-name",
type=str, type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
default="sharegpt") default="sharegpt",
)
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,
...@@ -551,57 +616,70 @@ if __name__ == "__main__": ...@@ -551,57 +616,70 @@ if __name__ == "__main__":
help="Path to the ShareGPT dataset, will be deprecated in\ help="Path to the ShareGPT dataset, will be deprecated in\
the next release. The dataset is expected to " the next release. The dataset is expected to "
"be a json in form of list[dict[..., conversations: " "be a json in form of list[dict[..., conversations: "
"list[dict[..., value: <prompt_or_response>]]]]") "list[dict[..., value: <prompt_or_response>]]]]",
parser.add_argument("--dataset-path", )
type=str, parser.add_argument(
default=None, "--dataset-path", type=str, default=None, help="Path to the dataset"
help="Path to the dataset") )
parser.add_argument("--input-len", parser.add_argument(
type=int, "--input-len",
default=None, type=int,
help="Input prompt length for each request") default=None,
parser.add_argument("--output-len", help="Input prompt length for each request",
type=int, )
default=None, parser.add_argument(
help="Output length for each request. Overrides the " "--output-len",
"output length from the dataset.") type=int,
parser.add_argument("--n", default=None,
type=int, help="Output length for each request. Overrides the "
default=1, "output length from the dataset.",
help="Number of generated sequences per prompt.") )
parser.add_argument("--num-prompts", parser.add_argument(
type=int, "--n", type=int, default=1, help="Number of generated sequences per prompt."
default=1000, )
help="Number of prompts to process.")
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument( parser.add_argument(
'--output-json', "--num-prompts", type=int, default=1000, help="Number of prompts to process."
)
parser.add_argument(
"--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.",
)
parser.add_argument(
"--output-json",
type=str, type=str,
default=None, default=None,
help='Path to save the throughput results in JSON format.') help="Path to save the throughput results in JSON format.",
parser.add_argument("--async-engine", )
action='store_true', parser.add_argument(
default=False, "--async-engine",
help="Use vLLM async engine rather than LLM class.") action="store_true",
parser.add_argument("--disable-frontend-multiprocessing", default=False,
action='store_true', help="Use vLLM async engine rather than LLM class.",
default=False, )
help="Disable decoupled async engine frontend.") parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
default=False,
help="Disable decoupled async engine frontend.",
)
parser.add_argument( parser.add_argument(
"--disable-detokenize", "--disable-detokenize",
action="store_true", action="store_true",
help=("Do not detokenize the response (i.e. do not include " help=(
"detokenization time in the measurement)")) "Do not detokenize the response (i.e. do not include "
"detokenization time in the measurement)"
),
)
# LoRA # LoRA
parser.add_argument( parser.add_argument(
"--lora-path", "--lora-path",
type=str, type=str,
default=None, default=None,
help="Path to the LoRA adapters to use. This can be an absolute path, " help="Path to the LoRA adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.") "a relative path, or a Hugging Face model identifier.",
)
parser.add_argument( parser.add_argument(
"--prefix-len", "--prefix-len",
type=int, type=int,
...@@ -615,7 +693,8 @@ if __name__ == "__main__": ...@@ -615,7 +693,8 @@ if __name__ == "__main__":
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) " f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
"controls how much of the input is fixed lines versus " "controls how much of the input is fixed lines versus "
"random lines, but the total input length remains approximately " "random lines, but the total input length remains approximately "
"input_len tokens.") "input_len tokens.",
)
# random dataset # random dataset
parser.add_argument( parser.add_argument(
"--random-range-ratio", "--random-range-ratio",
...@@ -629,14 +708,12 @@ if __name__ == "__main__": ...@@ -629,14 +708,12 @@ if __name__ == "__main__":
) )
# hf dtaset # hf dtaset
parser.add_argument("--hf-subset", parser.add_argument(
type=str, "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
default=None, )
help="Subset of the HF dataset.") parser.add_argument(
parser.add_argument("--hf-split", "--hf-split", type=str, default=None, help="Split of the HF dataset."
type=str, )
default=None,
help="Split of the HF dataset.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -7,9 +7,9 @@ import os ...@@ -7,9 +7,9 @@ import os
from typing import Any from typing import Any
def convert_to_pytorch_benchmark_format(args: argparse.Namespace, def convert_to_pytorch_benchmark_format(
metrics: dict[str, list], args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
extra_info: dict[str, Any]) -> list: ) -> list:
""" """
Save the benchmark results in the format used by PyTorch OSS benchmark with Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record on metric per record
...@@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, ...@@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
}, },
} }
tp = record["benchmark"]["extra_info"]["args"].get( tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
"tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata # Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info: if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"][ record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
"tensor_parallel_size"] = extra_info["tensor_parallel_size"] extra_info["tensor_parallel_size"]
)
records.append(record) records.append(record)
...@@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, ...@@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
class InfEncoder(json.JSONEncoder): class InfEncoder(json.JSONEncoder):
def clear_inf(self, o: Any): def clear_inf(self, o: Any):
if isinstance(o, dict): if isinstance(o, dict):
return {k: self.clear_inf(v) for k, v in o.items()} return {k: self.clear_inf(v) for k, v in o.items()}
......
...@@ -23,8 +23,9 @@ DEFAULT_TP_SIZES = [1] ...@@ -23,8 +23,9 @@ DEFAULT_TP_SIZES = [1]
# bench # bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_fn(
**kwargs) -> TMeasurement: label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1 min_run_time = 1
globals = { globals = {
...@@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ...@@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
).blocked_autorange(min_run_time=min_run_time) ).blocked_autorange(min_run_time=min_run_time)
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_int8(
sub_label: str) -> Iterable[TMeasurement]: dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.int8 assert dtype == torch.int8
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, out = ops.cutlass_scaled_sparse_mm(
torch.bfloat16) a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref): if not torch.allclose(out, out_ref):
...@@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ...@@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers = [] timers = []
# pytorch impl - bfloat16 # pytorch impl - bfloat16
timers.append( timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", bench_fn(
torch.mm, a.to(dtype=torch.bfloat16), label,
b.to(dtype=torch.bfloat16))) sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16),
)
)
# pytorch impl - float16 # pytorch impl - float16
timers.append( timers.append(
bench_fn(label, sub_label, bench_fn(
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, label,
a.to(dtype=torch.float16), b.to(dtype=torch.float16))) sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.float16),
b.to(dtype=torch.float16),
)
)
# cutlass impl # cutlass impl
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", bench_fn(
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, label,
torch.bfloat16)) sub_label,
"cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass with bias # cutlass with bias
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", bench_fn(
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, label,
bias)) sub_label,
"cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass sparse impl # cutlass sparse impl
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.bfloat16)) sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass sparse with bias # cutlass sparse with bias
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.bfloat16, bias)) sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
return timers return timers
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_fp8(
sub_label: str) -> Iterable[TMeasurement]: dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, out = ops.cutlass_scaled_sparse_mm(
torch.bfloat16) a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref): if not torch.allclose(out, out_ref):
...@@ -124,97 +180,165 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ...@@ -124,97 +180,165 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# pytorch impl w. bf16 # pytorch impl w. bf16
timers.append( timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", bench_fn(
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), label,
b.to(dtype=torch.bfloat16, device="cuda"))) sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"),
)
)
# pytorch impl: bf16 output, without fp8 fast accum # pytorch impl: bf16 output, without fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
sub_label, label,
"pytorch_fp8_fp8_bf16_scaled_mm", sub_label,
torch._scaled_mm, "pytorch_fp8_fp8_bf16_scaled_mm",
a, torch._scaled_mm,
b, a,
scale_a=scale_a, b,
scale_b=scale_b, scale_a=scale_a,
out_dtype=torch.bfloat16)) scale_b=scale_b,
out_dtype=torch.bfloat16,
)
)
# pytorch impl: bf16 output, with fp8 fast accum # pytorch impl: bf16 output, with fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
sub_label, label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", sub_label,
torch._scaled_mm, "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
a, torch._scaled_mm,
b, a,
scale_a=scale_a, b,
scale_b=scale_b, scale_a=scale_a,
out_dtype=torch.bfloat16, scale_b=scale_b,
use_fast_accum=True)) out_dtype=torch.bfloat16,
use_fast_accum=True,
)
)
# pytorch impl: fp16 output, without fp8 fast accum # pytorch impl: fp16 output, without fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
sub_label, label,
"pytorch_fp8_fp8_fp16_scaled_mm", sub_label,
torch._scaled_mm, "pytorch_fp8_fp8_fp16_scaled_mm",
a, torch._scaled_mm,
b, a,
scale_a=scale_a, b,
scale_b=scale_b, scale_a=scale_a,
out_dtype=torch.float16)) scale_b=scale_b,
out_dtype=torch.float16,
)
)
# pytorch impl: fp16 output, with fp8 fast accum # pytorch impl: fp16 output, with fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
sub_label, label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", sub_label,
torch._scaled_mm, "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
a, torch._scaled_mm,
b, a,
scale_a=scale_a, b,
scale_b=scale_b, scale_a=scale_a,
out_dtype=torch.float16, scale_b=scale_b,
use_fast_accum=True)) out_dtype=torch.float16,
use_fast_accum=True,
)
)
# cutlass impl: bf16 output # cutlass impl: bf16 output
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", bench_fn(
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, label,
torch.bfloat16)) sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: bf16 output # cutlass impl: bf16 output
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.bfloat16)) sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: fp16 output # cutlass impl: fp16 output
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.float16)) sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
)
)
# cutlass impl: bf16 output, with bias # cutlass impl: bf16 output, with bias
timers.append( timers.append(
bench_fn(label, sub_label, bench_fn(
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", label,
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, sub_label,
scale_b, torch.bfloat16, bias)) "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass impl: fp16 output, with bias # cutlass impl: fp16 output, with bias
timers.append( timers.append(
bench_fn(label, sub_label, bench_fn(
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", label,
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, sub_label,
scale_b, torch.float16, bias.to(dtype=torch.float16))) "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
bias.to(dtype=torch.float16),
)
)
return timers return timers
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench(
sub_label: str) -> Iterable[TMeasurement]: dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
if dtype == torch.int8: if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label) return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
...@@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]): ...@@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print() compare.print()
def run(dtype: torch.dtype, def run(
MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
f"MKN=({m}x{k}x{n})")
print_timers(timers) print_timers(timers)
results.extend(timers) results.extend(timers)
...@@ -241,10 +365,12 @@ def run(dtype: torch.dtype, ...@@ -241,10 +365,12 @@ def run(dtype: torch.dtype,
# output makers # output makers
def make_output(data: Iterable[TMeasurement], def make_output(
MKNs: Iterable[tuple[int, int, int]], data: Iterable[TMeasurement],
base_description: str, MKNs: Iterable[tuple[int, int, int]],
timestamp=None): base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====") print(f"== All Results {base_description} ====")
print_timers(data) print_timers(data)
...@@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement], ...@@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args): def run_square_bench(args):
dim_sizes = list( dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs) data = run(args.dtype, MKNs)
...@@ -319,7 +444,7 @@ def run_model_bench(args): ...@@ -319,7 +444,7 @@ def run_model_bench(args):
pkl.dump(all_data, f) pkl.dump(all_data, f)
if __name__ == '__main__': if __name__ == "__main__":
def to_torch_dtype(dt): def to_torch_dtype(dt):
if dt == "int8": if dt == "int8":
...@@ -344,12 +469,15 @@ Benchmark Cutlass GEMM. ...@@ -344,12 +469,15 @@ Benchmark Cutlass GEMM.
Output: Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501 """, # noqa: E501
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dtype",
type=to_torch_dtype, parser.add_argument(
required=True, "--dtype",
help="Available options are ['int8', 'fp8']") type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']",
)
subparsers = parser.add_subparsers(dest="cmd") subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench") square_parser = subparsers.add_parser("square_bench")
...@@ -368,19 +496,19 @@ Benchmark Cutlass GEMM. ...@@ -368,19 +496,19 @@ Benchmark Cutlass GEMM.
range_parser.set_defaults(func=run_range_bench) range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench") model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models", model_parser.add_argument(
nargs="+", "--models",
type=str, nargs="+",
default=DEFAULT_MODELS, type=str,
choices=WEIGHT_SHAPES.keys()) default=DEFAULT_MODELS,
model_parser.add_argument("--tp-sizes", choices=WEIGHT_SHAPES.keys(),
nargs="+", )
type=int, model_parser.add_argument(
default=DEFAULT_TP_SIZES) "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
model_parser.add_argument("--batch-sizes", )
nargs="+", model_parser.add_argument(
type=int, "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
default=DEFAULT_BATCH_SIZES) )
model_parser.set_defaults(func=run_model_bench) model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -10,8 +10,9 @@ import vllm._custom_ops as ops ...@@ -10,8 +10,9 @@ import vllm._custom_ops as ops
def to_fp8(tensor: torch.Tensor) -> torch.Tensor: def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp( return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) dtype=torch.float8_e4m3fn
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
...@@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor: ...@@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16) return tensor.to(dtype=torch.float16)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_tensors(
k: int) -> tuple[torch.Tensor, torch.Tensor]: dtype: torch.dtype, m: int, n: int, k: int
a = torch.randn((m, k), device='cuda') * 5 ) -> tuple[torch.Tensor, torch.Tensor]:
b = torch.randn((n, k), device='cuda').t() * 5 a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
if dtype == torch.int8: if dtype == torch.int8:
return to_int8(a), to_int8(b) return to_int8(a), to_int8(b)
...@@ -49,9 +51,7 @@ def prune_to_2_4(tensor): ...@@ -49,9 +51,7 @@ def prune_to_2_4(tensor):
# Create binary mask # Create binary mask
mask = torch.zeros_like(reshaped) mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1, mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back # Apply mask and reshape back
pruned = reshaped * mask pruned = reshaped * mask
...@@ -62,10 +62,11 @@ def prune_to_2_4(tensor): ...@@ -62,10 +62,11 @@ def prune_to_2_4(tensor):
return pruned.reshape(original_shape) return pruned.reshape(original_shape)
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_sparse_tensors(
k: int) -> tuple[torch.Tensor, torch.Tensor]: dtype: torch.dtype, m: int, n: int, k: int
a = torch.randn((m, k), device='cuda') * 5 ) -> tuple[torch.Tensor, torch.Tensor]:
b = torch.randn((n, k), device='cuda').t() * 5 a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
b = prune_to_2_4(b.t()).t() b = prune_to_2_4(b.t()).t()
...@@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, ...@@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
return b_compressed, e, a, b return b_compressed, e, a, b
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, def make_n_rand_sparse_tensors(
m: int, n: int, k: int) -> \ num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: ) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
ABs = [] ABs = []
for _ in range(num_tensors): for _ in range(num_tensors):
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
......
...@@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES ...@@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul) w8a8_block_fp8_matmul,
)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
...@@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1] ...@@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1]
# bench # bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_fn(
**kwargs) -> TMeasurement: label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1 min_run_time = 1
globals = { globals = {
...@@ -44,45 +46,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ...@@ -44,45 +46,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
def bench_int8( def bench_int8(
dtype: torch.dtype, dtype: torch.dtype,
m: int, m: int,
k: int, k: int,
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels.""" """Benchmark INT8-based kernels."""
assert dtype == torch.int8 assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k) a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
bench_fns = { bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales": "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
), ),
"pytorch_fp16_fp16_fp16_matmul-no-scales": "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), a.to(dtype=torch.float16), b.to(dtype=torch.float16)
"cutlass_i8_i8_bf16_scaled_mm": ),
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), "cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
"cutlass_i8_i8_bf16_scaled_mm_bias": a, b, scale_a, scale_b, torch.bfloat16
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, ),
bias), "cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
"cutlass_i8_i8_bf16_scaled_mm_azp": a, b, scale_a, scale_b, torch.bfloat16, bias
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. ),
bfloat16, azp_adj), "cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": a, b, scale_a, scale_b, torch.bfloat16, azp_adj
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. ),
bfloat16, azp_adj, None, bias), "cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. ),
bfloat16, azp_adj, azp), "cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. ),
bfloat16, azp_adj, azp, bias), "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
),
} }
timers = [] timers = []
...@@ -96,73 +101,65 @@ def bench_int8( ...@@ -96,73 +101,65 @@ def bench_int8(
def bench_fp8( def bench_fp8(
dtype: torch.dtype, dtype: torch.dtype,
m: int, m: int,
k: int, k: int,
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels.""" """Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
a_cont = a.contiguous() a_cont = a.contiguous()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
block_scale_a = torch.rand((m, k // 128), block_scale_a = torch.rand((m, k // 128), device="cuda", dtype=torch.float32)
device="cuda", block_scale_b = torch.rand((k // 128, n // 128), device="cuda", dtype=torch.float32)
dtype=torch.float32)
block_scale_b = torch.rand((k // 128, n // 128),
device="cuda",
dtype=torch.float32)
block_scale_a_M_major = block_scale_a.t().contiguous().t() block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t() block_scale_b_K_major = block_scale_b.t().contiguous().t()
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
print(m, k, n) print(m, k, n)
bench_fns = { bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales": "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
), ),
"pytorch_fp16_fp16_fp16_matmul-no-scales": "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), a.to(dtype=torch.float16), b.to(dtype=torch.float16)
"pytorch_fp8_fp8_fp16_scaled_mm": ),
lambda: torch._scaled_mm( "pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16), a, b, scale_a, scale_b, out_dtype=torch.float16
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": ),
lambda: torch._scaled_mm(a, "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
b, a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
scale_a, ),
scale_b, "pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
out_dtype=torch.float16, a, b, scale_a, scale_b, out_dtype=torch.bfloat16
use_fast_accum=True), ),
"pytorch_fp8_fp8_bf16_scaled_mm": "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
lambda: torch._scaled_mm( a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
a, b, scale_a, scale_b, out_dtype=torch.bfloat16), ),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": "cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
lambda: torch._scaled_mm(a, a, b, scale_a, scale_b, torch.bfloat16
b, ),
scale_a, "cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
scale_b, a, b, scale_a, scale_b, torch.float16
out_dtype=torch.bfloat16, ),
use_fast_accum=True), "cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
"cutlass_fp8_fp8_bf16_scaled_mm": a, b, scale_a, scale_b, torch.bfloat16, bias
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), ),
"cutlass_fp8_fp8_fp16_scaled_mm": "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16), a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
"cutlass_fp8_fp8_bf16_scaled_mm_bias": ),
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
bias), a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
"cutlass_fp8_fp8_fp16_scaled_mm_bias": ),
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16, "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
bias.to(dtype=torch.float16)), a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
"triton_fp8_fp8_fp16_scaled_mm_blockwise": ),
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
block_scale_b.t(), (128, 128)),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
block_scale_b_K_major, torch.float16),
} }
timers = [] timers = []
...@@ -175,13 +172,15 @@ def bench_fp8( ...@@ -175,13 +172,15 @@ def bench_fp8(
return timers return timers
def bench(dtype: torch.dtype, def bench(
m: int, dtype: torch.dtype,
k: int, m: int,
n: int, k: int,
label: str, n: int,
sub_label: str, label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: sub_label: str,
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
if dtype == torch.int8: if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
...@@ -195,27 +194,33 @@ def print_timers(timers: Iterable[TMeasurement]): ...@@ -195,27 +194,33 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print() compare.print()
def run(dtype: torch.dtype, def run(
MKNs: Iterable[tuple[int, int, int]], dtype: torch.dtype,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: MKNs: Iterable[tuple[int, int, int]],
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(dtype, timers = bench(
m, dtype,
k, m,
n, k,
f"scaled-{dtype}-gemm", n,
f"MKN=({m}x{k}x{n})", f"scaled-{dtype}-gemm",
bench_kernels=bench_kernels) f"MKN=({m}x{k}x{n})",
bench_kernels=bench_kernels,
)
print_timers(timers) print_timers(timers)
results.extend(timers) results.extend(timers)
return results return results
def make_output(data: Iterable[TMeasurement], def make_output(
MKNs: Iterable[tuple[int, int, int]], data: Iterable[TMeasurement],
base_description: str, MKNs: Iterable[tuple[int, int, int]],
timestamp=None): base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====") print(f"== All Results {base_description} ====")
print_timers(data) print_timers(data)
...@@ -226,8 +231,7 @@ def make_output(data: Iterable[TMeasurement], ...@@ -226,8 +231,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args): def run_square_bench(args):
dim_sizes = list( dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs, bench_kernels=args.kernels) data = run(args.dtype, MKNs, bench_kernels=args.kernels)
make_output(data, MKNs, f"square_bench-{args.dtype}") make_output(data, MKNs, f"square_bench-{args.dtype}")
...@@ -285,7 +289,7 @@ def run_model_bench(args): ...@@ -285,7 +289,7 @@ def run_model_bench(args):
pkl.dump(all_data, f) pkl.dump(all_data, f)
if __name__ == '__main__': if __name__ == "__main__":
def to_torch_dtype(dt): def to_torch_dtype(dt):
if dt == "int8": if dt == "int8":
...@@ -310,19 +314,21 @@ Benchmark Cutlass GEMM. ...@@ -310,19 +314,21 @@ Benchmark Cutlass GEMM.
Output: Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501 """, # noqa: E501
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dtype", parser.add_argument(
type=to_torch_dtype, "--dtype",
required=True, type=to_torch_dtype,
help="Available options are ['int8', 'fp8']") required=True,
help="Available options are ['int8', 'fp8']",
)
parser.add_argument( parser.add_argument(
"--kernels", "--kernels",
nargs="+", nargs="+",
type=str, type=str,
default=None, default=None,
help= help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
"Exact names of the kernels to benchmark. If not set, runs all kernels."
) )
subparsers = parser.add_subparsers(dest="cmd") subparsers = parser.add_subparsers(dest="cmd")
...@@ -343,19 +349,19 @@ Benchmark Cutlass GEMM. ...@@ -343,19 +349,19 @@ Benchmark Cutlass GEMM.
range_parser.set_defaults(func=run_range_bench) range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench") model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models", model_parser.add_argument(
nargs="+", "--models",
type=str, nargs="+",
default=DEFAULT_MODELS, type=str,
choices=WEIGHT_SHAPES.keys()) default=DEFAULT_MODELS,
model_parser.add_argument("--tp-sizes", choices=WEIGHT_SHAPES.keys(),
nargs="+", )
type=int, model_parser.add_argument(
default=DEFAULT_TP_SIZES) "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
model_parser.add_argument("--batch-sizes", )
nargs="+", model_parser.add_argument(
type=int, "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
default=DEFAULT_BATCH_SIZES) )
model_parser.set_defaults(func=run_model_bench) model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -42,4 +42,4 @@ WEIGHT_SHAPES = { ...@@ -42,4 +42,4 @@ WEIGHT_SHAPES = {
([8192, 57344], 1), ([8192, 57344], 1),
([28672, 8192], 0), ([28672, 8192], 0),
], ],
} }
\ No newline at end of file
...@@ -12,39 +12,37 @@ app = Quart(__name__) ...@@ -12,39 +12,37 @@ app = Quart(__name__)
async def forward_request(url, data): async def forward_request(url, data):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" async with session.post(url=url, json=data, headers=headers) as response:
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200: if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked': # if response.headers.get('Transfer-Encoding') == 'chunked':
if True: if True:
async for chunk_bytes in response.content.iter_chunked( async for chunk_bytes in response.content.iter_chunked(1024):
1024):
yield chunk_bytes yield chunk_bytes
else: else:
content = await response.read() content = await response.read()
yield content yield content
@app.route('/v1/completions', methods=['POST']) @app.route("/v1/completions", methods=["POST"])
async def handle_request(): async def handle_request():
try: try:
original_request_data = await request.get_json() original_request_data = await request.get_json()
prefill_request = original_request_data.copy() prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill # change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1 prefill_request["max_tokens"] = 1
# finish prefill # finish prefill
async for _ in forward_request('http://localhost:8100/v1/completions', async for _ in forward_request(
prefill_request): "http://localhost:8100/v1/completions", prefill_request
):
continue continue
# return decode # return decode
generator = forward_request('http://localhost:8200/v1/completions', generator = forward_request(
original_request_data) "http://localhost:8200/v1/completions", original_request_data
)
response = await make_response(generator) response = await make_response(generator)
response.timeout = None response.timeout = None
...@@ -53,11 +51,12 @@ async def handle_request(): ...@@ -53,11 +51,12 @@ async def handle_request():
except Exception as e: except Exception as e:
import sys import sys
import traceback import traceback
exc_info = sys.exc_info() exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server") print("Error occurred in disagg prefill proxy server")
print(e) print(e)
print("".join(traceback.format_exception(*exc_info))) print("".join(traceback.format_exception(*exc_info)))
if __name__ == '__main__': if __name__ == "__main__":
app.run(port=8000) app.run(port=8000)
...@@ -8,7 +8,6 @@ from aiohttp import web ...@@ -8,7 +8,6 @@ from aiohttp import web
class RoundRobinProxy: class RoundRobinProxy:
def __init__(self, target_ports): def __init__(self, target_ports):
self.target_ports = target_ports self.target_ports = target_ports
self.port_cycle = itertools.cycle(self.target_ports) self.port_cycle = itertools.cycle(self.target_ports)
...@@ -21,14 +20,15 @@ class RoundRobinProxy: ...@@ -21,14 +20,15 @@ class RoundRobinProxy:
try: try:
# Forward the request # Forward the request
async with session.request( async with session.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
headers=request.headers, headers=request.headers,
data=request.content, data=request.content,
) as response: ) as response:
# Start sending the response # Start sending the response
resp = web.StreamResponse(status=response.status, resp = web.StreamResponse(
headers=response.headers) status=response.status, headers=response.headers
)
await resp.prepare(request) await resp.prepare(request)
# Stream the response content # Stream the response content
...@@ -45,11 +45,11 @@ class RoundRobinProxy: ...@@ -45,11 +45,11 @@ class RoundRobinProxy:
async def main(): async def main():
proxy = RoundRobinProxy([8100, 8200]) proxy = RoundRobinProxy([8100, 8200])
app = web.Application() app = web.Application()
app.router.add_route('*', '/{path:.*}', proxy.handle_request) app.router.add_route("*", "/{path:.*}", proxy.handle_request)
runner = web.AppRunner(app) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite(runner, 'localhost', 8000) site = web.TCPSite(runner, "localhost", 8000)
await site.start() await site.start()
print("Proxy server started on http://localhost:8000") print("Proxy server started on http://localhost:8000")
...@@ -58,5 +58,5 @@ async def main(): ...@@ -58,5 +58,5 @@ async def main():
await asyncio.Event().wait() await asyncio.Event().wait()
if __name__ == '__main__': if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())
...@@ -6,43 +6,41 @@ import matplotlib.pyplot as plt ...@@ -6,43 +6,41 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
if __name__ == "__main__": if __name__ == "__main__":
data = [] data = []
for name in ['disagg_prefill', 'chunked_prefill']: for name in ["disagg_prefill", "chunked_prefill"]:
for qps in [2, 4, 6, 8]: for qps in [2, 4, 6, 8]:
with open(f"results/{name}-qps-{qps}.json") as f: with open(f"results/{name}-qps-{qps}.json") as f:
x = json.load(f) x = json.load(f)
x['name'] = name x["name"] = name
x['qps'] = qps x["qps"] = qps
data.append(x) data.append(x)
df = pd.DataFrame.from_dict(data) df = pd.DataFrame.from_dict(data)
dis_df = df[df['name'] == 'disagg_prefill'] dis_df = df[df["name"] == "disagg_prefill"]
chu_df = df[df['name'] == 'chunked_prefill'] chu_df = df[df["name"] == "chunked_prefill"]
plt.style.use('bmh') plt.style.use("bmh")
plt.rcParams['font.size'] = 20 plt.rcParams["font.size"] = 20
for key in [ for key in [
'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', "mean_ttft_ms",
'median_itl_ms', 'p99_itl_ms' "median_ttft_ms",
"p99_ttft_ms",
"mean_itl_ms",
"median_itl_ms",
"p99_itl_ms",
]: ]:
fig, ax = plt.subplots(figsize=(11, 7)) fig, ax = plt.subplots(figsize=(11, 7))
plt.plot(dis_df['qps'], plt.plot(
dis_df[key], dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4
label='disagg_prefill', )
marker='o', plt.plot(
linewidth=4) chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4
plt.plot(chu_df['qps'], )
chu_df[key],
label='chunked_prefill',
marker='o',
linewidth=4)
ax.legend() ax.legend()
ax.set_xlabel('QPS') ax.set_xlabel("QPS")
ax.set_ylabel(key) ax.set_ylabel(key)
ax.set_ylim(bottom=0) ax.set_ylim(bottom=0)
fig.savefig(f'results/{key}.png') fig.savefig(f"results/{key}.png")
plt.close(fig) plt.close(fig)
...@@ -24,10 +24,12 @@ class bench_params_t: ...@@ -24,10 +24,12 @@ class bench_params_t:
dtype: torch.dtype dtype: torch.dtype
def description(self): def description(self):
return (f'N {self.num_tokens} ' return (
f'x D {self.hidden_size} ' f"N {self.num_tokens} "
f'x R {self.add_residual} ' f"x D {self.hidden_size} "
f'x DT {self.dtype}') f"x R {self.add_residual} "
f"x DT {self.dtype}"
)
def get_bench_params() -> list[bench_params_t]: def get_bench_params() -> list[bench_params_t]:
...@@ -38,15 +40,19 @@ def get_bench_params() -> list[bench_params_t]: ...@@ -38,15 +40,19 @@ def get_bench_params() -> list[bench_params_t]:
DTYPES = [torch.bfloat16, torch.float] DTYPES = [torch.bfloat16, torch.float]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
bench_params = list(map(lambda x: \ bench_params = list(
bench_params_t(x[0], x[1], x[2], x[3]), combinations)) map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
)
return bench_params return bench_params
# Reference impls # Reference impls
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, def unfused_int8_impl(
residual: Optional[torch.Tensor], rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype): x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype,
):
# Norm # Norm
torch_out = None torch_out = None
if residual is None: if residual is None:
...@@ -58,9 +64,12 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, ...@@ -58,9 +64,12 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
torch_out, _, _ = ops.scaled_int8_quant(torch_out) torch_out, _, _ = ops.scaled_int8_quant(torch_out)
def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, def unfused_fp8_impl(
residual: Optional[torch.Tensor], rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype): x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype,
):
# Norm # Norm
torch_out = None torch_out = None
if residual is None: if residual is None:
...@@ -73,22 +82,27 @@ def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, ...@@ -73,22 +82,27 @@ def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
def fused_impl( def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
quant_dtype: torch.dtype): quant_dtype: torch.dtype,
out, _ = ops.rms_norm_dynamic_per_token_quant(x, ):
rms_norm_layer.weight, out, _ = ops.rms_norm_dynamic_per_token_quant(
1e-6, x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
quant_dtype, )
residual=residual)
# Bench functions # Bench functions
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, def bench_fn(
quant_dtype: torch.dtype, label: str, sub_label: str, rms_norm_layer: RMSNorm,
fn: Callable, description: str) -> TMeasurement: x: torch.Tensor,
residual: torch.Tensor,
quant_dtype: torch.dtype,
label: str,
sub_label: str,
fn: Callable,
description: str,
) -> TMeasurement:
min_run_time = 1 min_run_time = 1
globals = { globals = {
...@@ -106,43 +120,81 @@ def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, ...@@ -106,43 +120,81 @@ def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
description=description, description=description,
).blocked_autorange(min_run_time=min_run_time) ).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) \
-> Iterable[TMeasurement]:
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
# Make inputs # Make inputs
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
# Make weights # Make weights
layer.weight.data.normal_(mean=1.0, std=0.1) layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs # Make inputs
scale = 1 / params.hidden_size scale = 1 / params.hidden_size
x = torch.randn(params.num_tokens, x = (
params.hidden_size, torch.randn(
dtype=params.dtype, params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
device='cuda') * scale )
residual = (torch.randn_like(x) * scale).to(device='cuda') \ * scale
if params.add_residual else None )
residual = (
(torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
)
timers = [] timers = []
# unfused int8 impl. # unfused int8 impl.
timers.append( timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label, bench_fn(
unfused_int8_impl, "unfused_int8_impl")) layer,
x,
residual,
torch.int8,
label,
sub_label,
unfused_int8_impl,
"unfused_int8_impl",
)
)
# unfused fp8 impl. # unfused fp8 impl.
timers.append( timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, bench_fn(
unfused_fp8_impl, "unfused_fp8_impl")) layer,
x,
residual,
torch.float8_e4m3fn,
label,
sub_label,
unfused_fp8_impl,
"unfused_fp8_impl",
)
)
# fused int8 impl. # fused int8 impl.
timers.append( timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, bench_fn(
"fused_int8_impl")) layer,
x,
residual,
torch.int8,
label,
sub_label,
fused_impl,
"fused_int8_impl",
)
)
# fused fp8 impl. # fused fp8 impl.
timers.append( timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, bench_fn(
fused_impl, "fused_fp8_impl")) layer,
x,
residual,
torch.float8_e4m3fn,
label,
sub_label,
fused_impl,
"fused_fp8_impl",
)
)
print_timers(timers) print_timers(timers)
...@@ -157,13 +209,12 @@ def print_timers(timers: Iterable[TMeasurement]): ...@@ -157,13 +209,12 @@ def print_timers(timers: Iterable[TMeasurement]):
def main(): def main():
torch.set_default_device('cuda') torch.set_default_device("cuda")
bench_params = get_bench_params() bench_params = get_bench_params()
timers = [] timers = []
for bp in tqdm(bench_params): for bp in tqdm(bench_params):
timers.extend( timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
print_timers(timers) print_timers(timers)
# pickle all the results # pickle all the results
...@@ -172,5 +223,5 @@ def main(): ...@@ -172,5 +223,5 @@ def main():
pkl.dump(timers, f) pkl.dump(timers, f)
if __name__ == '__main__': if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment