"vscode:/vscode.git/clone" did not exist on "4383f1532e87e77b6f961e633230f47467cbd072"
Unverified Commit 9e0f44be authored by Komal Kumar Teru's avatar Komal Kumar Teru Committed by GitHub
Browse files

[cohere][fix][spec-decode]: fix crash when allowed_token_ids is set without penalties (#35654)


Signed-off-by: default avatarkkt-cohere <komal@cohere.com>
parent 097eb544
......@@ -271,7 +271,7 @@ class RejectionSampler(nn.Module):
# Calculate indices of target logits.
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
num_requests = len(sampling_metadata.output_token_ids)
num_requests = len(metadata.num_draft_tokens)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu")
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment