Unverified Commit 6a7c7711 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Skip for logits_scale == 1.0 (#5291)

parent 0f83ddd4
......@@ -21,7 +21,7 @@ class LogitsProcessor(nn.Module):
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: Optional[float] = 1.0,
scale: float = 1.0,
logits_as_input: bool = False) -> None:
"""
Args:
......@@ -52,7 +52,8 @@ class LogitsProcessor(nn.Module):
logits = self._get_logits(hidden_states, embedding, embedding_bias)
if logits is not None:
logits *= self.scale
if self.scale != 1.0:
logits *= self.scale
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment