Commit 93872128 authored by zhuwenwen's avatar zhuwenwen
Browse files

support head_dim 160 and update benchmark_throughput.py

parent a087cda8
......@@ -348,12 +348,11 @@ def main(args: argparse.Namespace):
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
if args.dataset is None:
# Synthesize a prompt with the given input length.
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
......@@ -363,7 +362,7 @@ def main(args: argparse.Namespace):
if args.backend == "vllm":
run_args = [
requests, args.model, args.tokenizer, args.quantization,
warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
......
......@@ -757,6 +757,9 @@ void paged_attention_v1_launcher(
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
case 160:
LAUNCH_PAGED_ATTENTION_V1(160);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V1(192);
break;
......@@ -921,6 +924,9 @@ void paged_attention_v2_launcher(
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
case 160:
LAUNCH_PAGED_ATTENTION_V2(160);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V2(192);
break;
......
......@@ -37,9 +37,6 @@
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 120) { \
constexpr static int HEAD_SIZE = 120; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
......
......@@ -348,12 +348,11 @@ def main(args: argparse.Namespace):
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
if args.dataset is None:
# Synthesize a prompt with the given input length.
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
......@@ -363,7 +362,7 @@ def main(args: argparse.Namespace):
if args.backend == "vllm":
run_args = [
requests, args.model, args.tokenizer, args.quantization,
warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment