metadata.py 1.08 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import dataclass
5
from typing import Optional
6
7
8

import torch

9
from vllm.v1.sample.logits_processor import LogitsProcessors
10

11
12
13

@dataclass
class SamplingMetadata:
14
    temperature: Optional[torch.Tensor]
15
16
17
    all_greedy: bool
    all_random: bool

18
19
    top_p: Optional[torch.Tensor]
    top_k: Optional[torch.Tensor]
20

21
    generators: dict[int, torch.Generator]
22

23
24
    # None means no logprobs, 0 means sampled token logprobs only
    max_num_logprobs: Optional[int]
25
26
27
28
29
30
31

    no_penalties: bool
    prompt_token_ids: Optional[torch.Tensor]
    frequency_penalties: torch.Tensor
    presence_penalties: torch.Tensor
    repetition_penalties: torch.Tensor

32
    output_token_ids: list[list[int]]
33

34
35
36
    # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
    # vocab size).
    allowed_token_ids_mask: Optional[torch.Tensor]
37
38
39

    # req_index -> bad_words_token_ids
    bad_words_token_ids: dict[int, list[list[int]]]
40
41

    # Loaded logits processors
42
    logitsprocs: LogitsProcessors