decoding.py 957 Bytes
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Optional, Set
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3


Woosuk Kwon's avatar
Woosuk Kwon committed
4
class SamplingParams:
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11

    def __init__(
        self,
        n: int = 1,
        temperature: float = 1.0,
        top_p: float = 1.0,
        use_beam_search: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
        stop_token_ids: Set[int] = [],
        max_context_len: Optional[int] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
17
18
19
20
21
22
23
24
25
    ) -> None:
        assert n >= 1
        assert temperature >= 0.0
        assert 0.0 < top_p <= 1.0
        if use_beam_search:
            assert n > 1
            assert temperature > 0.0
            assert top_p == 1.0
        elif temperature == 0.0:
            # Zero temperature means greedy decoding.
            assert n == 1
            assert top_p == 1.0
Woosuk Kwon's avatar
Woosuk Kwon committed
26
        assert max_context_len is None or max_context_len >= 0
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29
30
31
32

        self.n = n
        self.temperature = temperature
        self.top_p = top_p
        self.use_beam_search = use_beam_search
        self.stop_token_ids = stop_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
33
        self.max_context_len = max_context_len