sampling_params.py 3.33 KB
Newer Older
1
from typing import Dict, Set
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3


Woosuk Kwon's avatar
Woosuk Kwon committed
4
class SamplingParams:
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7

    def __init__(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
11
12
13
14
15
16
17
        n: int = 1,
        presence_penalty: float = 0.0,
        frequency_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        use_beam_search: bool = False,
        stop_token_ids: Set[int] = set(),
        max_tokens: int = 16,
        logprobs: int = 0,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
    ) -> None:
19
        if n < 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
20
            raise ValueError(f"n must be at least 1, got {n}.")
21
22
23
24
25
26
        if not -2.0 <= presence_penalty <= 2.0:
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {presence_penalty}.")
        if not -2.0 <= frequency_penalty <= 2.0:
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {frequency_penalty}.")
27
28
        if temperature < 0.0:
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
29
                f"temperature must be non-negative, got {temperature}.")
30
        if not 0.0 < top_p <= 1.0:
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
34
            raise ValueError(f"top_p must be in (0, 1], got {top_p}.")
        if top_k < -1 or top_k == 0:
            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                             f"got {top_k}.")
Woosuk Kwon's avatar
Woosuk Kwon committed
35
        if max_tokens < 1:
36
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
37
38
                f"max_tokens must be at least 1, got {max_tokens}.")
        if logprobs < 0:
39
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
40
                f"logprobs must be non-negative, got {logprobs}.")
41

Woosuk Kwon's avatar
Woosuk Kwon committed
42
        if use_beam_search:
43
44
            if n == 1:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
45
                    "n must be greater than 1 when using beam search.")
46
47
            if temperature > 0.0:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
48
                    "temperature must be 0 when using beam search.")
49
50
            if top_p < 1.0:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
54
                    "top_p must be 1 when using beam search.")
            if top_k != -1:
                raise ValueError(
                    "top_k must be -1 when using beam search.")
Woosuk Kwon's avatar
Woosuk Kwon committed
55
        elif temperature == 0.0:
56
57
58
            # Zero temperature means greedy sampling.
            if n > 1:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
59
                    "n must be 1 when using greedy sampling.")
60
61
            if top_p < 1.0:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
65
                    "top_p must be 1 when using greedy sampling.")
            if top_k != -1:
                raise ValueError(
                    "top_k must be -1 when using greedy sampling.")
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67

        self.n = n
68
69
        self.presence_penalty = presence_penalty
        self.frequency_penalty = frequency_penalty
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
        self.temperature = temperature
        self.top_p = top_p
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        self.top_k = top_k
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
        self.use_beam_search = use_beam_search
        self.stop_token_ids = stop_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
        self.max_tokens = max_tokens
        self.logprobs = logprobs
77
78

    def __repr__(self) -> str:
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        return (f"SamplingParams(n={self.n}, "
80
81
                f"presence_penalty={self.presence_penalty}, "
                f"frequency_penalty={self.frequency_penalty}, "
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
                f"temperature={self.temperature}, "
                f"top_p={self.top_p}, "
                f"top_k={self.top_k},"
                f"use_beam_search={self.use_beam_search}, "
                f"stop_token_ids={self.stop_token_ids}, "
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
                f"max_tokens={self.max_tokens}, "
                f"logprobs={self.logprobs}")