sampling_params.py 3.51 KB
Newer Older
1
from typing import Set
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3


Woosuk Kwon's avatar
Woosuk Kwon committed
4
class SamplingParams:
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7

    def __init__(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
11
12
13
14
15
16
17
        n: int = 1,
        presence_penalty: float = 0.0,
        frequency_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        use_beam_search: bool = False,
        stop_token_ids: Set[int] = set(),
        max_tokens: int = 16,
        logprobs: int = 0,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
    ) -> None:
        self.n = n
20
21
        self.presence_penalty = presence_penalty
        self.frequency_penalty = frequency_penalty
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
        self.temperature = temperature
        self.top_p = top_p
Woosuk Kwon's avatar
Woosuk Kwon committed
24
        self.top_k = top_k
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
        self.use_beam_search = use_beam_search
        self.stop_token_ids = stop_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
        self.max_tokens = max_tokens
        self.logprobs = logprobs
29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        self._verify_args()
        if self.use_beam_search:
            self._verity_beam_search()
        elif self.temperature == 0.0:
            # Zero temperature means greedy sampling.
            self._verify_greedy_sampling()

    def _verify_args(self) -> None:
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
        if not -2.0 <= self.presence_penalty <= 2.0:
            raise ValueError("presence_penalty must be in [-2, 2], got "
                             f"{self.presence_penalty}.")
        if not -2.0 <= self.frequency_penalty <= 2.0:
            raise ValueError("frequency_penalty must be in [-2, 2], got "
                             f"{self.frequency_penalty}.")
        if self.temperature < 0.0:
            raise ValueError(
                f"temperature must be non-negative, got {self.temperature}.")
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
        if self.top_k < -1 or self.top_k == 0:
            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                             f"got {self.top_k}.")
        if self.max_tokens < 1:
            raise ValueError(
                f"max_tokens must be at least 1, got {self.max_tokens}.")
        if self.logprobs < 0:
            raise ValueError(
                f"logprobs must be non-negative, got {self.logprobs}.")

    def _verity_beam_search(self) -> None:
        if self.n == 1:
            raise ValueError("n must be greater than 1 when using beam search.")
        if self.temperature > 0.0:
            raise ValueError("temperature must be 0 when using beam search.")
        if self.top_p < 1.0:
            raise ValueError("top_p must be 1 when using beam search.")
        if self.top_k != -1:
            raise ValueError("top_k must be -1 when using beam search.")

    def _verify_greedy_sampling(self) -> None:
        if self.n > 1:
            raise ValueError("n must be 1 when using greedy sampling.")
        if self.top_p < 1.0:
            raise ValueError("top_p must be 1 when using greedy sampling.")
        if self.top_k != -1:
            raise ValueError("top_k must be -1 when using greedy sampling.")

79
    def __repr__(self) -> str:
Woosuk Kwon's avatar
Woosuk Kwon committed
80
        return (f"SamplingParams(n={self.n}, "
81
82
                f"presence_penalty={self.presence_penalty}, "
                f"frequency_penalty={self.frequency_penalty}, "
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
85
86
87
                f"temperature={self.temperature}, "
                f"top_p={self.top_p}, "
                f"top_k={self.top_k},"
                f"use_beam_search={self.use_beam_search}, "
                f"stop_token_ids={self.stop_token_ids}, "
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
                f"max_tokens={self.max_tokens}, "
                f"logprobs={self.logprobs}")