metadata.py 1.43 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import dataclass
5
from typing import Optional
6
7
8

import torch

9
10
from vllm.v1.sample.logits_processor import LogitsProcessorManager

11
12
13
14

@dataclass
class SamplingMetadata:

15
    temperature: Optional[torch.Tensor]
16
17
18
    all_greedy: bool
    all_random: bool

19
20
    top_p: Optional[torch.Tensor]
    top_k: Optional[torch.Tensor]
21

22
    generators: dict[int, torch.Generator]
23

24
25
    # None means no logprobs, 0 means sampled token logprobs only
    max_num_logprobs: Optional[int]
26
27
28
29
30
31
32

    no_penalties: bool
    prompt_token_ids: Optional[torch.Tensor]
    frequency_penalties: torch.Tensor
    presence_penalties: torch.Tensor
    repetition_penalties: torch.Tensor

33
    output_token_ids: list[list[int]]
34

35
36
37
    # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
    # vocab size).
    allowed_token_ids_mask: Optional[torch.Tensor]
38
39
40

    # req_index -> bad_words_token_ids
    bad_words_token_ids: dict[int, list[list[int]]]
41
42
43

    # Loaded logits processors
    logitsprocs: LogitsProcessorManager
44
45
46
47
48
49
50

    # Optional host-side summaries to avoid device sync in fast paths.
    # When `top_k` is provided, `max_top_k` is the maximum top-k value across
    # the batch on the host (Python int).
    max_top_k: Optional[int] = None
    # True if any request in the batch has top_k == vocab_size (i.e. no top-k).
    has_any_no_top_k: bool = False