Unverified Commit 6a7796e8 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bug]: Limit num_reqs in dummy_run when max_num_seqs is small (#26144)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent 47b93395
...@@ -3060,7 +3060,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3060,7 +3060,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert not uniform_decode assert not uniform_decode
# Create mixed batch: # Create mixed batch:
# first half decode tokens, second half one prefill # first half decode tokens, second half one prefill
num_decode_tokens = num_tokens // 2 num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2)
num_prefill_tokens = num_tokens - num_decode_tokens num_prefill_tokens = num_tokens - num_decode_tokens
num_reqs = num_decode_tokens + 1 num_reqs = num_decode_tokens + 1
...@@ -3072,7 +3072,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3072,7 +3072,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_query_len = num_prefill_tokens max_query_len = num_prefill_tokens
elif uniform_decode: elif uniform_decode:
assert not create_mixed_batch assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len) num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len))
num_scheduled_tokens_list = [max_query_len] * num_reqs num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0: if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len num_scheduled_tokens_list[-1] = num_tokens % max_query_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment