Unverified Commit 3295cd8a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Allow skipping warmup in bench_offline_throughput.py (#2103)

parent 5942dfc0
...@@ -57,6 +57,7 @@ class BenchArgs: ...@@ -57,6 +57,7 @@ class BenchArgs:
disable_ignore_eos: bool = False disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None extra_request_body: Optional[str] = None
seed: int = 1 seed: int = 1
skip_warmup: bool = False
do_not_exit: bool = False do_not_exit: bool = False
@staticmethod @staticmethod
...@@ -152,6 +153,11 @@ class BenchArgs: ...@@ -152,6 +153,11 @@ class BenchArgs:
"additional generate params like sampling params.", "additional generate params like sampling params.",
) )
parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--skip-warmup",
action="store_true",
help="Skip the warmup batches.",
)
parser.add_argument( parser.add_argument(
"--do-not-exit", "--do-not-exit",
action="store_true", action="store_true",
...@@ -261,14 +267,15 @@ def throughput_test( ...@@ -261,14 +267,15 @@ def throughput_test(
) )
# Warm up # Warm up
logging.info("\nWarmup...") if not bench_args.skip_warmup:
throughput_test_once( logging.info("\nWarmup...")
backend_name=bench_args.backend, throughput_test_once(
backend=backend, backend_name=bench_args.backend,
reqs=warmup_requests, backend=backend,
ignore_eos=not bench_args.disable_ignore_eos, reqs=warmup_requests,
extra_request_body=extra_request_body, ignore_eos=not bench_args.disable_ignore_eos,
) extra_request_body=extra_request_body,
)
logging.info("\nBenchmark...") logging.info("\nBenchmark...")
result = throughput_test_once( result = throughput_test_once(
......
...@@ -156,9 +156,6 @@ class TpModelWorkerClient: ...@@ -156,9 +156,6 @@ class TpModelWorkerClient:
return logits_output, next_token_ids return logits_output, next_token_ids
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch.cuda.current_stream().synchronize()
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info = model_worker_batch.sampling_info sampling_info = model_worker_batch.sampling_info
sampling_info.update_penalties() sampling_info.update_penalties()
...@@ -169,6 +166,9 @@ class TpModelWorkerClient: ...@@ -169,6 +166,9 @@ class TpModelWorkerClient:
linear_penalties=sampling_info.linear_penalties, linear_penalties=sampling_info.linear_penalties,
) )
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch.cuda.current_stream().synchronize()
# Push a new batch to the queue # Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment