Unverified Commit f04908ca authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[FIX] Minor bug fixes (#1035)

* [FIX] Minor bug fixes

* Address review comments
parent ab019eea
...@@ -82,8 +82,9 @@ class Sampler(nn.Module): ...@@ -82,8 +82,9 @@ class Sampler(nn.Module):
# We use float32 for probabilities and log probabilities. # We use float32 for probabilities and log probabilities.
# Compute the probabilities. # Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k). # Compute the log probabilities.
logprobs = torch.log(probs) # Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
return _sample(probs, logprobs, input_metadata) return _sample(probs, logprobs, input_metadata)
......
...@@ -350,7 +350,7 @@ class SequenceOutputs: ...@@ -350,7 +350,7 @@ class SequenceOutputs:
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs): if not isinstance(other, SequenceOutputs):
return NotImplementedError() raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token and self.output_token == other.output_token
and self.logprobs == other.logprobs) and self.logprobs == other.logprobs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment