Unverified Commit 140c6ede authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

Small fix to ExponentialDecayLengthPenalty docstring (#21308)

Currently, it incorrectly states that the exponential_decay_length_penalty tuple parameter is optional.

Also changed the corresponding type hint to be more specific.
parent 3a6e4a22
...@@ -825,7 +825,7 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): ...@@ -825,7 +825,7 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
reached. reached.
Args: Args:
exponential_decay_length_penalty (`tuple(int, float)`, *optional*): exponential_decay_length_penalty (`tuple(int, float)`):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`Union[int, List[int]]`): eos_token_id (`Union[int, List[int]]`):
...@@ -835,7 +835,10 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): ...@@ -835,7 +835,10 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
""" """
def __init__( def __init__(
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int self,
exponential_decay_length_penalty: Tuple[int, float],
eos_token_id: Union[int, List[int]],
input_ids_seq_length: int,
): ):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1] self.regulation_factor = exponential_decay_length_penalty[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment