Unverified Commit 009d9e75 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Convert `benchmarks` to `ruff format` (#18068)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent b922c2eb
# This local pyproject file is part of the migration from yapf to ruff format. # This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the # It uses the same core rules as the main pyproject.toml file, but with the
# following differences: # following differences:
# - isort profile is set to black
# - ruff line length is overridden to 88 # - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed # - deprecated typing ignores (UP006, UP035) have been removed
[tool.isort]
profile = "black"
[tool.ruff] [tool.ruff]
line-length = 88 line-length = 88
exclude = [ exclude = [
......
...@@ -17,7 +17,7 @@ repos: ...@@ -17,7 +17,7 @@ repos:
- id: ruff - id: ruff
args: [--output-format, github, --fix] args: [--output-format, github, --fix]
- id: ruff-format - id: ruff-format
files: ^(.buildkite).* files: ^(.buildkite|benchmarks)/.*
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.4.1 rev: v2.4.1
hooks: hooks:
...@@ -28,8 +28,6 @@ repos: ...@@ -28,8 +28,6 @@ repos:
rev: 6.0.1 rev: 6.0.1
hooks: hooks:
- id: isort - id: isort
# necessary during the transition from yapf to ruff format
args: [--resolve-all-configs, --config-root, .]
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v20.1.3 rev: v20.1.3
hooks: hooks:
......
...@@ -12,8 +12,7 @@ from typing import Optional, Union ...@@ -12,8 +12,7 @@ from typing import Optional, Union
import aiohttp import aiohttp
import huggingface_hub.constants import huggingface_hub.constants
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
PreTrainedTokenizerFast)
# NOTE(simon): do not import vLLM here so the benchmark script # NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed. # can run without vLLM installed.
...@@ -43,8 +42,7 @@ class RequestFuncOutput: ...@@ -43,8 +42,7 @@ class RequestFuncOutput:
latency: float = 0.0 latency: float = 0.0
output_tokens: int = 0 output_tokens: int = 0
ttft: float = 0.0 # Time to first token ttft: float = 0.0 # Time to first token
itl: list[float] = field( itl: list[float] = field(default_factory=list) # list of inter-token latencies
default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0 prompt_len: int = 0
error: str = "" error: str = ""
...@@ -57,8 +55,9 @@ async def async_request_tgi( ...@@ -57,8 +55,9 @@ async def async_request_tgi(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
params = { params = {
"max_new_tokens": request_func_input.output_len, "max_new_tokens": request_func_input.output_len,
"do_sample": True, "do_sample": True,
...@@ -105,8 +104,7 @@ async def async_request_tgi( ...@@ -105,8 +104,7 @@ async def async_request_tgi(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
...@@ -133,8 +131,9 @@ async def async_request_trt_llm( ...@@ -133,8 +131,9 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = { payload = {
"accumulate_tokens": True, "accumulate_tokens": True,
"text_input": request_func_input.prompt, "text_input": request_func_input.prompt,
...@@ -159,8 +158,7 @@ async def async_request_trt_llm( ...@@ -159,8 +158,7 @@ async def async_request_trt_llm(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
"data:")
data = json.loads(chunk) data = json.loads(chunk)
output.generated_text += data["text_output"] output.generated_text += data["text_output"]
...@@ -172,8 +170,7 @@ async def async_request_trt_llm( ...@@ -172,8 +170,7 @@ async def async_request_trt_llm(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
...@@ -197,9 +194,9 @@ async def async_request_deepspeed_mii( ...@@ -197,9 +194,9 @@ async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = { payload = {
"model": request_func_input.model, "model": request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
...@@ -217,19 +214,21 @@ async def async_request_deepspeed_mii( ...@@ -217,19 +214,21 @@ async def async_request_deepspeed_mii(
st = time.perf_counter() st = time.perf_counter()
try: try:
async with session.post(url=request_func_input.api_url, async with session.post(
json=payload) as response: url=request_func_input.api_url, json=payload
) as response:
if response.status == 200: if response.status == 200:
parsed_resp = await response.json() parsed_resp = await response.json()
output.latency = time.perf_counter() - st output.latency = time.perf_counter() - st
if "choices" in parsed_resp: if "choices" in parsed_resp:
output.generated_text = parsed_resp["choices"][0][ output.generated_text = parsed_resp["choices"][0]["text"]
"text"]
elif "text" in parsed_resp: elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0] output.generated_text = parsed_resp["text"][0]
else: else:
output.error = ("Unexpected response format: " output.error = (
"neither 'choices' nor 'text' found") "Unexpected response format: "
"neither 'choices' nor 'text' found"
)
output.success = False output.success = False
output.success = True output.success = True
else: else:
...@@ -250,15 +249,17 @@ async def async_request_openai_completions( ...@@ -250,15 +249,17 @@ async def async_request_openai_completions(
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("completions", "profile")), (
("completions", "profile") "OpenAI Completions API URL must end with 'completions' or 'profile'."
), "OpenAI Completions API URL must end with 'completions' or 'profile'." )
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name
else request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
"temperature": 0.0, "temperature": 0.0,
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
...@@ -273,9 +274,7 @@ async def async_request_openai_completions( ...@@ -273,9 +274,7 @@ async def async_request_openai_completions(
payload["ignore_eos"] = request_func_input.ignore_eos payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -284,8 +283,9 @@ async def async_request_openai_completions( ...@@ -284,8 +283,9 @@ async def async_request_openai_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(
headers=headers) as response: url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
first_chunk_received = False first_chunk_received = False
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
...@@ -293,8 +293,7 @@ async def async_request_openai_completions( ...@@ -293,8 +293,7 @@ async def async_request_openai_completions(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
data = json.loads(chunk) data = json.loads(chunk)
...@@ -314,21 +313,20 @@ async def async_request_openai_completions( ...@@ -314,21 +313,20 @@ async def async_request_openai_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += text or "" generated_text += text or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else: else:
output.success = False output.success = False
output.error = ( output.error = (
"Never received a valid chunk to calculate TTFT." "Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!") "This response will be marked as failed!"
)
output.generated_text = generated_text output.generated_text = generated_text
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
else: else:
...@@ -349,23 +347,22 @@ async def async_request_openai_chat_completions( ...@@ -349,23 +347,22 @@ async def async_request_openai_chat_completions(
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("chat/completions", "profile")), (
("chat/completions", "profile") "OpenAI Chat Completions API URL must end with 'chat/completions'."
), "OpenAI Chat Completions API URL must end with 'chat/completions'." )
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content: if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content) content.append(request_func_input.multi_modal_content)
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name
else request_func_input.model,
"messages": [ "messages": [
{ {"role": "user", "content": content},
"role": "user",
"content": content
},
], ],
"temperature": 0.0, "temperature": 0.0,
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
...@@ -391,16 +388,16 @@ async def async_request_openai_chat_completions( ...@@ -391,16 +388,16 @@ async def async_request_openai_chat_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(
headers=headers) as response: url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
...@@ -414,13 +411,11 @@ async def async_request_openai_chat_completions( ...@@ -414,13 +411,11 @@ async def async_request_openai_chat_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
...@@ -446,25 +441,28 @@ async def async_request_openai_audio( ...@@ -446,25 +441,28 @@ async def async_request_openai_audio(
) -> RequestFuncOutput: ) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep. # Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile import soundfile
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("transcriptions", "translations")), (
("transcriptions", "translations" "OpenAI Chat Completions API URL must end with 'transcriptions' "
)), "OpenAI Chat Completions API URL must end with 'transcriptions' " )
"or `translations`." "or `translations`."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name
else request_func_input.model,
"temperature": 0.0, "temperature": 0.0,
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
"stream": True, "stream": True,
"language": "en", "language": "en",
# Flattened due to multipart/form-data # Flattened due to multipart/form-data
"stream_include_usage": True, "stream_include_usage": True,
"stream_continuous_usage_stats": True "stream_continuous_usage_stats": True,
} }
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
...@@ -479,9 +477,9 @@ async def async_request_openai_audio( ...@@ -479,9 +477,9 @@ async def async_request_openai_audio(
buffer.seek(0) buffer.seek(0)
return buffer return buffer
with to_bytes(*request_func_input.multi_modal_content['audio']) as f: with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData() form = aiohttp.FormData()
form.add_field('file', f, content_type='audio/wav') form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items(): for key, value in payload.items():
form.add_field(key, str(value)) form.add_field(key, str(value))
...@@ -493,24 +491,22 @@ async def async_request_openai_audio( ...@@ -493,24 +491,22 @@ async def async_request_openai_audio(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, async with session.post(
data=form, url=api_url, data=form, headers=headers
headers=headers) as response: ) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
if choices := data.get("choices"): if choices := data.get("choices"):
content = choices[0]["delta"].get( content = choices[0]["delta"].get("content")
"content")
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = timestamp - st ttft = timestamp - st
...@@ -519,12 +515,14 @@ async def async_request_openai_audio( ...@@ -519,12 +515,14 @@ async def async_request_openai_audio(
# Decoding phase # Decoding phase
else: else:
output.itl.append( output.itl.append(
timestamp - most_recent_timestamp) timestamp - most_recent_timestamp
)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get(
"completion_tokens") "completion_tokens"
)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
...@@ -545,7 +543,7 @@ async def async_request_openai_audio( ...@@ -545,7 +543,7 @@ async def async_request_openai_audio(
def get_model(pretrained_model_name_or_path: str) -> str: def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true":
from modelscope import snapshot_download from modelscope import snapshot_download
from vllm.model_executor.model_loader.weight_utils import get_lock from vllm.model_executor.model_loader.weight_utils import get_lock
...@@ -556,7 +554,8 @@ def get_model(pretrained_model_name_or_path: str) -> str: ...@@ -556,7 +554,8 @@ def get_model(pretrained_model_name_or_path: str) -> str:
model_path = snapshot_download( model_path = snapshot_download(
model_id=pretrained_model_name_or_path, model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
)
return model_path return model_path
return pretrained_model_name_or_path return pretrained_model_name_or_path
...@@ -569,23 +568,23 @@ def get_tokenizer( ...@@ -569,23 +568,23 @@ def get_tokenizer(
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists( if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path): pretrained_model_name_or_path
pretrained_model_name_or_path = get_model( ):
pretrained_model_name_or_path) pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
if tokenizer_mode == "slow": if tokenizer_mode == "slow":
if kwargs.get("use_fast", False): if kwargs.get("use_fast", False):
raise ValueError( raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False kwargs["use_fast"] = False
if tokenizer_mode == "mistral": if tokenizer_mode == "mistral":
try: try:
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.transformers_utils.tokenizer import MistralTokenizer
except ImportError as e: except ImportError as e:
raise ImportError("MistralTokenizer requires vllm package.\n" raise ImportError(
"MistralTokenizer requires vllm package.\n"
"Please install it with `pip install vllm` " "Please install it with `pip install vllm` "
"to use mistral tokenizer mode.") from e "to use mistral tokenizer mode."
return MistralTokenizer.from_pretrained( ) from e
str(pretrained_model_name_or_path)) return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
else: else:
return AutoTokenizer.from_pretrained( return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -608,7 +607,7 @@ ASYNC_REQUEST_FUNCS = { ...@@ -608,7 +607,7 @@ ASYNC_REQUEST_FUNCS = {
} }
OPENAI_COMPATIBLE_BACKENDS = [ OPENAI_COMPATIBLE_BACKENDS = [
k for k, v in ASYNC_REQUEST_FUNCS.items() k
if v in (async_request_openai_completions, for k, v in ASYNC_REQUEST_FUNCS.items()
async_request_openai_chat_completions) if v in (async_request_openai_completions, async_request_openai_chat_completions)
] ]
This diff is collapsed.
...@@ -11,9 +11,9 @@ from typing import Any, Optional ...@@ -11,9 +11,9 @@ from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm from tqdm import tqdm
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
...@@ -21,13 +21,14 @@ from vllm.sampling_params import BeamSearchParams ...@@ -21,13 +21,14 @@ from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(
results: dict[str, Any]) -> None: args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={"latency": results["latencies"]}, metrics={"latency": results["latencies"]},
extra_info={k: results[k] extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
for k in ["avg_latency", "percentiles"]}) )
if pt_records: if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records) write_to_json(pt_file, pt_records)
...@@ -42,9 +43,11 @@ def main(args: argparse.Namespace): ...@@ -42,9 +43,11 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= ( assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.input_len + args.output_len
args.output_len), ("Please ensure that max_model_len is greater than" ), (
" the sum of input_len and output_len.") "Please ensure that max_model_len is greater than"
" the sum of input_len and output_len."
)
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
...@@ -55,18 +58,16 @@ def main(args: argparse.Namespace): ...@@ -55,18 +58,16 @@ def main(args: argparse.Namespace):
detokenize=not args.disable_detokenize, detokenize=not args.disable_detokenize,
) )
print(sampling_params) print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(
size=(args.batch_size, 10000, size=(args.batch_size, args.input_len)
args.input_len)) )
dummy_prompts: list[PromptType] = [{ dummy_prompts: list[PromptType] = [
"prompt_token_ids": batch {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
} for batch in dummy_prompt_token_ids.tolist()] ]
def llm_generate(): def llm_generate():
if not args.use_beam_search: if not args.use_beam_search:
llm.generate(dummy_prompts, llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
sampling_params=sampling_params,
use_tqdm=False)
else: else:
llm.beam_search( llm.beam_search(
dummy_prompts, dummy_prompts,
...@@ -85,7 +86,8 @@ def main(args: argparse.Namespace): ...@@ -85,7 +86,8 @@ def main(args: argparse.Namespace):
torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CUDA,
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir)), str(profile_dir)
),
) as p: ) as p:
llm_generate() llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total")) print(p.key_averages().table(sort_by="self_cuda_time_total"))
...@@ -103,8 +105,9 @@ def main(args: argparse.Namespace): ...@@ -103,8 +105,9 @@ def main(args: argparse.Namespace):
if args.profile: if args.profile:
profile_dir = args.profile_result_dir profile_dir = args.profile_result_dir
if not profile_dir: if not profile_dir:
profile_dir = (Path(".") / "vllm_benchmark_result" / profile_dir = (
f"latency_result_{time.time()}") Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
)
print(f"Profiling (results will be saved to '{profile_dir}')...") print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir) run_to_completion(profile_dir=profile_dir)
return return
...@@ -135,7 +138,8 @@ def main(args: argparse.Namespace): ...@@ -135,7 +138,8 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the latency of processing a single batch of " description="Benchmark the latency of processing a single batch of "
"requests till completion.") "requests till completion."
)
parser.add_argument("--input-len", type=int, default=32) parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128) parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
...@@ -152,10 +156,9 @@ if __name__ == "__main__": ...@@ -152,10 +156,9 @@ if __name__ == "__main__":
default=10, default=10,
help="Number of iterations to run for warmup.", help="Number of iterations to run for warmup.",
) )
parser.add_argument("--num-iters", parser.add_argument(
type=int, "--num-iters", type=int, default=30, help="Number of iterations to run."
default=30, )
help="Number of iterations to run.")
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
...@@ -165,8 +168,10 @@ if __name__ == "__main__": ...@@ -165,8 +168,10 @@ if __name__ == "__main__":
"--profile-result-dir", "--profile-result-dir",
type=str, type=str,
default=None, default=None,
help=("path to save the pytorch profiler output. Can be visualized " help=(
"with ui.perfetto.dev or Tensorboard."), "path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."
),
) )
parser.add_argument( parser.add_argument(
"--output-json", "--output-json",
...@@ -177,8 +182,10 @@ if __name__ == "__main__": ...@@ -177,8 +182,10 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--disable-detokenize", "--disable-detokenize",
action="store_true", action="store_true",
help=("Do not detokenize responses (i.e. do not include " help=(
"detokenization time in the latency measurement)"), "Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
......
...@@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str): ...@@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str):
ValueError: If an invalid mode is provided. ValueError: If an invalid mode is provided.
""" """
print("Repeat mode: ", mode) print("Repeat mode: ", mode)
if mode == 'random': if mode == "random":
repeated_prompts = prompts * repeat_count repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts) random.shuffle(repeated_prompts)
return repeated_prompts return repeated_prompts
elif mode == 'tile': elif mode == "tile":
return prompts * repeat_count return prompts * repeat_count
elif mode == 'interleave': elif mode == "interleave":
repeated_prompts = [] repeated_prompts = []
for prompt in prompts: for prompt in prompts:
repeated_prompts.extend([prompt] * repeat_count) repeated_prompts.extend([prompt] * repeat_count)
return repeated_prompts return repeated_prompts
else: else:
raise ValueError(f"Invalid mode: {mode}, only support " raise ValueError(
"'random', 'tile', 'interleave'") f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
)
def main(args): def main(args):
...@@ -109,16 +110,16 @@ def main(args): ...@@ -109,16 +110,16 @@ def main(args):
# we append the document id at the beginning to avoid any of the document # we append the document id at the beginning to avoid any of the document
# being the prefix of other documents # being the prefix of other documents
prompts = [ prompts = [
str(i) + ' '.join(['hi'] * args.document_length) str(i) + " ".join(["hi"] * args.document_length)
for i in range(args.num_documents) for i in range(args.num_documents)
] ]
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
warmup_prompts = [ warmup_prompts = [
"This is warm up request " + str(i) + \ "This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
' '.join(['hi'] * args.document_length) for i in range(args.num_documents)
for i in range(args.num_documents)] ]
# Create the LLM engine # Create the LLM engine
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
...@@ -142,42 +143,52 @@ def main(args): ...@@ -142,42 +143,52 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description= description="Benchmark the performance with or "
'Benchmark the performance with or without automatic prefix caching.') "without automatic prefix caching."
)
parser.add_argument( parser.add_argument(
'--document-length', "--document-length",
type=int, type=int,
# Roughly the number of tokens for a system paper, # Roughly the number of tokens for a system paper,
# excluding images # excluding images
default=20000, default=20000,
help='Range of input lengths for sampling prompts,' help="Range of input lengths for sampling prompts, "
'specified as "min:max" (e.g., "128:256").') 'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument('--num-documents', parser.add_argument(
"--num-documents",
type=int, type=int,
default=8, default=8,
help='Range of input lengths for sampling prompts,' help="Range of input lengths for sampling prompts, "
'specified as "min:max" (e.g., "128:256").') 'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument('--output-len', type=int, default=10) parser.add_argument("--output-len", type=int, default=10)
parser.add_argument('--repeat-count', parser.add_argument(
"--repeat-count",
type=int, type=int,
default=2, default=2,
help='Number of times to repeat each prompt') help="Number of times to repeat each prompt",
)
parser.add_argument("--repeat-mode", parser.add_argument(
"--repeat-mode",
type=str, type=str,
default='random', default="random",
help='The mode to repeat prompts. The supported ' help="The mode to repeat prompts. The supported "
'modes are "random", "tile", and "interleave". ' 'modes are "random", "tile", and "interleave". '
'See repeat_prompts() in the source code for details.') "See repeat_prompts() in the source code for details.",
)
parser.add_argument("--shuffle-seed", parser.add_argument(
"--shuffle-seed",
type=int, type=int,
default=0, default=0,
help='Random seed when the repeat mode is "random"') help='Random seed when the repeat mode is "random"',
)
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -63,8 +63,7 @@ class Request: ...@@ -63,8 +63,7 @@ class Request:
output_len: int output_len: int
def sample_tokens(tokenizer: PreTrainedTokenizerBase, def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
length: int) -> list[int]:
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids) all_special_ids = set(tokenizer.all_special_ids)
...@@ -91,8 +90,10 @@ def sample_requests_from_dataset( ...@@ -91,8 +90,10 @@ def sample_requests_from_dataset(
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"], dataset = [
data["conversations"][1]["value"]) for data in dataset] (data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset. # Shuffle the dataset.
random.shuffle(dataset) random.shuffle(dataset)
...@@ -113,8 +114,9 @@ def sample_requests_from_dataset( ...@@ -113,8 +114,9 @@ def sample_requests_from_dataset(
completion = dataset[i][1] completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = (len(completion_token_ids) output_len = (
if fixed_output_len is None else fixed_output_len) len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if min_len <= prompt_len <= max_len: if min_len <= prompt_len <= max_len:
filtered_requests.append(Request(prompt, prompt_len, output_len)) filtered_requests.append(Request(prompt, prompt_len, output_len))
...@@ -128,27 +130,27 @@ def sample_requests_from_random( ...@@ -128,27 +130,27 @@ def sample_requests_from_random(
fixed_output_len: Optional[int], fixed_output_len: Optional[int],
prefix_len: int, prefix_len: int,
) -> list[Request]: ) -> list[Request]:
requests = [] requests = []
prefix_token_ids = sample_tokens(tokenizer, prefix_len) prefix_token_ids = sample_tokens(tokenizer, prefix_len)
min_len, max_len = input_length_range min_len, max_len = input_length_range
for i in range(num_requests): for i in range(num_requests):
unique_part_token_ids = sample_tokens( unique_part_token_ids = sample_tokens(
tokenizer, tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len)
random.randint(min_len - prefix_len, max_len - prefix_len)) )
prompt_token_ids = prefix_token_ids + unique_part_token_ids prompt_token_ids = prefix_token_ids + unique_part_token_ids
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
assert (min_len <= prompt_len <= max_len assert min_len <= prompt_len <= max_len, (
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
)
requests.append(Request(prompt, prompt_len, fixed_output_len)) requests.append(Request(prompt, prompt_len, fixed_output_len))
return requests return requests
def repeat_and_sort_requests(requests: list[Request], def repeat_and_sort_requests(
repeat_count: int, requests: list[Request], repeat_count: int, sort: bool = False
sort: bool = False) -> list[str]: ) -> list[str]:
repeated_requests = requests * repeat_count repeated_requests = requests * repeat_count
if sort: if sort:
repeated_requests.sort(key=lambda x: x[1]) repeated_requests.sort(key=lambda x: x[1])
...@@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request], ...@@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request],
def main(args): def main(args):
tokenizer = get_tokenizer(args.model, trust_remote_code=True) tokenizer = get_tokenizer(args.model, trust_remote_code=True)
input_length_range = tuple(map(int, args.input_length_range.split(':'))) input_length_range = tuple(map(int, args.input_length_range.split(":")))
random.seed(args.seed) random.seed(args.seed)
if args.dataset_path is not None: if args.dataset_path is not None:
if args.prefix_len > 0: if args.prefix_len > 0:
raise ValueError("prefix-len is not supported when " raise ValueError(
"dataset-path is provided.") "prefix-len is not supported when dataset-path is provided."
print(f"Start to sample {args.num_prompts} prompts " )
f"from {args.dataset_path}") print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}")
filtered_requests = sample_requests_from_dataset( filtered_requests = sample_requests_from_dataset(
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
...@@ -196,14 +198,16 @@ def main(args): ...@@ -196,14 +198,16 @@ def main(args):
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
sampling_params = SamplingParams(temperature=0, sampling_params = SamplingParams(
temperature=0,
max_tokens=args.output_len, max_tokens=args.output_len,
detokenize=not args.disable_detokenize) detokenize=not args.disable_detokenize,
)
print("Testing filtered requests") print("Testing filtered requests")
prompts = repeat_and_sort_requests(filtered_requests, prompts = repeat_and_sort_requests(
repeat_count=args.repeat_count, filtered_requests, repeat_count=args.repeat_count, sort=args.sort
sort=args.sort) )
print("------start generating------") print("------start generating------")
test_prefix( test_prefix(
...@@ -215,29 +219,35 @@ def main(args): ...@@ -215,29 +219,35 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description= description="Benchmark the performance with or without "
'Benchmark the performance with or without automatic prefix caching.') "automatic prefix caching."
parser.add_argument("--dataset-path", )
type=str, parser.add_argument(
default=None, "--dataset-path", type=str, default=None, help="Path to the dataset."
help="Path to the dataset.") )
parser.add_argument('--output-len', type=int, default=10) parser.add_argument("--output-len", type=int, default=10)
parser.add_argument('--num-prompts', parser.add_argument(
"--num-prompts",
type=int, type=int,
required=True, required=True,
help="Number of the prompts sampled from dataset") help="Number of the prompts sampled from dataset",
parser.add_argument('--repeat-count', )
parser.add_argument(
"--repeat-count",
type=int, type=int,
default=1, default=1,
help='Number of times to repeat each prompt') help="Number of times to repeat each prompt",
parser.add_argument('--sort', )
action='store_true', parser.add_argument(
help='Sort prompts by input length') "--sort", action="store_true", help="Sort prompts by input length"
parser.add_argument('--input-length-range', )
parser.add_argument(
"--input-length-range",
type=str, type=str,
required=True, required=True,
help='Range of input lengths for sampling prompts,' help="Range of input lengths for sampling prompts,"
'specified as "min:max" (e.g., "128:256").') 'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument( parser.add_argument(
"--prefix-len", "--prefix-len",
type=int, type=int,
...@@ -248,10 +258,12 @@ if __name__ == "__main__": ...@@ -248,10 +258,12 @@ if __name__ == "__main__":
"when dataset-path is not provided.", "when dataset-path is not provided.",
) )
parser.add_argument( parser.add_argument(
'--disable-detokenize', "--disable-detokenize",
action='store_true', action="store_true",
help=("Do not detokenize responses (i.e. do not include " help=(
"detokenization time in the latency measurement)"), "Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Benchmark offline prioritization.""" """Benchmark offline prioritization."""
import argparse import argparse
import dataclasses import dataclasses
import json import json
...@@ -13,7 +14,7 @@ from vllm.engine.arg_utils import EngineArgs ...@@ -13,7 +14,7 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
#Select a equi-probable random priority # Select a equi-probable random priority
def get_random_flag(): def get_random_flag():
return 0 if random.random() < 0.5 else 1 return 0 if random.random() < 0.5 else 1
...@@ -33,8 +34,10 @@ def sample_requests( ...@@ -33,8 +34,10 @@ def sample_requests(
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"], dataset = [
data["conversations"][1]["value"]) for data in dataset] (data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset. # Shuffle the dataset.
random.shuffle(dataset) random.shuffle(dataset)
...@@ -51,8 +54,9 @@ def sample_requests( ...@@ -51,8 +54,9 @@ def sample_requests(
completion = dataset[i][1] completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids output_len = (
) if fixed_output_len is None else fixed_output_len len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4: if prompt_len < 4 or output_len < 4:
# Prune too short sequences. # Prune too short sequences.
continue continue
...@@ -74,13 +78,16 @@ def run_vllm( ...@@ -74,13 +78,16 @@ def run_vllm(
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert all( assert all(
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
for request in requests), ( for request in requests
), (
"Please ensure that max_model_len is greater than the sum of" "Please ensure that max_model_len is greater than the sum of"
" input_len and output_len for all requests.") " input_len and output_len for all requests."
)
# Add the requests to the engine. # Add the requests to the engine.
prompts = [] prompts = []
...@@ -97,7 +104,8 @@ def run_vllm( ...@@ -97,7 +104,8 @@ def run_vllm(
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) )
)
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
...@@ -111,26 +119,33 @@ def main(args: argparse.Namespace): ...@@ -111,26 +119,33 @@ def main(args: argparse.Namespace):
# Sample the requests. # Sample the requests.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.dataset is None: if args.dataset is None:
# Synthesize a prompt with the given input length. # Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1) prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len, requests = [
get_random_flag()) for _ in range(args.num_prompts)] (prompt, args.input_len, args.output_len, get_random_flag())
for _ in range(args.num_prompts)
]
else: else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer, requests = sample_requests(
args.output_len) args.dataset, args.num_prompts, tokenizer, args.output_len
)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.n, elapsed_time = run_vllm(
EngineArgs.from_cli_args(args), requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
args.disable_detokenize) )
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len total_num_tokens = sum(
for _, prompt_len, output_len, priority in requests) prompt_len + output_len for _, prompt_len, output_len, priority in requests
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " )
f"{total_num_tokens / elapsed_time:.2f} tokens/s") print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s"
)
# Output JSON results if specified # Output JSON results if specified
if args.output_json: if args.output_json:
...@@ -147,41 +162,44 @@ def main(args: argparse.Namespace): ...@@ -147,41 +162,44 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", parser.add_argument(
type=str, "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
choices=["vllm", "hf", "mii"], )
default="vllm") parser.add_argument(
parser.add_argument("--dataset", "--dataset", type=str, default=None, help="Path to the dataset."
type=str, )
default=None, parser.add_argument(
help="Path to the dataset.") "--input-len",
parser.add_argument("--input-len",
type=int, type=int,
default=None, default=None,
help="Input prompt length for each request") help="Input prompt length for each request",
parser.add_argument("--output-len", )
parser.add_argument(
"--output-len",
type=int, type=int,
default=None, default=None,
help="Output length for each request. Overrides the " help="Output length for each request. Overrides the "
"output length from the dataset.") "output length from the dataset.",
parser.add_argument("--n", )
type=int, parser.add_argument(
default=1, "--n", type=int, default=1, help="Number of generated sequences per prompt."
help="Number of generated sequences per prompt.") )
parser.add_argument("--num-prompts",
type=int,
default=200,
help="Number of prompts to process.")
parser.add_argument( parser.add_argument(
'--output-json', "--num-prompts", type=int, default=200, help="Number of prompts to process."
)
parser.add_argument(
"--output-json",
type=str, type=str,
default=None, default=None,
help='Path to save the throughput results in JSON format.') help="Path to save the throughput results in JSON format.",
)
parser.add_argument( parser.add_argument(
'--disable-detokenize', "--disable-detokenize",
action='store_true', action="store_true",
help=("Do not detokenize responses (i.e. do not include " help=(
"detokenization time in the latency measurement)"), "Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
......
This diff is collapsed.
This diff is collapsed.
...@@ -7,9 +7,9 @@ import os ...@@ -7,9 +7,9 @@ import os
from typing import Any from typing import Any
def convert_to_pytorch_benchmark_format(args: argparse.Namespace, def convert_to_pytorch_benchmark_format(
metrics: dict[str, list], args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
extra_info: dict[str, Any]) -> list: ) -> list:
""" """
Save the benchmark results in the format used by PyTorch OSS benchmark with Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record on metric per record
...@@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, ...@@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
}, },
} }
tp = record["benchmark"]["extra_info"]["args"].get( tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
"tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata # Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info: if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"][ record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
"tensor_parallel_size"] = extra_info["tensor_parallel_size"] extra_info["tensor_parallel_size"]
)
records.append(record) records.append(record)
...@@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, ...@@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
class InfEncoder(json.JSONEncoder): class InfEncoder(json.JSONEncoder):
def clear_inf(self, o: Any): def clear_inf(self, o: Any):
if isinstance(o, dict): if isinstance(o, dict):
return {k: self.clear_inf(v) for k, v in o.items()} return {k: self.clear_inf(v) for k, v in o.items()}
......
...@@ -23,8 +23,9 @@ DEFAULT_TP_SIZES = [1] ...@@ -23,8 +23,9 @@ DEFAULT_TP_SIZES = [1]
# bench # bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_fn(
**kwargs) -> TMeasurement: label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1 min_run_time = 1
globals = { globals = {
...@@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ...@@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
).blocked_autorange(min_run_time=min_run_time) ).blocked_autorange(min_run_time=min_run_time)
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_int8(
sub_label: str) -> Iterable[TMeasurement]: dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.int8 assert dtype == torch.int8
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, out = ops.cutlass_scaled_sparse_mm(
torch.bfloat16) a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref): if not torch.allclose(out, out_ref):
...@@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ...@@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers = [] timers = []
# pytorch impl - bfloat16 # pytorch impl - bfloat16
timers.append( timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", bench_fn(
torch.mm, a.to(dtype=torch.bfloat16), label,
b.to(dtype=torch.bfloat16))) sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16),
)
)
# pytorch impl - float16 # pytorch impl - float16
timers.append( timers.append(
bench_fn(label, sub_label, bench_fn(
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, label,
a.to(dtype=torch.float16), b.to(dtype=torch.float16))) sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.float16),
b.to(dtype=torch.float16),
)
)
# cutlass impl # cutlass impl
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", bench_fn(
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, label,
torch.bfloat16)) sub_label,
"cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass with bias # cutlass with bias
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", bench_fn(
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, label,
bias)) sub_label,
"cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass sparse impl # cutlass sparse impl
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.bfloat16)) sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass sparse with bias # cutlass sparse with bias
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.bfloat16, bias)) sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
return timers return timers
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_fp8(
sub_label: str) -> Iterable[TMeasurement]: dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, out = ops.cutlass_scaled_sparse_mm(
torch.bfloat16) a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref): if not torch.allclose(out, out_ref):
...@@ -124,13 +180,20 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ...@@ -124,13 +180,20 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# pytorch impl w. bf16 # pytorch impl w. bf16
timers.append( timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", bench_fn(
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), label,
b.to(dtype=torch.bfloat16, device="cuda"))) sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"),
)
)
# pytorch impl: bf16 output, without fp8 fast accum # pytorch impl: bf16 output, without fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
label,
sub_label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm", "pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm, torch._scaled_mm,
...@@ -138,11 +201,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ...@@ -138,11 +201,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
b, b,
scale_a=scale_a, scale_a=scale_a,
scale_b=scale_b, scale_b=scale_b,
out_dtype=torch.bfloat16)) out_dtype=torch.bfloat16,
)
)
# pytorch impl: bf16 output, with fp8 fast accum # pytorch impl: bf16 output, with fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
label,
sub_label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm, torch._scaled_mm,
...@@ -151,11 +217,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ...@@ -151,11 +217,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_a=scale_a, scale_a=scale_a,
scale_b=scale_b, scale_b=scale_b,
out_dtype=torch.bfloat16, out_dtype=torch.bfloat16,
use_fast_accum=True)) use_fast_accum=True,
)
)
# pytorch impl: fp16 output, without fp8 fast accum # pytorch impl: fp16 output, without fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
label,
sub_label, sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm", "pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm, torch._scaled_mm,
...@@ -163,11 +232,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ...@@ -163,11 +232,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
b, b,
scale_a=scale_a, scale_a=scale_a,
scale_b=scale_b, scale_b=scale_b,
out_dtype=torch.float16)) out_dtype=torch.float16,
)
)
# pytorch impl: fp16 output, with fp8 fast accum # pytorch impl: fp16 output, with fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
label,
sub_label, sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm, torch._scaled_mm,
...@@ -176,45 +248,97 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ...@@ -176,45 +248,97 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_a=scale_a, scale_a=scale_a,
scale_b=scale_b, scale_b=scale_b,
out_dtype=torch.float16, out_dtype=torch.float16,
use_fast_accum=True)) use_fast_accum=True,
)
)
# cutlass impl: bf16 output # cutlass impl: bf16 output
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", bench_fn(
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, label,
torch.bfloat16)) sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: bf16 output # cutlass impl: bf16 output
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.bfloat16)) sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: fp16 output # cutlass impl: fp16 output
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.float16)) sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
)
)
# cutlass impl: bf16 output, with bias # cutlass impl: bf16 output, with bias
timers.append( timers.append(
bench_fn(label, sub_label, bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, ops.cutlass_scaled_sparse_mm,
scale_b, torch.bfloat16, bias)) a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass impl: fp16 output, with bias # cutlass impl: fp16 output, with bias
timers.append( timers.append(
bench_fn(label, sub_label, bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, ops.cutlass_scaled_sparse_mm,
scale_b, torch.float16, bias.to(dtype=torch.float16))) a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
bias.to(dtype=torch.float16),
)
)
return timers return timers
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench(
sub_label: str) -> Iterable[TMeasurement]: dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
if dtype == torch.int8: if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label) return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
...@@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]): ...@@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print() compare.print()
def run(dtype: torch.dtype, def run(
MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
f"MKN=({m}x{k}x{n})")
print_timers(timers) print_timers(timers)
results.extend(timers) results.extend(timers)
...@@ -241,10 +365,12 @@ def run(dtype: torch.dtype, ...@@ -241,10 +365,12 @@ def run(dtype: torch.dtype,
# output makers # output makers
def make_output(data: Iterable[TMeasurement], def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[tuple[int, int, int]], MKNs: Iterable[tuple[int, int, int]],
base_description: str, base_description: str,
timestamp=None): timestamp=None,
):
print(f"== All Results {base_description} ====") print(f"== All Results {base_description} ====")
print_timers(data) print_timers(data)
...@@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement], ...@@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args): def run_square_bench(args):
dim_sizes = list( dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs) data = run(args.dtype, MKNs)
...@@ -319,7 +444,7 @@ def run_model_bench(args): ...@@ -319,7 +444,7 @@ def run_model_bench(args):
pkl.dump(all_data, f) pkl.dump(all_data, f)
if __name__ == '__main__': if __name__ == "__main__":
def to_torch_dtype(dt): def to_torch_dtype(dt):
if dt == "int8": if dt == "int8":
...@@ -344,12 +469,15 @@ Benchmark Cutlass GEMM. ...@@ -344,12 +469,15 @@ Benchmark Cutlass GEMM.
Output: Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501 """, # noqa: E501
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dtype", parser.add_argument(
"--dtype",
type=to_torch_dtype, type=to_torch_dtype,
required=True, required=True,
help="Available options are ['int8', 'fp8']") help="Available options are ['int8', 'fp8']",
)
subparsers = parser.add_subparsers(dest="cmd") subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench") square_parser = subparsers.add_parser("square_bench")
...@@ -368,19 +496,19 @@ Benchmark Cutlass GEMM. ...@@ -368,19 +496,19 @@ Benchmark Cutlass GEMM.
range_parser.set_defaults(func=run_range_bench) range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench") model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models", model_parser.add_argument(
"--models",
nargs="+", nargs="+",
type=str, type=str,
default=DEFAULT_MODELS, default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys()) choices=WEIGHT_SHAPES.keys(),
model_parser.add_argument("--tp-sizes", )
nargs="+", model_parser.add_argument(
type=int, "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
default=DEFAULT_TP_SIZES) )
model_parser.add_argument("--batch-sizes", model_parser.add_argument(
nargs="+", "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
type=int, )
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench) model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -10,8 +10,9 @@ import vllm._custom_ops as ops ...@@ -10,8 +10,9 @@ import vllm._custom_ops as ops
def to_fp8(tensor: torch.Tensor) -> torch.Tensor: def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp( return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) dtype=torch.float8_e4m3fn
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
...@@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor: ...@@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16) return tensor.to(dtype=torch.float16)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_tensors(
k: int) -> tuple[torch.Tensor, torch.Tensor]: dtype: torch.dtype, m: int, n: int, k: int
a = torch.randn((m, k), device='cuda') * 5 ) -> tuple[torch.Tensor, torch.Tensor]:
b = torch.randn((n, k), device='cuda').t() * 5 a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
if dtype == torch.int8: if dtype == torch.int8:
return to_int8(a), to_int8(b) return to_int8(a), to_int8(b)
...@@ -49,9 +51,7 @@ def prune_to_2_4(tensor): ...@@ -49,9 +51,7 @@ def prune_to_2_4(tensor):
# Create binary mask # Create binary mask
mask = torch.zeros_like(reshaped) mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1, mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back # Apply mask and reshape back
pruned = reshaped * mask pruned = reshaped * mask
...@@ -62,10 +62,11 @@ def prune_to_2_4(tensor): ...@@ -62,10 +62,11 @@ def prune_to_2_4(tensor):
return pruned.reshape(original_shape) return pruned.reshape(original_shape)
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_sparse_tensors(
k: int) -> tuple[torch.Tensor, torch.Tensor]: dtype: torch.dtype, m: int, n: int, k: int
a = torch.randn((m, k), device='cuda') * 5 ) -> tuple[torch.Tensor, torch.Tensor]:
b = torch.randn((n, k), device='cuda').t() * 5 a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
b = prune_to_2_4(b.t()).t() b = prune_to_2_4(b.t()).t()
...@@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, ...@@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
return b_compressed, e, a, b return b_compressed, e, a, b
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, def make_n_rand_sparse_tensors(
m: int, n: int, k: int) -> \ num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: ) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
ABs = [] ABs = []
for _ in range(num_tensors): for _ in range(num_tensors):
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
......
...@@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES ...@@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul) w8a8_block_fp8_matmul,
)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
...@@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1] ...@@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1]
# bench # bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_fn(
**kwargs) -> TMeasurement: label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1 min_run_time = 1
globals = { globals = {
...@@ -50,39 +52,42 @@ def bench_int8( ...@@ -50,39 +52,42 @@ def bench_int8(
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels.""" """Benchmark INT8-based kernels."""
assert dtype == torch.int8 assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k) a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
bench_fns = { bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales": "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
),
"cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16
),
"cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16, bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj
),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
), ),
"pytorch_fp16_fp16_fp16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
"cutlass_i8_i8_bf16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
"cutlass_i8_i8_bf16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
bias),
"cutlass_i8_i8_bf16_scaled_mm_azp":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, None, bias),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, azp),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, azp, bias),
} }
timers = [] timers = []
...@@ -102,67 +107,59 @@ def bench_fp8( ...@@ -102,67 +107,59 @@ def bench_fp8(
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels.""" """Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
a_cont = a.contiguous() a_cont = a.contiguous()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
block_scale_a = torch.rand((m, k // 128), block_scale_a = torch.rand((m, k // 128), device="cuda", dtype=torch.float32)
device="cuda", block_scale_b = torch.rand((k // 128, n // 128), device="cuda", dtype=torch.float32)
dtype=torch.float32)
block_scale_b = torch.rand((k // 128, n // 128),
device="cuda",
dtype=torch.float32)
block_scale_a_M_major = block_scale_a.t().contiguous().t() block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t() block_scale_b_K_major = block_scale_b.t().contiguous().t()
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
print(m, k, n) print(m, k, n)
bench_fns = { bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales": "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
),
"pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16
),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
),
"pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16
),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
),
"cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16
),
"cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16
),
"cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16, bias
),
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
), ),
"pytorch_fp16_fp16_fp16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
"pytorch_fp8_fp8_fp16_scaled_mm":
lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
lambda: torch._scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.float16,
use_fast_accum=True),
"pytorch_fp8_fp8_bf16_scaled_mm":
lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
lambda: torch._scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True),
"cutlass_fp8_fp8_bf16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
"cutlass_fp8_fp8_fp16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
bias),
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)),
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
block_scale_b.t(), (128, 128)),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
block_scale_b_K_major, torch.float16),
} }
timers = [] timers = []
...@@ -175,13 +172,15 @@ def bench_fp8( ...@@ -175,13 +172,15 @@ def bench_fp8(
return timers return timers
def bench(dtype: torch.dtype, def bench(
dtype: torch.dtype,
m: int, m: int,
k: int, k: int,
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
if dtype == torch.int8: if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
...@@ -195,27 +194,33 @@ def print_timers(timers: Iterable[TMeasurement]): ...@@ -195,27 +194,33 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print() compare.print()
def run(dtype: torch.dtype, def run(
dtype: torch.dtype,
MKNs: Iterable[tuple[int, int, int]], MKNs: Iterable[tuple[int, int, int]],
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(dtype, timers = bench(
dtype,
m, m,
k, k,
n, n,
f"scaled-{dtype}-gemm", f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})", f"MKN=({m}x{k}x{n})",
bench_kernels=bench_kernels) bench_kernels=bench_kernels,
)
print_timers(timers) print_timers(timers)
results.extend(timers) results.extend(timers)
return results return results
def make_output(data: Iterable[TMeasurement], def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[tuple[int, int, int]], MKNs: Iterable[tuple[int, int, int]],
base_description: str, base_description: str,
timestamp=None): timestamp=None,
):
print(f"== All Results {base_description} ====") print(f"== All Results {base_description} ====")
print_timers(data) print_timers(data)
...@@ -226,8 +231,7 @@ def make_output(data: Iterable[TMeasurement], ...@@ -226,8 +231,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args): def run_square_bench(args):
dim_sizes = list( dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs, bench_kernels=args.kernels) data = run(args.dtype, MKNs, bench_kernels=args.kernels)
make_output(data, MKNs, f"square_bench-{args.dtype}") make_output(data, MKNs, f"square_bench-{args.dtype}")
...@@ -285,7 +289,7 @@ def run_model_bench(args): ...@@ -285,7 +289,7 @@ def run_model_bench(args):
pkl.dump(all_data, f) pkl.dump(all_data, f)
if __name__ == '__main__': if __name__ == "__main__":
def to_torch_dtype(dt): def to_torch_dtype(dt):
if dt == "int8": if dt == "int8":
...@@ -310,19 +314,21 @@ Benchmark Cutlass GEMM. ...@@ -310,19 +314,21 @@ Benchmark Cutlass GEMM.
Output: Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501 """, # noqa: E501
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dtype", parser.add_argument(
"--dtype",
type=to_torch_dtype, type=to_torch_dtype,
required=True, required=True,
help="Available options are ['int8', 'fp8']") help="Available options are ['int8', 'fp8']",
)
parser.add_argument( parser.add_argument(
"--kernels", "--kernels",
nargs="+", nargs="+",
type=str, type=str,
default=None, default=None,
help= help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
"Exact names of the kernels to benchmark. If not set, runs all kernels."
) )
subparsers = parser.add_subparsers(dest="cmd") subparsers = parser.add_subparsers(dest="cmd")
...@@ -343,19 +349,19 @@ Benchmark Cutlass GEMM. ...@@ -343,19 +349,19 @@ Benchmark Cutlass GEMM.
range_parser.set_defaults(func=run_range_bench) range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench") model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models", model_parser.add_argument(
"--models",
nargs="+", nargs="+",
type=str, type=str,
default=DEFAULT_MODELS, default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys()) choices=WEIGHT_SHAPES.keys(),
model_parser.add_argument("--tp-sizes", )
nargs="+", model_parser.add_argument(
type=int, "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
default=DEFAULT_TP_SIZES) )
model_parser.add_argument("--batch-sizes", model_parser.add_argument(
nargs="+", "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
type=int, )
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench) model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -12,39 +12,37 @@ app = Quart(__name__) ...@@ -12,39 +12,37 @@ app = Quart(__name__)
async def forward_request(url, data): async def forward_request(url, data):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" async with session.post(url=url, json=data, headers=headers) as response:
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200: if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked': # if response.headers.get('Transfer-Encoding') == 'chunked':
if True: if True:
async for chunk_bytes in response.content.iter_chunked( async for chunk_bytes in response.content.iter_chunked(1024):
1024):
yield chunk_bytes yield chunk_bytes
else: else:
content = await response.read() content = await response.read()
yield content yield content
@app.route('/v1/completions', methods=['POST']) @app.route("/v1/completions", methods=["POST"])
async def handle_request(): async def handle_request():
try: try:
original_request_data = await request.get_json() original_request_data = await request.get_json()
prefill_request = original_request_data.copy() prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill # change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1 prefill_request["max_tokens"] = 1
# finish prefill # finish prefill
async for _ in forward_request('http://localhost:8100/v1/completions', async for _ in forward_request(
prefill_request): "http://localhost:8100/v1/completions", prefill_request
):
continue continue
# return decode # return decode
generator = forward_request('http://localhost:8200/v1/completions', generator = forward_request(
original_request_data) "http://localhost:8200/v1/completions", original_request_data
)
response = await make_response(generator) response = await make_response(generator)
response.timeout = None response.timeout = None
...@@ -53,11 +51,12 @@ async def handle_request(): ...@@ -53,11 +51,12 @@ async def handle_request():
except Exception as e: except Exception as e:
import sys import sys
import traceback import traceback
exc_info = sys.exc_info() exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server") print("Error occurred in disagg prefill proxy server")
print(e) print(e)
print("".join(traceback.format_exception(*exc_info))) print("".join(traceback.format_exception(*exc_info)))
if __name__ == '__main__': if __name__ == "__main__":
app.run(port=8000) app.run(port=8000)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment