Commit f246ee95 authored by zhuwenwen's avatar zhuwenwen
Browse files

update bench run error

parent 10184690
...@@ -103,9 +103,11 @@ def run_vllm( ...@@ -103,9 +103,11 @@ def run_vllm(
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]
use_beam_search = False
print("Warming up...") print("Warming up...")
for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"): for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"):
if not args.use_beam_search: if not use_beam_search:
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
else: else:
llm.beam_search( llm.beam_search(
...@@ -117,8 +119,6 @@ def run_vllm( ...@@ -117,8 +119,6 @@ def run_vllm(
), ),
) )
use_beam_search = False
outputs = None outputs = None
if not use_beam_search: if not use_beam_search:
if args.profile: if args.profile:
...@@ -790,4 +790,4 @@ if __name__ == "__main__": ...@@ -790,4 +790,4 @@ if __name__ == "__main__":
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
validate_args(args) validate_args(args)
main(args) main(args)
\ No newline at end of file
...@@ -788,6 +788,12 @@ class EngineArgs: ...@@ -788,6 +788,12 @@ class EngineArgs:
default=None, default=None,
help="The configurations for speculative decoding. Should be a " help="The configurations for speculative decoding. Should be a "
"JSON string.") "JSON string.")
parser.add_argument(
'--num-speculative-heads',
type=int,
default=EngineArgs.num_speculative_heads,
help='The number of speculative heads to sample from '
'the draft model in speculative decoding.')
# Observability arguments # Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig) observability_kwargs = get_kwargs(ObservabilityConfig)
......
...@@ -103,9 +103,11 @@ def run_vllm( ...@@ -103,9 +103,11 @@ def run_vllm(
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]
use_beam_search = False
print("Warming up...") print("Warming up...")
for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"): for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"):
if not args.use_beam_search: if not use_beam_search:
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
else: else:
llm.beam_search( llm.beam_search(
...@@ -117,8 +119,6 @@ def run_vllm( ...@@ -117,8 +119,6 @@ def run_vllm(
), ),
) )
use_beam_search = False
outputs = None outputs = None
if not use_beam_search: if not use_beam_search:
if args.profile: if args.profile:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment