sampling_params.py 3.2 KB
Newer Older
1
from typing import Dict, Set
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3


Woosuk Kwon's avatar
Woosuk Kwon committed
4
class SamplingParams:
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7

    def __init__(
        self,
8
9
10
        n: int,
        temperature: float,
        top_p: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
11
        top_k: int,
12
13
14
15
        use_beam_search: bool,
        stop_token_ids: Set[int],
        max_num_steps: int,
        num_logprobs: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
16
    ) -> None:
17
        if n < 1:
Woosuk Kwon's avatar
Woosuk Kwon committed
18
            raise ValueError(f"n must be at least 1, got {n}.")
19
20
        if temperature < 0.0:
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
21
                f"temperature must be non-negative, got {temperature}.")
22
        if not 0.0 < top_p <= 1.0:
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25
26
            raise ValueError(f"top_p must be in (0, 1], got {top_p}.")
        if top_k < -1 or top_k == 0:
            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                             f"got {top_k}.")
27
28
        if max_num_steps < 1:
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
29
                f"max_num_steps must be at least 1, got {max_num_steps}.")
30
31
        if num_logprobs < 0:
            raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
32
                f"num_logprobs must be non-negative, got {num_logprobs}.")
33

Woosuk Kwon's avatar
Woosuk Kwon committed
34
        if use_beam_search:
35
36
            if n == 1:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
37
                    "n must be greater than 1 when using beam search.")
38
39
            if temperature > 0.0:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
40
                    "temperature must be 0 when using beam search.")
41
42
            if top_p < 1.0:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45
46
                    "top_p must be 1 when using beam search.")
            if top_k != -1:
                raise ValueError(
                    "top_k must be -1 when using beam search.")
Woosuk Kwon's avatar
Woosuk Kwon committed
47
        elif temperature == 0.0:
48
49
50
            # Zero temperature means greedy sampling.
            if n > 1:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
51
                    "n must be 1 when using greedy sampling.")
52
53
            if top_p < 1.0:
                raise ValueError(
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
                    "top_p must be 1 when using greedy sampling.")
            if top_k != -1:
                raise ValueError(
                    "top_k must be -1 when using greedy sampling.")
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61

        self.n = n
        self.temperature = temperature
        self.top_p = top_p
Woosuk Kwon's avatar
Woosuk Kwon committed
62
        self.top_k = top_k
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
        self.use_beam_search = use_beam_search
        self.stop_token_ids = stop_token_ids
65
        self.max_num_steps = max_num_steps
66
67
68
        self.num_logprobs = num_logprobs

    def __repr__(self) -> str:
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
71
72
73
74
75
76
        return (f"SamplingParams(n={self.n}, "
                f"temperature={self.temperature}, "
                f"top_p={self.top_p}, "
                f"top_k={self.top_k},"
                f"use_beam_search={self.use_beam_search}, "
                f"stop_token_ids={self.stop_token_ids}, "
                f"max_num_steps={self.max_num_steps}, "
                f"num_logprobs={self.num_logprobs}")
77
78

    @classmethod
Woosuk Kwon's avatar
Woosuk Kwon committed
79
    def from_dict(cls, d: Dict) -> "SamplingParams":
80
        return cls(
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
84
85
86
87
88
            n=d.get("n", 1),
            temperature=d.get("temperature", 1.0),
            top_p=d.get("top_p", 1.0),
            top_k=d.get("top_k", -1),
            use_beam_search=d.get("use_beam_search", False),
            stop_token_ids=set(d.get("stop_token_ids", set())),
            max_num_steps=d.get("max_num_steps", 16),
            num_logprobs=d.get("num_logprobs", 0),
89
        )