Unverified Commit 99757cc3 authored by narutolhy's avatar narutolhy Committed by GitHub
Browse files

fix probs name which without temp scaling name (#9984)

parent cdddab05
...@@ -80,9 +80,9 @@ class Sampler(nn.Module): ...@@ -80,9 +80,9 @@ class Sampler(nn.Module):
logprobs = torch.nn.functional.log_softmax(logits, dim=-1) logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else: else:
# Post process original logits. if temperatures are all 1.0, no need to rescale # If requested, cache probabilities from original logits before temperature scaling.
if return_logprob and RETURN_ORIGINAL_LOGPROB: if return_logprob and RETURN_ORIGINAL_LOGPROB:
logprobs = torch.softmax(logits, dim=-1) probs_without_temp_scaling = torch.softmax(logits, dim=-1)
# Post process logits # Post process logits
logits.div_(sampling_info.temperatures) logits.div_(sampling_info.temperatures)
...@@ -123,9 +123,10 @@ class Sampler(nn.Module): ...@@ -123,9 +123,10 @@ class Sampler(nn.Module):
if return_logprob: if return_logprob:
# clamp to avoid -inf # clamp to avoid -inf
if RETURN_ORIGINAL_LOGPROB: if RETURN_ORIGINAL_LOGPROB:
logprobs = torch.log(logprobs).clamp( logprobs = torch.log(probs_without_temp_scaling).clamp(
min=torch.finfo(logprobs.dtype).min min=torch.finfo(probs_without_temp_scaling.dtype).min
) )
del probs_without_temp_scaling
else: else:
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment