metadata.py 917 Bytes
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from dataclasses import dataclass
4
from typing import Dict, List, Optional, Set
5
6
7
8
9
10
11
12
13
14

import torch


@dataclass
class SamplingMetadata:

    temperature: torch.Tensor
    all_greedy: bool
    all_random: bool
15
16
    rejection_sampling: bool
    spec_token_ids: List[List[int]]
17
18
19
20
21

    top_p: torch.Tensor
    top_k: torch.Tensor
    no_top_p: bool
    no_top_k: bool
22
23
    min_p: torch.Tensor
    no_min_p: bool
24

25
    generators: Dict[int, torch.Generator]
26

27
28
    # None means no logprobs, 0 means sampled token logprobs only
    max_num_logprobs: Optional[int]
29
30
31
32
33
34
35
36
37
38

    no_penalties: bool
    prompt_token_ids: Optional[torch.Tensor]
    frequency_penalties: torch.Tensor
    presence_penalties: torch.Tensor
    repetition_penalties: torch.Tensor

    output_token_ids: List[List[int]]
    min_tokens: List[int]
    stop_token_ids: List[Set[int]]
39
40

    logit_bias: List[Optional[Dict[int, float]]]