Unverified Commit 9e0f44be authored by Komal Kumar Teru's avatar Komal Kumar Teru Committed by GitHub
Browse files

[cohere][fix][spec-decode]: fix crash when allowed_token_ids is set without penalties (#35654)


Signed-off-by: default avatarkkt-cohere <komal@cohere.com>
parent 097eb544
...@@ -271,7 +271,7 @@ class RejectionSampler(nn.Module): ...@@ -271,7 +271,7 @@ class RejectionSampler(nn.Module):
# Calculate indices of target logits. # Calculate indices of target logits.
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties: if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
num_requests = len(sampling_metadata.output_token_ids) num_requests = len(metadata.num_draft_tokens)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu") num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu") original_indices = torch.arange(num_requests, device="cpu")
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens) repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment