Unverified Commit f04908ca authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[FIX] Minor bug fixes (#1035)

* [FIX] Minor bug fixes

* Address review comments
parent ab019eea
......@@ -82,8 +82,9 @@ class Sampler(nn.Module):
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
......
......@@ -350,7 +350,7 @@ class SequenceOutputs:
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
return NotImplementedError()
raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment