Unverified Commit 05ccd0aa authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1] Ensure using int64 for sampled token ids (#15065)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent f690372b
...@@ -47,6 +47,11 @@ class Sampler(nn.Module): ...@@ -47,6 +47,11 @@ class Sampler(nn.Module):
logits = self.apply_penalties(logits, sampling_metadata) logits = self.apply_penalties(logits, sampling_metadata)
# Sample the next token. # Sample the next token.
sampled = self.sample(logits, sampling_metadata) sampled = self.sample(logits, sampling_metadata)
# Convert sampled token ids to int64 (long) type to ensure compatibility
# with subsequent operations that may use these values as indices.
# This conversion is necessary because FlashInfer sampling operations
# return int32 (while PyTorch argmax and topk return int64).
sampled = sampled.long()
# Gather the logprobs of the topk and sampled token (if requested). # Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested) # Get logprobs and rank tensors (if requested)
...@@ -139,19 +144,21 @@ class Sampler(nn.Module): ...@@ -139,19 +144,21 @@ class Sampler(nn.Module):
or sampled tokens (if sampled or sampled tokens (if sampled
logprobs); 1D token ID tensor logprobs); 1D token ID tensor
with (num tokens) elements with (num tokens) elements
Must be int64.
Returns: Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1) Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens) Sampled token rank tensor, (num tokens)
""" """
assert token_ids.dtype == torch.int64
# Find the topK values. # Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs, topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs, num_logprobs,
dim=-1) dim=-1)
# Get with the logprob of the prompt or sampled token. # Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1).to(torch.long) token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids) token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token. # Compute the ranks of the actual token.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment