Unverified Commit 027827cc authored by Chujie Zheng's avatar Chujie Zheng Committed by GitHub
Browse files

fix long dtype in topk sampling (#15049)

parent 72a8639b
...@@ -151,7 +151,7 @@ class Sampler(nn.Module): ...@@ -151,7 +151,7 @@ class Sampler(nn.Module):
dim=-1) dim=-1)
# Get with the logprob of the prompt or sampled token. # Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1) token_ids = token_ids.unsqueeze(-1).to(torch.long)
token_logprobs = logprobs.gather(-1, token_ids) token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token. # Compute the ranks of the actual token.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment