metadata.py 340 Bytes
Newer Older
1
from dataclasses import dataclass
2
from typing import Dict
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

import torch


@dataclass
class SamplingMetadata:

    temperature: torch.Tensor
    all_greedy: bool
    all_random: bool

    top_p: torch.Tensor
    top_k: torch.Tensor
    no_top_p: bool
    no_top_k: bool

19
    generators: Dict[int, torch.Generator]
20
21

    max_num_logprobs: int