sampling_params.py 4.05 KB
Newer Older
1
from typing import Dict, Set
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3


Woosuk Kwon's avatar
Woosuk Kwon committed
4
class SamplingParams:
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7

    def __init__(
        self,
8
        n: int,
9
10
        presence_penalty: float,
        frequency_penalty: float,
11
12
        temperature: float,
        top_p: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
13
        top_k: int,
14
15
16
17
        use_beam_search: bool,
        stop_token_ids: Set[int],
        max_num_steps: int,
        num_logprobs: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
    ) -> None:
19
        if n < 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
20
            raise ValueError(f"n must be at least 1, got {n}.")
21
22
23
24
25
26
        if not -2.0 <= presence_penalty <= 2.0:
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {presence_penalty}.")
        if not -2.0 <= frequency_penalty <= 2.0:
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {frequency_penalty}.")
27
28
        if temperature < 0.0:
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
29
                f"temperature must be non-negative, got {temperature}.")
30
        if not 0.0 < top_p <= 1.0:
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
34
            raise ValueError(f"top_p must be in (0, 1], got {top_p}.")
        if top_k < -1 or top_k == 0:
            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                             f"got {top_k}.")
35
36
        if max_num_steps < 1:
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
37
                f"max_num_steps must be at least 1, got {max_num_steps}.")
38
39
        if num_logprobs < 0:
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
40
                f"num_logprobs must be non-negative, got {num_logprobs}.")
41

Woosuk Kwon's avatar
Woosuk Kwon committed
42
        if use_beam_search:
43
44
            if n == 1:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
45
                    "n must be greater than 1 when using beam search.")
46
47
            if temperature > 0.0:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
48
                    "temperature must be 0 when using beam search.")
49
50
            if top_p < 1.0:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
54
                    "top_p must be 1 when using beam search.")
            if top_k != -1:
                raise ValueError(
                    "top_k must be -1 when using beam search.")
Woosuk Kwon's avatar
Woosuk Kwon committed
55
        elif temperature == 0.0:
56
57
58
            # Zero temperature means greedy sampling.
            if n > 1:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
59
                    "n must be 1 when using greedy sampling.")
60
61
            if top_p < 1.0:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
65
                    "top_p must be 1 when using greedy sampling.")
            if top_k != -1:
                raise ValueError(
                    "top_k must be -1 when using greedy sampling.")
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67

        self.n = n
68
69
        self.presence_penalty = presence_penalty
        self.frequency_penalty = frequency_penalty
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
        self.temperature = temperature
        self.top_p = top_p
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        self.top_k = top_k
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
        self.use_beam_search = use_beam_search
        self.stop_token_ids = stop_token_ids
75
        self.max_num_steps = max_num_steps
76
77
78
        self.num_logprobs = num_logprobs

    def __repr__(self) -> str:
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        return (f"SamplingParams(n={self.n}, "
80
81
                f"presence_penalty={self.presence_penalty}, "
                f"frequency_penalty={self.frequency_penalty}, "
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
87
88
                f"temperature={self.temperature}, "
                f"top_p={self.top_p}, "
                f"top_k={self.top_k},"
                f"use_beam_search={self.use_beam_search}, "
                f"stop_token_ids={self.stop_token_ids}, "
                f"max_num_steps={self.max_num_steps}, "
                f"num_logprobs={self.num_logprobs}")
89
90

    @classmethod
Woosuk Kwon's avatar
Woosuk Kwon committed
91
    def from_dict(cls, d: Dict) -> "SamplingParams":
92
93
94
95
96
97
98
99
100
101
102
        sampling_params = cls(
            n=d.pop("n", 1),
            presence_penalty=d.pop("presence_penalty", 0.0),
            frequency_penalty=d.pop("frequency_penalty", 0.0),
            temperature=d.pop("temperature", 1.0),
            top_p=d.pop("top_p", 1.0),
            top_k=d.pop("top_k", -1),
            use_beam_search=d.pop("use_beam_search", False),
            stop_token_ids=set(d.pop("stop_token_ids", set())),
            max_num_steps=d.pop("max_num_steps", 16),
            num_logprobs=d.pop("num_logprobs", 0),
103
        )
104
105
106
        if d:
            raise ValueError(f"Unrecognized keys in dict: {d.keys()}")
        return sampling_params