metadata.py 2.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import torch


@dataclass
class SamplingMetadata:
10
11
    idx_mapping: torch.Tensor

12
13
14
15
    temperature: torch.Tensor

    top_p: torch.Tensor | None
    top_k: torch.Tensor | None
16
    min_p: torch.Tensor | None
17

18
    # For penalties
19
20
21
    repetition_penalty: torch.Tensor
    frequency_penalty: torch.Tensor
    presence_penalty: torch.Tensor
22
23
    prompt_bin_mask: torch.Tensor
    output_bin_counts: torch.Tensor
24
25
26
27
28
29
30
31
32
33
34
35
36
37

    seeds: torch.Tensor
    pos: torch.Tensor

    # None means no logprobs, 0 means sampled token logprobs only
    max_num_logprobs: int | None

    @classmethod
    def make_dummy(
        cls,
        num_reqs: int,
        device: torch.device,
    ) -> "SamplingMetadata":
        assert num_reqs > 0
38
39
        idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)

40
41
42
43
44
45
46
47
        temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
        temperature[0] = 0.5
        # TODO(woosuk): Use top-p and top-k for dummy sampler.
        # Currently, they are disabled because of memory usage.
        # top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
        # top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
        top_p = None
        top_k = None
48
        min_p = torch.zeros(num_reqs, dtype=torch.float32, device=device)
49
50
51
52
53
54
55
56
57
        # NOTE(woosuk): We must set penalties to their default values to make sure
        # the penalties kernel does not touch the placeholder bin_counts tensors.
        repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
        frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
        presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)

        # NOTE(woosuk): These are placeholder tensors to avoid None checks in the
        # penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
        # specialization and re-compilation at runtime.
58
        prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
59
60
        output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)

61
62
63
64
        seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
        pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
        max_num_logprobs = 20

65
        return cls(
66
            idx_mapping=idx_mapping,
67
68
69
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
70
            min_p=min_p,
71
72
73
            repetition_penalty=repetition_penalty,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
74
75
            prompt_bin_mask=prompt_bin_mask,
            output_bin_counts=output_bin_counts,
76
77
78
79
            seeds=seeds,
            pos=pos,
            max_num_logprobs=max_num_logprobs,
        )