Unverified Commit 0879736a authored by Corey Lowman's avatar Corey Lowman Committed by GitHub
Browse files

[Perf] Remove hardcoded num_warps=1 (#26183)


Signed-off-by: default avatarCorey Lowman <clowman1993@gmail.com>
parent a2691733
......@@ -164,12 +164,12 @@ def rejection_sample(
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.empty(
output_token_ids = torch.full(
(batch_size, max_spec_len + 1),
PLACEHOLDER_TOKEN_ID,
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
device=device,
)
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
if sampling_metadata.all_greedy:
is_greedy = None
......@@ -186,7 +186,6 @@ def rejection_sample(
bonus_token_ids,
is_greedy,
max_spec_len,
num_warps=1,
)
if sampling_metadata.all_greedy:
return output_token_ids
......@@ -227,7 +226,6 @@ def rejection_sample(
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids
......@@ -329,7 +327,6 @@ def expand_batch_to_tokens(
replace_from,
replace_to,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
num_warps=1,
)
return expanded_x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment