sampling_params.py 2.68 KB
Newer Older
1
from typing import Dict, Set
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3


Woosuk Kwon's avatar
Woosuk Kwon committed
4
class SamplingParams:
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7

    def __init__(
        self,
8
9
10
11
12
13
14
        n: int,
        temperature: float,
        top_p: float,
        use_beam_search: bool,
        stop_token_ids: Set[int],
        max_num_steps: int,
        num_logprobs: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
15
    ) -> None:
16
17
18
19
20
21
22
23
24
25
26
27
28
29
        if n < 1:
            raise ValueError(f'n must be at least 1, got {n}.')
        if temperature < 0.0:
            raise ValueError(
                f'temperature must be non-negative, got {temperature}.')
        if not 0.0 < top_p <= 1.0:
            raise ValueError(f'top_p must be in (0, 1], got {top_p}.')
        if max_num_steps < 1:
            raise ValueError(
                f'max_num_steps must be at least 1, got {max_num_steps}.')
        if num_logprobs < 0:
            raise ValueError(
                f'num_logprobs must be non-negative, got {num_logprobs}.')

Woosuk Kwon's avatar
Woosuk Kwon committed
30
        if use_beam_search:
31
32
33
34
35
36
37
38
39
            if n == 1:
                raise ValueError(
                    'n must be greater than 1 when using beam search.')
            if temperature > 0.0:
                raise ValueError(
                    'temperature must be 0 when using beam search.')
            if top_p < 1.0:
                raise ValueError(
                    'top_p must be 1 when using beam search.')
Woosuk Kwon's avatar
Woosuk Kwon committed
40
        elif temperature == 0.0:
41
42
43
44
45
46
47
            # Zero temperature means greedy sampling.
            if n > 1:
                raise ValueError(
                    'n must be 1 when using greedy sampling.')
            if top_p < 1.0:
                raise ValueError(
                    'top_p must be 1 when using greedy sampling.')
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50
51
52
53

        self.n = n
        self.temperature = temperature
        self.top_p = top_p
        self.use_beam_search = use_beam_search
        self.stop_token_ids = stop_token_ids
54
        self.max_num_steps = max_num_steps
55
56
57
58
59
60
61
62
63
        self.num_logprobs = num_logprobs

    def __repr__(self) -> str:
        return (f'SamplingParams(n={self.n}, '
                f'temperature={self.temperature}, '
                f'top_p={self.top_p}, '
                f'use_beam_search={self.use_beam_search}, '
                f'stop_token_ids={self.stop_token_ids}, '
                f'max_num_steps={self.max_num_steps}, '
64
                f'num_logprobs={self.num_logprobs}')
65
66
67
68
69
70
71
72
73
74
75
76

    @classmethod
    def from_dict(cls, d: Dict) -> 'SamplingParams':
        return cls(
            n=d.get('n', 1),
            temperature=d.get('temperature', 1.0),
            top_p=d.get('top_p', 1.0),
            use_beam_search=d.get('use_beam_search', False),
            stop_token_ids=set(d.get('stop_token_ids', set())),
            max_num_steps=d.get('max_num_steps', 16),
            num_logprobs=d.get('num_logprobs', 0),
        )