Commit 93872128 authored by zhuwenwen's avatar zhuwenwen
Browse files

support head_dim 160 and update benchmark_throughput.py

parent a087cda8
...@@ -348,12 +348,11 @@ def main(args: argparse.Namespace): ...@@ -348,12 +348,11 @@ def main(args: argparse.Namespace):
# Sample the requests. # Sample the requests.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code)
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
if args.dataset is None: if args.dataset is None:
# Synthesize a prompt with the given input length. # Synthesize a prompt with the given input length.
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
prompt = "hi" * (args.input_len - 1) prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len) requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)] for _ in range(args.num_prompts)]
...@@ -363,7 +362,7 @@ def main(args: argparse.Namespace): ...@@ -363,7 +362,7 @@ def main(args: argparse.Namespace):
if args.backend == "vllm": if args.backend == "vllm":
run_args = [ run_args = [
requests, args.model, args.tokenizer, args.quantization, warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.enforce_eager, args.kv_cache_dtype,
......
...@@ -757,6 +757,9 @@ void paged_attention_v1_launcher( ...@@ -757,6 +757,9 @@ void paged_attention_v1_launcher(
case 128: case 128:
LAUNCH_PAGED_ATTENTION_V1(128); LAUNCH_PAGED_ATTENTION_V1(128);
break; break;
case 160:
LAUNCH_PAGED_ATTENTION_V1(160);
break;
case 192: case 192:
LAUNCH_PAGED_ATTENTION_V1(192); LAUNCH_PAGED_ATTENTION_V1(192);
break; break;
...@@ -921,6 +924,9 @@ void paged_attention_v2_launcher( ...@@ -921,6 +924,9 @@ void paged_attention_v2_launcher(
case 128: case 128:
LAUNCH_PAGED_ATTENTION_V2(128); LAUNCH_PAGED_ATTENTION_V2(128);
break; break;
case 160:
LAUNCH_PAGED_ATTENTION_V2(160);
break;
case 192: case 192:
LAUNCH_PAGED_ATTENTION_V2(192); LAUNCH_PAGED_ATTENTION_V2(192);
break; break;
......
...@@ -37,9 +37,6 @@ ...@@ -37,9 +37,6 @@
} else if (HEADDIM == 112) { \ } else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \ constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (HEADDIM == 120) { \
constexpr static int HEAD_SIZE = 120; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \ } else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \ constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
......
...@@ -348,12 +348,11 @@ def main(args: argparse.Namespace): ...@@ -348,12 +348,11 @@ def main(args: argparse.Namespace):
# Sample the requests. # Sample the requests.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code)
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
if args.dataset is None: if args.dataset is None:
# Synthesize a prompt with the given input length. # Synthesize a prompt with the given input length.
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
prompt = "hi" * (args.input_len - 1) prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len) requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)] for _ in range(args.num_prompts)]
...@@ -363,7 +362,7 @@ def main(args: argparse.Namespace): ...@@ -363,7 +362,7 @@ def main(args: argparse.Namespace):
if args.backend == "vllm": if args.backend == "vllm":
run_args = [ run_args = [
requests, args.model, args.tokenizer, args.quantization, warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.enforce_eager, args.kv_cache_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment