Unverified Commit a90c97d7 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Use FP32 for log probabilities (#19)

parent e3f00d19
...@@ -36,10 +36,11 @@ class Sampler(nn.Module): ...@@ -36,10 +36,11 @@ class Sampler(nn.Module):
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1)) logits.div_(t.unsqueeze(dim=1))
# We use float32 for probabilities and log probabilities.
# Compute the probabilities. # Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p). # Compute the log probabilities (before applying top-p).
logprobs = torch.log(probs, out=logits) logprobs = torch.log(probs)
# Apply top-p truncation. # Apply top-p truncation.
top_ps = _get_top_ps(input_metadata) top_ps = _get_top_ps(input_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment