Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
...@@ -6,7 +6,7 @@ import sys ...@@ -6,7 +6,7 @@ import sys
import time import time
import traceback import traceback
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Union from typing import Optional, Union
import aiohttp import aiohttp
import huggingface_hub.constants import huggingface_hub.constants
...@@ -41,8 +41,8 @@ class RequestFuncOutput: ...@@ -41,8 +41,8 @@ class RequestFuncOutput:
latency: float = 0.0 latency: float = 0.0
output_tokens: int = 0 output_tokens: int = 0
ttft: float = 0.0 # Time to first token ttft: float = 0.0 # Time to first token
itl: List[float] = field( itl: list[float] = field(
default_factory=list) # List of inter-token latencies default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0 prompt_len: int = 0
error: str = "" error: str = ""
......
...@@ -6,7 +6,6 @@ import json ...@@ -6,7 +6,6 @@ import json
import os import os
import random import random
import time import time
from typing import List
import datasets import datasets
import pandas as pd import pandas as pd
...@@ -39,7 +38,7 @@ class SampleRequest: ...@@ -39,7 +38,7 @@ class SampleRequest:
completion: str = None completion: str = None
def run_vllm(requests: List[SampleRequest], def run_vllm(requests: list[SampleRequest],
engine_args: EngineArgs, engine_args: EngineArgs,
n: int, n: int,
guided_decoding_rate: float = 1.0, guided_decoding_rate: float = 1.0,
...@@ -54,8 +53,8 @@ def run_vllm(requests: List[SampleRequest], ...@@ -54,8 +53,8 @@ def run_vllm(requests: List[SampleRequest],
" prompt_len and expected_output_len for all requests.") " prompt_len and expected_output_len for all requests.")
# Add the requests to the engine. # Add the requests to the engine.
prompts: List[str] = [] prompts: list[str] = []
sampling_params: List[SamplingParams] = [] sampling_params: list[SamplingParams] = []
# create a list containing random selected true or false # create a list containing random selected true or false
guided_decoding_req_idx = random.sample( guided_decoding_req_idx = random.sample(
range(len(requests)), int(len(requests) * guided_decoding_rate)) range(len(requests)), int(len(requests) * guided_decoding_rate))
...@@ -110,7 +109,7 @@ def run_vllm(requests: List[SampleRequest], ...@@ -110,7 +109,7 @@ def run_vllm(requests: List[SampleRequest],
async def run_vllm_async( async def run_vllm_async(
requests: List[SampleRequest], requests: list[SampleRequest],
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
n: int, n: int,
guided_decoding_rate: float = 1.0, guided_decoding_rate: float = 1.0,
...@@ -129,8 +128,8 @@ async def run_vllm_async( ...@@ -129,8 +128,8 @@ async def run_vllm_async(
" prompt_len and expected_output_len for all requests.") " prompt_len and expected_output_len for all requests.")
# Add the requests to the engine. # Add the requests to the engine.
prompts: List[str] = [] prompts: list[str] = []
sampling_params: List[SamplingParams] = [] sampling_params: list[SamplingParams] = []
guided_decoding_req_idx = random.sample( guided_decoding_req_idx = random.sample(
range(len(requests)), int(len(requests) * guided_decoding_rate)) range(len(requests)), int(len(requests) * guided_decoding_rate))
...@@ -203,7 +202,7 @@ async def run_vllm_async( ...@@ -203,7 +202,7 @@ async def run_vllm_async(
def sample_requests(tokenizer: PreTrainedTokenizerBase, def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]: args: argparse.Namespace) -> list[SampleRequest]:
if args.dataset == 'json': if args.dataset == 'json':
if args.json_schema_path is None: if args.json_schema_path is None:
dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = os.path.dirname(os.path.realpath(__file__))
...@@ -287,7 +286,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -287,7 +286,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
elif args.dataset == "xgrammar_bench": elif args.dataset == "xgrammar_bench":
args.warmup = False args.warmup = False
requests: List[SampleRequest] = [] requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval", dataset = datasets.load_dataset("NousResearch/json-mode-eval",
split="train") split="train")
print(f"dataset has {len(dataset)} entries") print(f"dataset has {len(dataset)} entries")
......
...@@ -7,7 +7,7 @@ import json ...@@ -7,7 +7,7 @@ import json
import os import os
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -22,7 +22,7 @@ from vllm.utils import FlexibleArgumentParser ...@@ -22,7 +22,7 @@ from vllm.utils import FlexibleArgumentParser
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: Dict[str, Any]) -> None: results: dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={"latency": results["latencies"]}, metrics={"latency": results["latencies"]},
...@@ -57,7 +57,7 @@ def main(args: argparse.Namespace): ...@@ -57,7 +57,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size, size=(args.batch_size,
args.input_len)) args.input_len))
dummy_prompts: List[PromptType] = [{ dummy_prompts: list[PromptType] = [{
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]
......
...@@ -31,7 +31,7 @@ import dataclasses ...@@ -31,7 +31,7 @@ import dataclasses
import json import json
import random import random
import time import time
from typing import List, Optional, Tuple from typing import Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -77,9 +77,9 @@ def sample_requests_from_dataset( ...@@ -77,9 +77,9 @@ def sample_requests_from_dataset(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_length_range: Tuple[int, int], input_length_range: tuple[int, int],
fixed_output_len: Optional[int], fixed_output_len: Optional[int],
) -> List[Request]: ) -> list[Request]:
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small") raise ValueError("output_len too small")
...@@ -99,7 +99,7 @@ def sample_requests_from_dataset( ...@@ -99,7 +99,7 @@ def sample_requests_from_dataset(
assert min_len >= 0 and max_len >= min_len, "input_length_range too small" assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
# Filter out sequences that are too long or too short # Filter out sequences that are too long or too short
filtered_requests: List[Request] = [] filtered_requests: list[Request] = []
for i in range(len(dataset)): for i in range(len(dataset)):
if len(filtered_requests) == num_requests: if len(filtered_requests) == num_requests:
...@@ -122,10 +122,10 @@ def sample_requests_from_dataset( ...@@ -122,10 +122,10 @@ def sample_requests_from_dataset(
def sample_requests_from_random( def sample_requests_from_random(
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_length_range: Tuple[int, int], input_length_range: tuple[int, int],
fixed_output_len: Optional[int], fixed_output_len: Optional[int],
prefix_len: int, prefix_len: int,
) -> List[Request]: ) -> list[Request]:
requests = [] requests = []
prefix_token_ids = sample_tokens(tokenizer, prefix_len) prefix_token_ids = sample_tokens(tokenizer, prefix_len)
...@@ -144,9 +144,9 @@ def sample_requests_from_random( ...@@ -144,9 +144,9 @@ def sample_requests_from_random(
return requests return requests
def repeat_and_sort_requests(requests: List[Request], def repeat_and_sort_requests(requests: list[Request],
repeat_count: int, repeat_count: int,
sort: bool = False) -> List[str]: sort: bool = False) -> list[str]:
repeated_requests = requests * repeat_count repeated_requests = requests * repeat_count
if sort: if sort:
repeated_requests.sort(key=lambda x: x[1]) repeated_requests.sort(key=lambda x: x[1])
......
...@@ -5,7 +5,7 @@ import dataclasses ...@@ -5,7 +5,7 @@ import dataclasses
import json import json
import random import random
import time import time
from typing import List, Optional, Tuple from typing import Optional
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
...@@ -23,7 +23,7 @@ def sample_requests( ...@@ -23,7 +23,7 @@ def sample_requests(
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int], fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]: ) -> list[tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small") raise ValueError("output_len too small")
...@@ -40,7 +40,7 @@ def sample_requests( ...@@ -40,7 +40,7 @@ def sample_requests(
random.shuffle(dataset) random.shuffle(dataset)
# Filter out sequences that are too long or too short # Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = [] filtered_dataset: list[tuple[str, int, int]] = []
for i in range(len(dataset)): for i in range(len(dataset)):
if len(filtered_dataset) == num_requests: if len(filtered_dataset) == num_requests:
break break
...@@ -68,7 +68,7 @@ def sample_requests( ...@@ -68,7 +68,7 @@ def sample_requests(
def run_vllm( def run_vllm(
requests: List[Tuple[str, int, int]], requests: list[tuple[str, int, int]],
n: int, n: int,
engine_args: EngineArgs, engine_args: EngineArgs,
) -> float: ) -> float:
......
...@@ -33,9 +33,10 @@ import os ...@@ -33,9 +33,10 @@ import os
import random import random
import time import time
import warnings import warnings
from collections.abc import AsyncGenerator, Collection
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple from typing import Any, Optional
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -73,22 +74,22 @@ class BenchmarkMetrics: ...@@ -73,22 +74,22 @@ class BenchmarkMetrics:
mean_ttft_ms: float mean_ttft_ms: float
median_ttft_ms: float median_ttft_ms: float
std_ttft_ms: float std_ttft_ms: float
percentiles_ttft_ms: List[Tuple[float, float]] percentiles_ttft_ms: list[tuple[float, float]]
mean_tpot_ms: float mean_tpot_ms: float
median_tpot_ms: float median_tpot_ms: float
std_tpot_ms: float std_tpot_ms: float
percentiles_tpot_ms: List[Tuple[float, float]] percentiles_tpot_ms: list[tuple[float, float]]
mean_itl_ms: float mean_itl_ms: float
median_itl_ms: float median_itl_ms: float
std_itl_ms: float std_itl_ms: float
percentiles_itl_ms: List[Tuple[float, float]] percentiles_itl_ms: list[tuple[float, float]]
# E2EL stands for end-to-end latency per request. # E2EL stands for end-to-end latency per request.
# It is the time taken on the client side from sending # It is the time taken on the client side from sending
# a request to receiving a complete response. # a request to receiving a complete response.
mean_e2el_ms: float mean_e2el_ms: float
median_e2el_ms: float median_e2el_ms: float
std_e2el_ms: float std_e2el_ms: float
percentiles_e2el_ms: List[Tuple[float, float]] percentiles_e2el_ms: list[tuple[float, float]]
def sample_sharegpt_requests( def sample_sharegpt_requests(
...@@ -96,7 +97,7 @@ def sample_sharegpt_requests( ...@@ -96,7 +97,7 @@ def sample_sharegpt_requests(
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int, None]]: ) -> list[tuple[str, int, int, None]]:
# Load the dataset. # Load the dataset.
with open(dataset_path, encoding='utf-8') as f: with open(dataset_path, encoding='utf-8') as f:
dataset = json.load(f) dataset = json.load(f)
...@@ -110,7 +111,7 @@ def sample_sharegpt_requests( ...@@ -110,7 +111,7 @@ def sample_sharegpt_requests(
random.shuffle(dataset) random.shuffle(dataset)
# Filter out sequences that are too long or too short # Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = [] filtered_dataset: list[tuple[str, int, int]] = []
for i in range(len(dataset)): for i in range(len(dataset)):
if len(filtered_dataset) == num_requests: if len(filtered_dataset) == num_requests:
break break
...@@ -139,7 +140,7 @@ def sample_burstgpt_requests( ...@@ -139,7 +140,7 @@ def sample_burstgpt_requests(
num_requests: int, num_requests: int,
random_seed: int, random_seed: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int, None]]: ) -> list[tuple[str, int, int, None]]:
df = pd.read_csv(dataset_path) df = pd.read_csv(dataset_path)
gpt4_df = df[df["Model"] == "GPT-4"] gpt4_df = df[df["Model"] == "GPT-4"]
# Remove the failed requests (i.e., response length is 0) # Remove the failed requests (i.e., response length is 0)
...@@ -170,7 +171,7 @@ def sample_sonnet_requests( ...@@ -170,7 +171,7 @@ def sample_sonnet_requests(
output_len: int, output_len: int,
prefix_len: int, prefix_len: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, str, int, int, None]]: ) -> list[tuple[str, str, int, int, None]]:
assert ( assert (
input_len > prefix_len input_len > prefix_len
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
...@@ -211,7 +212,7 @@ def sample_sonnet_requests( ...@@ -211,7 +212,7 @@ def sample_sonnet_requests(
prefix_lines = poem_lines[:num_prefix_lines] prefix_lines = poem_lines[:num_prefix_lines]
# Sample the rest of lines per request. # Sample the rest of lines per request.
sampled_requests: List[Tuple[str, int, int]] = [] sampled_requests: list[tuple[str, int, int]] = []
for _ in range(num_requests): for _ in range(num_requests):
num_lines_needed = num_input_lines - num_prefix_lines num_lines_needed = num_input_lines - num_prefix_lines
sampled_lines = "".join(prefix_lines + sampled_lines = "".join(prefix_lines +
...@@ -238,8 +239,8 @@ def sample_vision_arena_requests( ...@@ -238,8 +239,8 @@ def sample_vision_arena_requests(
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: ) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
sampled_requests: List[Tuple[str, int, int, Dict[str, sampled_requests: list[tuple[str, int, int, dict[str,
Collection[str]]]] = [] Collection[str]]]] = []
for data in dataset: for data in dataset:
if len(sampled_requests) == num_requests: if len(sampled_requests) == num_requests:
...@@ -285,7 +286,7 @@ def sample_hf_requests( ...@@ -285,7 +286,7 @@ def sample_hf_requests(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
random_seed: int, random_seed: int,
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: ) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
# Special case for vision_arena dataset # Special case for vision_arena dataset
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \ if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
...@@ -307,7 +308,7 @@ def sample_hf_requests( ...@@ -307,7 +308,7 @@ def sample_hf_requests(
"HF Dataset must have 'conversations' column.") "HF Dataset must have 'conversations' column.")
filter_func = lambda x: len(x["conversations"]) >= 2 filter_func = lambda x: len(x["conversations"]) >= 2
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func) filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
sampled_requests: List[Tuple[str, int, int, Dict[str, sampled_requests: list[tuple[str, int, int, dict[str,
Collection[str]]]] = [] Collection[str]]]] = []
for data in filtered_dataset: for data in filtered_dataset:
if len(sampled_requests) == num_requests: if len(sampled_requests) == num_requests:
...@@ -370,7 +371,7 @@ def sample_random_requests( ...@@ -370,7 +371,7 @@ def sample_random_requests(
num_prompts: int, num_prompts: int,
range_ratio: float, range_ratio: float,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]: ) -> list[tuple[str, int, int]]:
prefix_token_ids = np.random.randint(0, prefix_token_ids = np.random.randint(0,
tokenizer.vocab_size, tokenizer.vocab_size,
size=prefix_len).tolist() size=prefix_len).tolist()
...@@ -399,10 +400,10 @@ def sample_random_requests( ...@@ -399,10 +400,10 @@ def sample_random_requests(
async def get_request( async def get_request(
input_requests: List[Tuple[str, int, int]], input_requests: list[tuple[str, int, int]],
request_rate: float, request_rate: float,
burstiness: float = 1.0, burstiness: float = 1.0,
) -> AsyncGenerator[Tuple[str, int, int], None]: ) -> AsyncGenerator[tuple[str, int, int], None]:
""" """
Asynchronously generates requests at a specified rate Asynchronously generates requests at a specified rate
with OPTIONAL burstiness. with OPTIONAL burstiness.
...@@ -443,23 +444,23 @@ async def get_request( ...@@ -443,23 +444,23 @@ async def get_request(
def calculate_metrics( def calculate_metrics(
input_requests: List[Tuple[str, int, int]], input_requests: list[tuple[str, int, int]],
outputs: List[RequestFuncOutput], outputs: list[RequestFuncOutput],
dur_s: float, dur_s: float,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
selected_percentile_metrics: List[str], selected_percentile_metrics: list[str],
selected_percentiles: List[float], selected_percentiles: list[float],
goodput_config_dict: Dict[str, float], goodput_config_dict: dict[str, float],
) -> Tuple[BenchmarkMetrics, List[int]]: ) -> tuple[BenchmarkMetrics, list[int]]:
actual_output_lens: List[int] = [] actual_output_lens: list[int] = []
total_input = 0 total_input = 0
completed = 0 completed = 0
good_completed = 0 good_completed = 0
itls: List[float] = [] itls: list[float] = []
tpots: List[float] = [] tpots: list[float] = []
all_tpots: List[float] = [] all_tpots: list[float] = []
ttfts: List[float] = [] ttfts: list[float] = []
e2els: List[float] = [] e2els: list[float] = []
for i in range(len(outputs)): for i in range(len(outputs)):
if outputs[i].success: if outputs[i].success:
output_len = outputs[i].output_tokens output_len = outputs[i].output_tokens
...@@ -557,19 +558,19 @@ async def benchmark( ...@@ -557,19 +558,19 @@ async def benchmark(
model_id: str, model_id: str,
model_name: str, model_name: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]], input_requests: list[tuple[str, int, int]],
logprobs: Optional[int], logprobs: Optional[int],
best_of: int, best_of: int,
request_rate: float, request_rate: float,
burstiness: float, burstiness: float,
disable_tqdm: bool, disable_tqdm: bool,
profile: bool, profile: bool,
selected_percentile_metrics: List[str], selected_percentile_metrics: list[str],
selected_percentiles: List[str], selected_percentiles: list[str],
ignore_eos: bool, ignore_eos: bool,
goodput_config_dict: Dict[str, float], goodput_config_dict: dict[str, float],
max_concurrency: Optional[int], max_concurrency: Optional[int],
lora_modules: Optional[List[str]], lora_modules: Optional[list[str]],
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
...@@ -652,7 +653,7 @@ async def benchmark( ...@@ -652,7 +653,7 @@ async def benchmark(
pbar=pbar) pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = [] tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness): async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request prompt, prompt_len, output_len, mm_content = request
req_model_id, req_model_name = model_id, model_name req_model_id, req_model_name = model_id, model_name
...@@ -674,7 +675,7 @@ async def benchmark( ...@@ -674,7 +675,7 @@ async def benchmark(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input,
pbar=pbar))) pbar=pbar)))
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile: if profile:
print("Stopping profiler...") print("Stopping profiler...")
...@@ -820,7 +821,7 @@ def parse_goodput(slo_pairs): ...@@ -820,7 +821,7 @@ def parse_goodput(slo_pairs):
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: Dict[str, Any], results: dict[str, Any],
file_name: str) -> None: file_name: str) -> None:
metrics = [ metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
...@@ -974,7 +975,7 @@ def main(args: argparse.Namespace): ...@@ -974,7 +975,7 @@ def main(args: argparse.Namespace):
# Save config and results to json # Save config and results to json
if args.save_result: if args.save_result:
result_json: Dict[str, Any] = {} result_json: dict[str, Any] = {}
# Setup # Setup
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
......
...@@ -30,8 +30,9 @@ import os ...@@ -30,8 +30,9 @@ import os
import random import random
import time import time
import warnings import warnings
from collections.abc import AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
from typing import AsyncGenerator, Dict, List, Optional, Tuple from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
...@@ -66,22 +67,22 @@ class BenchmarkMetrics: ...@@ -66,22 +67,22 @@ class BenchmarkMetrics:
mean_ttft_ms: float mean_ttft_ms: float
median_ttft_ms: float median_ttft_ms: float
std_ttft_ms: float std_ttft_ms: float
percentiles_ttft_ms: List[Tuple[float, float]] percentiles_ttft_ms: list[tuple[float, float]]
mean_tpot_ms: float mean_tpot_ms: float
median_tpot_ms: float median_tpot_ms: float
std_tpot_ms: float std_tpot_ms: float
percentiles_tpot_ms: List[Tuple[float, float]] percentiles_tpot_ms: list[tuple[float, float]]
mean_itl_ms: float mean_itl_ms: float
median_itl_ms: float median_itl_ms: float
std_itl_ms: float std_itl_ms: float
percentiles_itl_ms: List[Tuple[float, float]] percentiles_itl_ms: list[tuple[float, float]]
# E2EL stands for end-to-end latency per request. # E2EL stands for end-to-end latency per request.
# It is the time taken on the client side from sending # It is the time taken on the client side from sending
# a request to receiving a complete response. # a request to receiving a complete response.
mean_e2el_ms: float mean_e2el_ms: float
median_e2el_ms: float median_e2el_ms: float
std_e2el_ms: float std_e2el_ms: float
percentiles_e2el_ms: List[Tuple[float, float]] percentiles_e2el_ms: list[tuple[float, float]]
@dataclasses.dataclass @dataclasses.dataclass
...@@ -104,7 +105,7 @@ class SampleRequest: ...@@ -104,7 +105,7 @@ class SampleRequest:
def sample_requests(tokenizer: PreTrainedTokenizerBase, def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]: args: argparse.Namespace) -> list[SampleRequest]:
if args.dataset == 'json': if args.dataset == 'json':
if args.json_schema_path is None: if args.json_schema_path is None:
dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = os.path.dirname(os.path.realpath(__file__))
...@@ -187,7 +188,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -187,7 +188,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
] ]
elif args.dataset == "xgrammar_bench": elif args.dataset == "xgrammar_bench":
requests: List[SampleRequest] = [] requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval", dataset = datasets.load_dataset("NousResearch/json-mode-eval",
split="train") split="train")
print(f"dataset has {len(dataset)} entries") print(f"dataset has {len(dataset)} entries")
...@@ -214,10 +215,10 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -214,10 +215,10 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
async def get_request( async def get_request(
input_requests: List[SampleRequest], input_requests: list[SampleRequest],
request_rate: float, request_rate: float,
burstiness: float = 1.0, burstiness: float = 1.0,
) -> AsyncGenerator[Tuple[int, SampleRequest], None]: ) -> AsyncGenerator[tuple[int, SampleRequest], None]:
""" """
Asynchronously generates requests at a specified rate Asynchronously generates requests at a specified rate
with OPTIONAL burstiness. with OPTIONAL burstiness.
...@@ -258,23 +259,23 @@ async def get_request( ...@@ -258,23 +259,23 @@ async def get_request(
def calculate_metrics( def calculate_metrics(
input_requests: List[Tuple[str, int, int]], input_requests: list[tuple[str, int, int]],
outputs: List[RequestFuncOutput], outputs: list[RequestFuncOutput],
dur_s: float, dur_s: float,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
selected_percentile_metrics: List[str], selected_percentile_metrics: list[str],
selected_percentiles: List[float], selected_percentiles: list[float],
goodput_config_dict: Optional[Dict[str, float]] = None, goodput_config_dict: Optional[dict[str, float]] = None,
) -> Tuple[BenchmarkMetrics, List[int]]: ) -> tuple[BenchmarkMetrics, list[int]]:
actual_output_lens: List[int] = [] actual_output_lens: list[int] = []
total_input = 0 total_input = 0
completed = 0 completed = 0
good_completed = 0 good_completed = 0
itls: List[float] = [] itls: list[float] = []
tpots: List[float] = [] tpots: list[float] = []
all_tpots: List[float] = [] all_tpots: list[float] = []
ttfts: List[float] = [] ttfts: list[float] = []
e2els: List[float] = [] e2els: list[float] = []
for i in range(len(outputs)): for i in range(len(outputs)):
if outputs[i].success: if outputs[i].success:
# We use the tokenizer to count the number of output tokens for all # We use the tokenizer to count the number of output tokens for all
...@@ -368,18 +369,18 @@ async def benchmark( ...@@ -368,18 +369,18 @@ async def benchmark(
base_url: str, base_url: str,
model_id: str, model_id: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_requests: List[SampleRequest], input_requests: list[SampleRequest],
request_rate: float, request_rate: float,
burstiness: float, burstiness: float,
disable_tqdm: bool, disable_tqdm: bool,
profile: bool, profile: bool,
selected_percentile_metrics: List[str], selected_percentile_metrics: list[str],
selected_percentiles: List[str], selected_percentiles: list[str],
ignore_eos: bool, ignore_eos: bool,
max_concurrency: Optional[int], max_concurrency: Optional[int],
guided_decoding_ratio: float, guided_decoding_ratio: float,
guided_decoding_backend: str, guided_decoding_backend: str,
goodput_config_dict: Optional[Dict[str, float]] = None, goodput_config_dict: Optional[dict[str, float]] = None,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
...@@ -459,8 +460,8 @@ async def benchmark( ...@@ -459,8 +460,8 @@ async def benchmark(
pbar=pbar) pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = [] tasks: list[asyncio.Task] = []
expected: List[str] = [] expected: list[str] = []
async for i, request in get_request(input_requests, request_rate, async for i, request in get_request(input_requests, request_rate,
burstiness): burstiness):
extra_body = prepare_extra_body( extra_body = prepare_extra_body(
...@@ -479,7 +480,7 @@ async def benchmark( ...@@ -479,7 +480,7 @@ async def benchmark(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input,
pbar=pbar))) pbar=pbar)))
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile: if profile:
print("Stopping profiler...") print("Stopping profiler...")
......
...@@ -7,7 +7,7 @@ import os ...@@ -7,7 +7,7 @@ import os
import random import random
import time import time
from functools import cache from functools import cache
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Optional
import torch import torch
import uvloop import uvloop
...@@ -74,12 +74,12 @@ def lora_path_on_disk(lora_path: str) -> str: ...@@ -74,12 +74,12 @@ def lora_path_on_disk(lora_path: str) -> str:
return get_adapter_absolute_path(lora_path) return get_adapter_absolute_path(lora_path)
lora_tokenizer_cache: Dict[int, AnyTokenizer] = {} lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
def get_random_lora_request( def get_random_lora_request(
args: argparse.Namespace args: argparse.Namespace
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]: ) -> tuple[LoRARequest, Optional[AnyTokenizer]]:
global lora_tokenizer_cache global lora_tokenizer_cache
lora_id = random.randint(1, args.max_loras) lora_id = random.randint(1, args.max_loras)
lora_request = LoRARequest(lora_name=str(lora_id), lora_request = LoRARequest(lora_name=str(lora_id),
...@@ -91,7 +91,7 @@ def get_random_lora_request( ...@@ -91,7 +91,7 @@ def get_random_lora_request(
def sample_requests(tokenizer: PreTrainedTokenizerBase, def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]: args: argparse.Namespace) -> list[SampleRequest]:
dataset_path: str = args.dataset dataset_path: str = args.dataset
num_requests: int = args.num_prompts num_requests: int = args.num_prompts
...@@ -109,7 +109,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -109,7 +109,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
random.shuffle(dataset) random.shuffle(dataset)
# Filter out sequences that are too long or too short # Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = [] filtered_dataset: list[SampleRequest] = []
for data in tqdm(dataset, for data in tqdm(dataset,
total=len(filtered_dataset), total=len(filtered_dataset),
desc="sampling requests"): desc="sampling requests"):
...@@ -165,7 +165,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -165,7 +165,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
def run_vllm( def run_vllm(
requests: List[SampleRequest], requests: list[SampleRequest],
n: int, n: int,
engine_args: EngineArgs, engine_args: EngineArgs,
) -> float: ) -> float:
...@@ -178,8 +178,8 @@ def run_vllm( ...@@ -178,8 +178,8 @@ def run_vllm(
"Please ensure that max_model_len is greater than the sum of" "Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.") " prompt_len and expected_output_len for all requests.")
# Add the requests to the engine. # Add the requests to the engine.
prompts: List[TextPrompt] = [] prompts: list[TextPrompt] = []
sampling_params: List[SamplingParams] = [] sampling_params: list[SamplingParams] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TextPrompt(prompt=request.prompt, TextPrompt(prompt=request.prompt,
...@@ -192,7 +192,7 @@ def run_vllm( ...@@ -192,7 +192,7 @@ def run_vllm(
ignore_eos=True, ignore_eos=True,
max_tokens=request.expected_output_len, max_tokens=request.expected_output_len,
)) ))
lora_requests: Optional[List[LoRARequest]] = None lora_requests: Optional[list[LoRARequest]] = None
if engine_args.enable_lora: if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests] lora_requests = [request.lora_request for request in requests]
...@@ -225,7 +225,7 @@ def run_vllm( ...@@ -225,7 +225,7 @@ def run_vllm(
async def run_vllm_async( async def run_vllm_async(
requests: List[SampleRequest], requests: list[SampleRequest],
n: int, n: int,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False, disable_frontend_multiprocessing: bool = False,
...@@ -242,9 +242,9 @@ async def run_vllm_async( ...@@ -242,9 +242,9 @@ async def run_vllm_async(
" prompt_len and expected_output_len for all requests.") " prompt_len and expected_output_len for all requests.")
# Add the requests to the engine. # Add the requests to the engine.
prompts: List[TextPrompt] = [] prompts: list[TextPrompt] = []
sampling_params: List[SamplingParams] = [] sampling_params: list[SamplingParams] = []
lora_requests: List[Optional[LoRARequest]] = [] lora_requests: list[Optional[LoRARequest]] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TextPrompt(prompt=request.prompt, TextPrompt(prompt=request.prompt,
...@@ -276,7 +276,7 @@ async def run_vllm_async( ...@@ -276,7 +276,7 @@ async def run_vllm_async(
def run_hf( def run_hf(
requests: List[SampleRequest], requests: list[SampleRequest],
model: str, model: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
n: int, n: int,
...@@ -292,7 +292,7 @@ def run_hf( ...@@ -292,7 +292,7 @@ def run_hf(
pbar = tqdm(total=len(requests)) pbar = tqdm(total=len(requests))
start = time.perf_counter() start = time.perf_counter()
batch: List[str] = [] batch: list[str] = []
max_prompt_len = 0 max_prompt_len = 0
max_output_len = 0 max_output_len = 0
for i in range(len(requests)): for i in range(len(requests)):
...@@ -334,7 +334,7 @@ def run_hf( ...@@ -334,7 +334,7 @@ def run_hf(
def run_mii( def run_mii(
requests: List[SampleRequest], requests: list[SampleRequest],
model: str, model: str,
tensor_parallel_size: int, tensor_parallel_size: int,
output_len: int, output_len: int,
...@@ -352,7 +352,7 @@ def run_mii( ...@@ -352,7 +352,7 @@ def run_mii(
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: Dict[str, Any]) -> None: results: dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={ metrics={
...@@ -479,8 +479,8 @@ if __name__ == "__main__": ...@@ -479,8 +479,8 @@ if __name__ == "__main__":
type=str, type=str,
default=None, default=None,
help="Path to the dataset. The dataset is expected to " help="Path to the dataset. The dataset is expected to "
"be a json in form of List[Dict[..., conversations: " "be a json in form of list[dict[..., conversations: "
"List[Dict[..., value: <prompt_or_response>]]]]") "list[dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--input-len", parser.add_argument("--input-len",
type=int, type=int,
default=None, default=None,
......
...@@ -4,12 +4,12 @@ import argparse ...@@ -4,12 +4,12 @@ import argparse
import json import json
import math import math
import os import os
from typing import Any, Dict, List from typing import Any
def convert_to_pytorch_benchmark_format(args: argparse.Namespace, def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
metrics: Dict[str, List], metrics: dict[str, list],
extra_info: Dict[str, Any]) -> List: extra_info: dict[str, Any]) -> list:
""" """
Save the benchmark results in the format used by PyTorch OSS benchmark with Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record on metric per record
...@@ -64,6 +64,6 @@ class InfEncoder(json.JSONEncoder): ...@@ -64,6 +64,6 @@ class InfEncoder(json.JSONEncoder):
return super().iterencode(self.clear_inf(o), *args, **kwargs) return super().iterencode(self.clear_inf(o), *args, **kwargs)
def write_to_json(filename: str, records: List) -> None: def write_to_json(filename: str, records: list) -> None:
with open(filename, "w") as f: with open(filename, "w") as f:
json.dump(records, f, cls=InfEncoder) json.dump(records, f, cls=InfEncoder)
...@@ -5,7 +5,8 @@ import copy ...@@ -5,7 +5,8 @@ import copy
import itertools import itertools
import pickle as pkl import pickle as pkl
import time import time
from typing import Callable, Iterable, List, Tuple from collections.abc import Iterable
from typing import Callable
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
...@@ -228,7 +229,7 @@ def print_timers(timers: Iterable[TMeasurement]): ...@@ -228,7 +229,7 @@ def print_timers(timers: Iterable[TMeasurement]):
def run(dtype: torch.dtype, def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
...@@ -241,7 +242,7 @@ def run(dtype: torch.dtype, ...@@ -241,7 +242,7 @@ def run(dtype: torch.dtype,
# output makers # output makers
def make_output(data: Iterable[TMeasurement], def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]], MKNs: Iterable[tuple[int, int, int]],
base_description: str, base_description: str,
timestamp=None): timestamp=None):
print(f"== All Results {base_description} ====") print(f"== All Results {base_description} ====")
...@@ -282,7 +283,7 @@ def run_model_bench(args): ...@@ -282,7 +283,7 @@ def run_model_bench(args):
for i, model in enumerate(args.models): for i, model in enumerate(args.models):
print(f"[{i}] {model}") print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
KNs = [] KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size KN[tp_split_dim] = KN[tp_split_dim] // tp_size
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Cutlass bench utils # Cutlass bench utils
from typing import Iterable, Tuple from collections.abc import Iterable
import torch import torch
...@@ -27,7 +27,7 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor: ...@@ -27,7 +27,7 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
def make_rand_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]: k: int) -> tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5 a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5 b = torch.randn((n, k), device='cuda').t() * 5
...@@ -63,7 +63,7 @@ def prune_to_2_4(tensor): ...@@ -63,7 +63,7 @@ def prune_to_2_4(tensor):
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]: k: int) -> tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5 a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5 b = torch.randn((n, k), device='cuda').t() * 5
...@@ -88,7 +88,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, ...@@ -88,7 +88,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
m: int, n: int, k: int) -> \ m: int, n: int, k: int) -> \
Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
ABs = [] ABs = []
for _ in range(num_tensors): for _ in range(num_tensors):
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
......
...@@ -5,7 +5,8 @@ import copy ...@@ -5,7 +5,8 @@ import copy
import itertools import itertools
import pickle as pkl import pickle as pkl
import time import time
from typing import Callable, Iterable, List, Optional, Tuple from collections.abc import Iterable
from typing import Callable, Optional
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
...@@ -49,7 +50,7 @@ def bench_int8( ...@@ -49,7 +50,7 @@ def bench_int8(
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels.""" """Benchmark INT8-based kernels."""
assert dtype == torch.int8 assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k) a, b = make_rand_tensors(torch.int8, m, n, k)
...@@ -101,7 +102,7 @@ def bench_fp8( ...@@ -101,7 +102,7 @@ def bench_fp8(
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels.""" """Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
...@@ -180,7 +181,7 @@ def bench(dtype: torch.dtype, ...@@ -180,7 +181,7 @@ def bench(dtype: torch.dtype,
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
if dtype == torch.int8: if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
...@@ -195,8 +196,8 @@ def print_timers(timers: Iterable[TMeasurement]): ...@@ -195,8 +196,8 @@ def print_timers(timers: Iterable[TMeasurement]):
def run(dtype: torch.dtype, def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]], MKNs: Iterable[tuple[int, int, int]],
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(dtype, timers = bench(dtype,
...@@ -212,7 +213,7 @@ def run(dtype: torch.dtype, ...@@ -212,7 +213,7 @@ def run(dtype: torch.dtype,
def make_output(data: Iterable[TMeasurement], def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]], MKNs: Iterable[tuple[int, int, int]],
base_description: str, base_description: str,
timestamp=None): timestamp=None):
print(f"== All Results {base_description} ====") print(f"== All Results {base_description} ====")
...@@ -248,7 +249,7 @@ def run_model_bench(args): ...@@ -248,7 +249,7 @@ def run_model_bench(args):
for i, model in enumerate(args.models): for i, model in enumerate(args.models):
print(f"[{i}] {model}") print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
KNs = [] KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size KN[tp_split_dim] = KN[tp_split_dim] // tp_size
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
import pickle as pkl import pickle as pkl
import time import time
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
from typing import Callable, Iterable, List, Optional from typing import Callable, Optional
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
...@@ -29,7 +30,7 @@ class bench_params_t: ...@@ -29,7 +30,7 @@ class bench_params_t:
f'x DT {self.dtype}') f'x DT {self.dtype}')
def get_bench_params() -> List[bench_params_t]: def get_bench_params() -> list[bench_params_t]:
## Test Fixtures ## Test Fixtures
NUM_TOKENS = [2**x for x in range(11)] NUM_TOKENS = [2**x for x in range(11)]
HIDDEN_SIZES = list(range(1024, 8129, 1024)) HIDDEN_SIZES = list(range(1024, 8129, 1024))
......
...@@ -9,7 +9,7 @@ from dataclasses import dataclass ...@@ -9,7 +9,7 @@ from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Optional
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
...@@ -61,15 +61,15 @@ def make_rand_lora_weight_tensor(k: int, ...@@ -61,15 +61,15 @@ def make_rand_lora_weight_tensor(k: int,
def make_rand_tensors( def make_rand_tensors(
a_shape: Tuple[int], a_shape: tuple[int],
b_shape: Tuple[int], b_shape: tuple[int],
c_shape: Tuple[int], c_shape: tuple[int],
a_dtype: torch.dtype, a_dtype: torch.dtype,
b_dtype: torch.dtype, b_dtype: torch.dtype,
c_dtype: torch.dtype, c_dtype: torch.dtype,
num_slices: int, num_slices: int,
device: str = "cuda", device: str = "cuda",
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: ) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
""" """
Make LoRA input/output matrices. Make LoRA input/output matrices.
""" """
...@@ -135,7 +135,7 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int, ...@@ -135,7 +135,7 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int,
def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
lora_weights: List[torch.Tensor], lora_weights: list[torch.Tensor],
seq_lens_cpu: torch.Tensor, seq_lens_cpu: torch.Tensor,
prompt_lora_mapping_cpu: torch.Tensor, scaling: float, prompt_lora_mapping_cpu: torch.Tensor, scaling: float,
add_inputs: Optional[bool]): add_inputs: Optional[bool]):
...@@ -204,7 +204,7 @@ class OpType(Enum): ...@@ -204,7 +204,7 @@ class OpType(Enum):
def is_expand_slice_fn(self) -> bool: def is_expand_slice_fn(self) -> bool:
return self in [OpType.BGMV_EXPAND_SLICE] return self in [OpType.BGMV_EXPAND_SLICE]
def num_slices(self) -> List[int]: def num_slices(self) -> list[int]:
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]: if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
# SGMV kernels supports slices # SGMV kernels supports slices
return [1, 2, 3] return [1, 2, 3]
...@@ -215,7 +215,7 @@ class OpType(Enum): ...@@ -215,7 +215,7 @@ class OpType(Enum):
raise ValueError(f"Unrecognized OpType {self}") raise ValueError(f"Unrecognized OpType {self}")
def mkn(self, batch_size: int, seq_length: int, hidden_size: int, def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
lora_rank: int) -> Tuple[int, int, int]: lora_rank: int) -> tuple[int, int, int]:
num_tokens = batch_size * seq_length num_tokens = batch_size * seq_length
if self.is_shrink_fn(): if self.is_shrink_fn():
m = num_tokens m = num_tokens
...@@ -230,7 +230,7 @@ class OpType(Enum): ...@@ -230,7 +230,7 @@ class OpType(Enum):
def matmul_dtypes( def matmul_dtypes(
self, op_dtype: torch.dtype self, op_dtype: torch.dtype
) -> Tuple[torch.dtype, torch.dtype, torch.dtype]: ) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
""" """
return a type, b type and c type for A x B = C return a type, b type and c type for A x B = C
""" """
...@@ -243,7 +243,7 @@ class OpType(Enum): ...@@ -243,7 +243,7 @@ class OpType(Enum):
def matmul_shapes( def matmul_shapes(
self, batch_size: int, seq_length: int, hidden_size: int, self, batch_size: int, seq_length: int, hidden_size: int,
lora_rank: int, num_loras: int, lora_rank: int, num_loras: int,
num_slices: int) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]: num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]:
""" """
Given num_slices, return the shapes of the A, B, and C matrices Given num_slices, return the shapes of the A, B, and C matrices
in A x B = C, for the op_type in A x B = C, for the op_type
...@@ -268,7 +268,7 @@ class OpType(Enum): ...@@ -268,7 +268,7 @@ class OpType(Enum):
def bench_fn(self) -> Callable: def bench_fn(self) -> Callable:
def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
for x in kwargs_list: for x in kwargs_list:
bgmv_expand_slice(**x) bgmv_expand_slice(**x)
...@@ -285,7 +285,7 @@ class OpType(Enum): ...@@ -285,7 +285,7 @@ class OpType(Enum):
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
lora_weights: List[torch.Tensor], lora_weights: list[torch.Tensor],
**kwargs) -> Callable: **kwargs) -> Callable:
"""Each benchmark operation expected the input, lora_weights and outputs """Each benchmark operation expected the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes(). in a slightly different format. Refer to self.matmul_shapes().
...@@ -384,7 +384,7 @@ class BenchmarkTensors: ...@@ -384,7 +384,7 @@ class BenchmarkTensors:
""" """
# matmul tensors # matmul tensors
input: torch.Tensor input: torch.Tensor
lora_weights_lst: List[torch.Tensor] lora_weights_lst: list[torch.Tensor]
output: torch.Tensor output: torch.Tensor
# metadata tensors # metadata tensors
seq_lens: torch.Tensor seq_lens: torch.Tensor
...@@ -469,7 +469,7 @@ class BenchmarkTensors: ...@@ -469,7 +469,7 @@ class BenchmarkTensors:
for i in range(len(self.lora_weights_lst)): for i in range(len(self.lora_weights_lst)):
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
def metadata(self) -> Tuple[int, int, int]: def metadata(self) -> tuple[int, int, int]:
""" """
Return num_seqs, num_tokens and max_seq_len Return num_seqs, num_tokens and max_seq_len
""" """
...@@ -505,7 +505,7 @@ class BenchmarkTensors: ...@@ -505,7 +505,7 @@ class BenchmarkTensors:
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype) self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype) self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]: def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
self.convert_to_sgmv_benchmark_tensors() self.convert_to_sgmv_benchmark_tensors()
self.sanity_check() self.sanity_check()
self.to_device(self.input.device) self.to_device(self.input.device)
...@@ -540,7 +540,7 @@ class BenchmarkTensors: ...@@ -540,7 +540,7 @@ class BenchmarkTensors:
'scaling': 1.0, 'scaling': 1.0,
} }
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]: def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
self.convert_to_sgmv_benchmark_tensors() self.convert_to_sgmv_benchmark_tensors()
self.sanity_check() self.sanity_check()
...@@ -578,7 +578,7 @@ class BenchmarkTensors: ...@@ -578,7 +578,7 @@ class BenchmarkTensors:
'add_inputs': add_inputs, 'add_inputs': add_inputs,
} }
def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]: def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
assert len(self.lora_weights_lst) == 1 assert len(self.lora_weights_lst) == 1
self.to_device(self.input.device) self.to_device(self.input.device)
...@@ -634,7 +634,7 @@ class BenchmarkTensors: ...@@ -634,7 +634,7 @@ class BenchmarkTensors:
'add_inputs': add_inputs 'add_inputs': add_inputs
} }
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
_, num_tokens, _, num_slices = self.metadata() _, num_tokens, _, num_slices = self.metadata()
# Sanity check shapes # Sanity check shapes
...@@ -670,7 +670,7 @@ class BenchmarkTensors: ...@@ -670,7 +670,7 @@ class BenchmarkTensors:
def bench_fn_kwargs(self, def bench_fn_kwargs(self,
op_type: OpType, op_type: OpType,
add_inputs: Optional[bool] = None) -> Dict[str, Any]: add_inputs: Optional[bool] = None) -> dict[str, Any]:
if op_type.is_shrink_fn(): if op_type.is_shrink_fn():
assert add_inputs is None assert add_inputs is None
else: else:
...@@ -734,7 +734,7 @@ def bench_optype(ctx: BenchmarkContext, ...@@ -734,7 +734,7 @@ def bench_optype(ctx: BenchmarkContext,
assert expand_fn_add_inputs is not None assert expand_fn_add_inputs is not None
# BenchmarkContext -> BenchmarkTensors # BenchmarkContext -> BenchmarkTensors
bench_tensors : List[BenchmarkTensors] = \ bench_tensors : list[BenchmarkTensors] = \
[BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)] [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)]
for bt in bench_tensors: for bt in bench_tensors:
bt.sanity_check() bt.sanity_check()
...@@ -746,7 +746,7 @@ def bench_optype(ctx: BenchmarkContext, ...@@ -746,7 +746,7 @@ def bench_optype(ctx: BenchmarkContext,
for bt in bench_tensors for bt in bench_tensors
]) ])
# BenchmarkTensors -> Dict (kwargs) # BenchmarkTensors -> dict (kwargs)
kwargs_list = [ kwargs_list = [
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
for bt in bench_tensors for bt in bench_tensors
...@@ -841,7 +841,7 @@ def use_cuda_graph_recommendation() -> str: ...@@ -841,7 +841,7 @@ def use_cuda_graph_recommendation() -> str:
""" """
def print_timers(timers: List[TMeasurement], def print_timers(timers: list[TMeasurement],
args: Optional[argparse.Namespace] = None): args: Optional[argparse.Namespace] = None):
compare = TBenchmark.Compare(timers) compare = TBenchmark.Compare(timers)
compare.print() compare.print()
...@@ -861,7 +861,7 @@ def print_timers(timers: List[TMeasurement], ...@@ -861,7 +861,7 @@ def print_timers(timers: List[TMeasurement],
"small num_loras the goal should be to match the torch.mm numbers.") "small num_loras the goal should be to match the torch.mm numbers.")
def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
if args.cuda_graph_nops is not None: if args.cuda_graph_nops is not None:
assert args.cuda_graph_nops > 0 assert args.cuda_graph_nops > 0
...@@ -873,7 +873,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): ...@@ -873,7 +873,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
timers = [] timers = []
for bench_ctx in bench_ctxs: for bench_ctx in bench_ctxs:
for seq_len in args.seq_lengths: for seq_len in args.seq_lengths:
bench_ops: List[OpType] = [] bench_ops: list[OpType] = []
if seq_len == 1: if seq_len == 1:
# bench all decode ops # bench all decode ops
bench_ops = [op for op in args.op_types if op.is_decode_op()] bench_ops = [op for op in args.op_types if op.is_decode_op()]
...@@ -921,10 +921,10 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): ...@@ -921,10 +921,10 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
pickle.dump(timers, f) pickle.dump(timers, f)
def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int], def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
args: argparse.Namespace) -> List[BenchmarkContext]: args: argparse.Namespace) -> list[BenchmarkContext]:
ctxs: List[BenchmarkContext] = [] ctxs: list[BenchmarkContext] = []
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras,
args.sort_by_lora_id): args.sort_by_lora_id):
...@@ -954,7 +954,7 @@ def run_list_bench(args: argparse.Namespace): ...@@ -954,7 +954,7 @@ def run_list_bench(args: argparse.Namespace):
f" LoRA Ranks {args.lora_ranks}") f" LoRA Ranks {args.lora_ranks}")
# Get all benchmarking contexts # Get all benchmarking contexts
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args) hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args)
run(args, bench_contexts) run(args, bench_contexts)
...@@ -975,7 +975,7 @@ def run_range_bench(args: argparse.Namespace): ...@@ -975,7 +975,7 @@ def run_range_bench(args: argparse.Namespace):
f" LoRA Ranks {lora_ranks}") f" LoRA Ranks {lora_ranks}")
# Get all benchmarking contexts # Get all benchmarking contexts
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args) hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args)
run(args, bench_contexts) run(args, bench_contexts)
...@@ -1002,7 +1002,7 @@ def run_model_bench(args: argparse.Namespace): ...@@ -1002,7 +1002,7 @@ def run_model_bench(args: argparse.Namespace):
f" LoRA Ranks {args.lora_ranks}") f" LoRA Ranks {args.lora_ranks}")
# Get all benchmarking contexts # Get all benchmarking contexts
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args) hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args)
run(args, bench_contexts) run(args, bench_contexts)
......
...@@ -7,9 +7,10 @@ import math ...@@ -7,9 +7,10 @@ import math
import os import os
import pickle as pkl import pickle as pkl
import time import time
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple from typing import Callable, Optional
import pandas as pd import pandas as pd
import torch import torch
...@@ -102,8 +103,8 @@ def quantize_and_pack(atype: torch.dtype, ...@@ -102,8 +103,8 @@ def quantize_and_pack(atype: torch.dtype,
return w_ref, w_q, w_s, w_zp return w_ref, w_q, w_s, w_zp
def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig, def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
group_size: Optional[int]) -> List[BenchmarkTensors]: group_size: Optional[int]) -> list[BenchmarkTensors]:
m, n, k = shape m, n, k = shape
# we want to make sure that weights don't fit into L2 cache between runs so # we want to make sure that weights don't fit into L2 cache between runs so
...@@ -114,7 +115,7 @@ def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig, ...@@ -114,7 +115,7 @@ def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
a = rand_data((m, k), types.act_type, scale=5) a = rand_data((m, k), types.act_type, scale=5)
benchmark_tensors: List[BenchmarkTensors] = [] benchmark_tensors: list[BenchmarkTensors] = []
for _ in range(num_weights): for _ in range(num_weights):
w = rand_data((k, n), types.act_type, scale=5) w = rand_data((k, n), types.act_type, scale=5)
...@@ -276,7 +277,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors, ...@@ -276,7 +277,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
def bench_fns(label: str, sub_label: str, description: str, def bench_fns(label: str, sub_label: str, description: str,
fns: List[Callable]): fns: list[Callable]):
min_run_time = 1 if not NVTX_PROFILE else 0.1 min_run_time = 1 if not NVTX_PROFILE else 0.1
res = TBenchmark.Timer( res = TBenchmark.Timer(
...@@ -311,7 +312,7 @@ def bench(types: TypeConfig, ...@@ -311,7 +312,7 @@ def bench(types: TypeConfig,
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
sweep_schedules: bool = True) -> List[TMeasurement]: sweep_schedules: bool = True) -> list[TMeasurement]:
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
sub_label += f", L={len(benchmark_tensors)}" sub_label += f", L={len(benchmark_tensors)}"
...@@ -414,12 +415,12 @@ def bench(types: TypeConfig, ...@@ -414,12 +415,12 @@ def bench(types: TypeConfig,
# runner # runner
def print_timers(timers: List[TMeasurement]): def print_timers(timers: list[TMeasurement]):
compare = TBenchmark.Compare(timers) compare = TBenchmark.Compare(timers)
compare.print() compare.print()
def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
types = TypeConfig( types = TypeConfig(
act_type=args.act_type, act_type=args.act_type,
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
...@@ -431,7 +432,7 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: ...@@ -431,7 +432,7 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
token_scale_type=args.token_scale_type, token_scale_type=args.token_scale_type,
) )
results: List[TMeasurement] = [] results: list[TMeasurement] = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(types, timers = bench(types,
args.group_size, args.group_size,
...@@ -449,8 +450,8 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: ...@@ -449,8 +450,8 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
# output makers # output makers
def make_output( def make_output(
data: List[TMeasurement], data: list[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]], MKNs: Iterable[tuple[int, int, int]],
base_description: str, base_description: str,
timestamp=None, timestamp=None,
): ):
...@@ -497,7 +498,7 @@ def run_model_bench(args): ...@@ -497,7 +498,7 @@ def run_model_bench(args):
for i, model in enumerate(args.models): for i, model in enumerate(args.models):
print(f"[{i}] {model}") print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
KNs = [] KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size KN[tp_split_dim] = KN[tp_split_dim] // tp_size
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
import torch import torch
import torch.utils.benchmark as benchmark import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES from benchmark_shapes import WEIGHT_SHAPES
...@@ -31,7 +29,7 @@ ACT_ORDER_OPTS = [False, True] ...@@ -31,7 +29,7 @@ ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True] K_FULL_OPTS = [False, True]
def bench_run(results: List[benchmark.Measurement], model: str, def bench_run(results: list[benchmark.Measurement], model: str,
act_order: bool, is_k_full: bool, quant_type: ScalarType, act_order: bool, is_k_full: bool, quant_type: ScalarType,
group_size: int, size_m: int, size_k: int, size_n: int): group_size: int, size_m: int, size_k: int, size_n: int):
label = "Quant Matmul" label = "Quant Matmul"
...@@ -221,7 +219,7 @@ def main(args): ...@@ -221,7 +219,7 @@ def main(args):
for i, model in enumerate(args.models): for i, model in enumerate(args.models):
print(f"[{i}] {model}") print(f"[{i}] {model}")
results: List[benchmark.Measurement] = [] results: list[benchmark.Measurement] = []
for model in args.models: for model in args.models:
for layer in WEIGHT_SHAPES[model]: for layer in WEIGHT_SHAPES[model]:
......
...@@ -4,7 +4,7 @@ import argparse ...@@ -4,7 +4,7 @@ import argparse
import time import time
from datetime import datetime from datetime import datetime
from itertools import product from itertools import product
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, TypedDict
import ray import ray
import torch import torch
...@@ -132,7 +132,7 @@ def benchmark_config( ...@@ -132,7 +132,7 @@ def benchmark_config(
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
prepare(i) prepare(i)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -175,8 +175,8 @@ def get_rocm_tuning_space(use_fp16): ...@@ -175,8 +175,8 @@ def get_rocm_tuning_space(use_fp16):
return param_ranges return param_ranges
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]: def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
configs: List[BenchmarkConfig] = [] configs: list[BenchmarkConfig] = []
if current_platform.is_rocm(): if current_platform.is_rocm():
param_ranges = get_rocm_tuning_space(use_fp16) param_ranges = get_rocm_tuning_space(use_fp16)
...@@ -335,7 +335,7 @@ class BenchmarkWorker: ...@@ -335,7 +335,7 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]: ) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed) current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype, dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
...@@ -371,8 +371,8 @@ class BenchmarkWorker: ...@@ -371,8 +371,8 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
search_space: List[Dict[str, int]], search_space: list[dict[str, int]],
) -> Dict[str, int]: ) -> dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
if current_platform.is_rocm(): if current_platform.is_rocm():
...@@ -434,7 +434,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: ...@@ -434,7 +434,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
} }
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int, shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool, dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None: use_int8_w8a16: bool) -> None:
...@@ -498,7 +498,7 @@ def main(args: argparse.Namespace): ...@@ -498,7 +498,7 @@ def main(args: argparse.Namespace):
num_gpus = int(ray.available_resources()["GPU"]) num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
def _distribute(method: str, inputs: List[Any]) -> List[Any]: def _distribute(method: str, inputs: list[Any]) -> list[Any]:
outputs = [] outputs = []
worker_idx = 0 worker_idx = 0
for input_args in inputs: for input_args in inputs:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import random import random
import time import time
from typing import List, Optional from typing import Optional
import torch import torch
...@@ -54,7 +54,7 @@ def main( ...@@ -54,7 +54,7 @@ def main(
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst: List[List[int]] = [] block_tables_lst: list[list[int]] = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
random.randint(0, NUM_BLOCKS - 1) random.randint(0, NUM_BLOCKS - 1)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import itertools import itertools
from typing import Optional, Tuple, Union from typing import Optional, Union
import torch import torch
import triton import triton
...@@ -22,7 +22,7 @@ class HuggingFaceRMSNorm(nn.Module): ...@@ -22,7 +22,7 @@ class HuggingFaceRMSNorm(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
if residual is not None: if residual is not None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from itertools import accumulate from itertools import accumulate
from typing import List, Optional from typing import Optional
import nvtx import nvtx
import torch import torch
...@@ -39,7 +39,7 @@ def benchmark_rope_kernels_multi_lora( ...@@ -39,7 +39,7 @@ def benchmark_rope_kernels_multi_lora(
}) })
# non-batched RoPE takes only one scaling factor, we create multiple # non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior # instances to simulate the same behavior
non_batched_ropes: List[RotaryEmbedding] = [] non_batched_ropes: list[RotaryEmbedding] = []
for scaling_factor in scaling_factors: for scaling_factor in scaling_factors:
non_batched_ropes.append( non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style, get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment