Unverified Commit 91732486 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] reduce dp capture bs (#5634)


Co-authored-by: default avataralcanerian <alcanerian@gmail.com>
parent 2ed96c7a
...@@ -134,7 +134,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -134,7 +134,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
) )
gpu_mem = get_device_memory_capacity() gpu_mem = get_device_memory_capacity()
if gpu_mem is not None and gpu_mem > 81920: # Batch size of each rank will not become so large when DP is on
if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
capture_bs += list(range(160, 257, 8)) capture_bs += list(range(160, 257, 8))
if max(capture_bs) > model_runner.req_to_token_pool.size: if max(capture_bs) > model_runner.req_to_token_pool.size:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment