sampling_params.py 1.06 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Optional, Set
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3


Woosuk Kwon's avatar
Woosuk Kwon committed
4
class SamplingParams:
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11

    def __init__(
        self,
        n: int = 1,
        temperature: float = 1.0,
        top_p: float = 1.0,
        use_beam_search: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
12
        stop_token_ids: Set[int] = [],
13
        max_num_steps: int = 16,  # From OpenAI API.
Woosuk Kwon's avatar
Woosuk Kwon committed
14
        max_context_len: Optional[int] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
18
19
20
21
22
23
24
25
26
    ) -> None:
        assert n >= 1
        assert temperature >= 0.0
        assert 0.0 < top_p <= 1.0
        if use_beam_search:
            assert n > 1
            assert temperature > 0.0
            assert top_p == 1.0
        elif temperature == 0.0:
            # Zero temperature means greedy decoding.
            assert n == 1
            assert top_p == 1.0
27
        assert max_num_steps >= 1
Woosuk Kwon's avatar
Woosuk Kwon committed
28
        assert max_context_len is None or max_context_len >= 0
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33
34

        self.n = n
        self.temperature = temperature
        self.top_p = top_p
        self.use_beam_search = use_beam_search
        self.stop_token_ids = stop_token_ids
35
        self.max_num_steps = max_num_steps
Woosuk Kwon's avatar
Woosuk Kwon committed
36
        self.max_context_len = max_context_len