"src/vscode:/vscode.git/clone" did not exist on "40b44a43a98b4a5dbc088fd1a6839adbda370fb5"
sampling_params.py 5.28 KB
Newer Older
1
"""Sampling parameters for text generation."""
2
from typing import Set
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4


Woosuk Kwon's avatar
Woosuk Kwon committed
5
class SamplingParams:
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.

    Args:
        n: Number of output sequences to generate from the given prompt. This is
            regarded as the beam width when using beam search.
        presence_penalty: Float that penalizes new tokens based on whether they
            appear in the generated text so far. Values > 0 encourage the model
            to use new tokens, while values < 0 encourage the model to repeat
            tokens.
        frequency_penalty: Float that penalizes new tokens based on their
            frequency in the generated text so far. Values > 0 encourage the
            model to use new tokens, while values < 0 encourage the model to
            repeat tokens.
        temperature: Float that controls the randomness of the sampling. Lower
            values make the model more deterministic, while higher values make
            the model more random. Zero means greedy sampling.
        top_p: Float that controls the cumulative probability of the top tokens
            to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
        top_k: Integer that controls the number of top tokens to consider. Set
            to -1 to consider all tokens.
        use_beam_search: Whether to use beam search instead of sampling.
        stop_token_ids: Set of token IDs that indicate the end of a sequence.
        max_tokens: Maximum number of tokens to generate per output sequence.
        logprobs: Number of log probabilities to return per output token.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37

    def __init__(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
41
42
43
44
45
46
47
        n: int = 1,
        presence_penalty: float = 0.0,
        frequency_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        use_beam_search: bool = False,
        stop_token_ids: Set[int] = set(),
        max_tokens: int = 16,
        logprobs: int = 0,
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
    ) -> None:
        self.n = n
50
51
        self.presence_penalty = presence_penalty
        self.frequency_penalty = frequency_penalty
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
        self.temperature = temperature
        self.top_p = top_p
Woosuk Kwon's avatar
Woosuk Kwon committed
54
        self.top_k = top_k
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
        self.use_beam_search = use_beam_search
        self.stop_token_ids = stop_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
        self.max_tokens = max_tokens
        self.logprobs = logprobs
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        self._verify_args()
        if self.use_beam_search:
            self._verity_beam_search()
        elif self.temperature == 0.0:
            # Zero temperature means greedy sampling.
            self._verify_greedy_sampling()

    def _verify_args(self) -> None:
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
        if not -2.0 <= self.presence_penalty <= 2.0:
            raise ValueError("presence_penalty must be in [-2, 2], got "
                             f"{self.presence_penalty}.")
        if not -2.0 <= self.frequency_penalty <= 2.0:
            raise ValueError("frequency_penalty must be in [-2, 2], got "
                             f"{self.frequency_penalty}.")
        if self.temperature < 0.0:
            raise ValueError(
                f"temperature must be non-negative, got {self.temperature}.")
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
        if self.top_k < -1 or self.top_k == 0:
            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                             f"got {self.top_k}.")
        if self.max_tokens < 1:
            raise ValueError(
                f"max_tokens must be at least 1, got {self.max_tokens}.")
        if self.logprobs < 0:
            raise ValueError(
                f"logprobs must be non-negative, got {self.logprobs}.")

    def _verity_beam_search(self) -> None:
        if self.n == 1:
            raise ValueError("n must be greater than 1 when using beam search.")
        if self.temperature > 0.0:
            raise ValueError("temperature must be 0 when using beam search.")
        if self.top_p < 1.0:
            raise ValueError("top_p must be 1 when using beam search.")
        if self.top_k != -1:
            raise ValueError("top_k must be -1 when using beam search.")

    def _verify_greedy_sampling(self) -> None:
        if self.n > 1:
            raise ValueError("n must be 1 when using greedy sampling.")
        if self.top_p < 1.0:
            raise ValueError("top_p must be 1 when using greedy sampling.")
        if self.top_k != -1:
            raise ValueError("top_k must be -1 when using greedy sampling.")

109
    def __repr__(self) -> str:
Woosuk Kwon's avatar
Woosuk Kwon committed
110
        return (f"SamplingParams(n={self.n}, "
111
112
                f"presence_penalty={self.presence_penalty}, "
                f"frequency_penalty={self.frequency_penalty}, "
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
116
117
                f"temperature={self.temperature}, "
                f"top_p={self.top_p}, "
                f"top_k={self.top_k},"
                f"use_beam_search={self.use_beam_search}, "
                f"stop_token_ids={self.stop_token_ids}, "
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
                f"max_tokens={self.max_tokens}, "
                f"logprobs={self.logprobs}")