metadata.py 1002 Bytes
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from dataclasses import dataclass
4
from typing import Optional
5
6
7
8
9
10
11

import torch


@dataclass
class SamplingMetadata:

12
    temperature: Optional[torch.Tensor]
13
14
15
    all_greedy: bool
    all_random: bool

16
17
18
    top_p: Optional[torch.Tensor]
    top_k: Optional[torch.Tensor]
    min_p: Optional[torch.Tensor]
19

20
    generators: dict[int, torch.Generator]
21

22
23
    # None means no logprobs, 0 means sampled token logprobs only
    max_num_logprobs: Optional[int]
24
25
26
27
28
29
30

    no_penalties: bool
    prompt_token_ids: Optional[torch.Tensor]
    frequency_penalties: torch.Tensor
    presence_penalties: torch.Tensor
    repetition_penalties: torch.Tensor

31
    output_token_ids: list[list[int]]
32
33

    # req_index -> (min_tokens, stop_token_ids)
34
    min_tokens: dict[int, tuple[int, set[int]]]
35

36
    logit_bias: list[Optional[dict[int, float]]]
37
38
39
40

    # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
    # vocab size).
    allowed_token_ids_mask: Optional[torch.Tensor]