Unverified Commit 6a854c7a authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[V1][Sampler] Don't apply temp for greedy-only (#13311)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent e7eea5a5
...@@ -41,8 +41,6 @@ class Sampler(nn.Module): ...@@ -41,8 +41,6 @@ class Sampler(nn.Module):
logits = self.apply_logits_bias(logits, sampling_metadata) logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties). # Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata) logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token. # Sample the next token.
sampled = self.sample(logits, sampling_metadata) sampled = self.sample(logits, sampling_metadata)
...@@ -82,9 +80,21 @@ class Sampler(nn.Module): ...@@ -82,9 +80,21 @@ class Sampler(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
assert not (sampling_metadata.all_greedy assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random) and sampling_metadata.all_random)
if sampling_metadata.all_greedy: if sampling_metadata.all_random:
return self.greedy_sample(logits) greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
return greedy_sampled
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply min_p.
if not sampling_metadata.no_min_p:
logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler( random_sampled = self.topk_topp_sampler(
logits, logits,
sampling_metadata.generators, sampling_metadata.generators,
...@@ -94,13 +104,9 @@ class Sampler(nn.Module): ...@@ -94,13 +104,9 @@ class Sampler(nn.Module):
sampling_metadata.top_p, sampling_metadata.top_p,
) )
if not sampling_metadata.no_min_p: if greedy_sampled is None:
logits = self.apply_min_p(logits, sampling_metadata.min_p)
if sampling_metadata.all_random:
return random_sampled return random_sampled
greedy_sampled = self.greedy_sample(logits)
sampled = torch.where( sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS, sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled, greedy_sampled,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment