metadata.py 4.19 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional

import torch

7
8
9
10
11
12
from vllm.v1.worker.gpu_input_batch import InputBatch

DEFAULT_SAMPLING_PARAMS = dict(
    temperature=-1.0,
    min_p=0.0,
    # strictly disabled for now
13
    top_k=0,
14
    top_p=1.0,
15
16
17
18
    # frequency_penalties=0.0,
    # presence_penalties=0.0,
    # repetition_penalties=0.0,
)
19
20
21
22
23
24
25


@dataclass
class TPUSupportedSamplingMetadata:
    # This class exposes a more xla-friendly interface than SamplingMetadata
    # on TPU, in particular all arguments should be traceable and no optionals
    # are allowed, to avoid graph recompilation on Nones.
26
    temperature: torch.Tensor = None
27

28
    min_p: torch.Tensor = None
29
30
31
    top_k: torch.Tensor = None
    top_p: torch.Tensor = None

32
    all_greedy: bool = True
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    # unsupported, you need to return an extra tensor of static size BxV
    max_num_logprobs = None

    # TODO No penalties for now
    no_penalties: bool = True
    prompt_token_ids = None
    frequency_penalties = None
    presence_penalties = None
    repetition_penalties = None
    # should use tensor
    output_token_ids: list[list[int]] = field(default_factory=lambda: list())

    min_tokens = None  # impl is not vectorized

    logit_bias: list[Optional[dict[int, float]]] = field(
        default_factory=lambda: list())

    allowed_token_ids_mask = None
    bad_words_token_ids = None

54
55
56
57
58
59
60
61
62
    # Generator not supported by xla
    _generators: dict[int,
                      torch.Generator] = field(default_factory=lambda: dict())

    @property
    def generators(self) -> dict[int, torch.Generator]:
        # Generator not supported by torch/xla. This field must be immutable.
        return self._generators

63
    @classmethod
64
    def from_input_batch(
65
66
67
68
69
70
        cls,
        input_batch: InputBatch,
        padded_num_reqs: int,
        xla_device: torch.device,
        generate_params_if_all_greedy: bool = False
    ) -> "TPUSupportedSamplingMetadata":
71
        """
72
73
74
75
        Copy sampling tensors slices from `input_batch` to on device tensors.

        `InputBatch._make_sampling_metadata` causes recompilation on XLA as it 
        slices dynamic shapes on device tensors. This impl moves the dynamic 
76
77
78
79
80
81
82
83
84
85
        ops to CPU and produces tensors of fixed `padded_num_reqs` size.

        Args:
            input_batch: The input batch containing sampling parameters.
            padded_num_reqs: The padded number of requests.
            xla_device: The XLA device.
            generate_params_if_all_greedy: If True, generate sampling parameters
                even if all requests are greedy. this is useful for cases where
                we want to pre-compile a graph with sampling parameters, even if
                they are not strictly needed for greedy decoding.
86
        """
87
88
89
90
91
        # Early return to avoid unnecessary cpu to tpu copy
        if (input_batch.all_greedy is True
                and generate_params_if_all_greedy is False):
            return cls(all_greedy=True)

92
93
        num_reqs = input_batch.num_reqs

94
        def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
95
96
97
            # Pad value is the default one.
            cpu_tensor[num_reqs:padded_num_reqs] = fill_val

98
        fill_slice(input_batch.temperature_cpu_tensor,
99
                   DEFAULT_SAMPLING_PARAMS["temperature"])
100
        fill_slice(input_batch.min_p_cpu_tensor,
101
                   DEFAULT_SAMPLING_PARAMS["min_p"])
102
103
        fill_slice(input_batch.top_k_cpu_tensor,
                   DEFAULT_SAMPLING_PARAMS["top_k"])
104
105
        fill_slice(input_batch.top_p_cpu_tensor,
                   DEFAULT_SAMPLING_PARAMS["top_p"])
106
107
108

        # Slice persistent device tensors to a fixed pre-compiled padded shape.
        return cls(
109
110
111
            temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
            to(xla_device),
            all_greedy=input_batch.all_greedy,
112
            # TODO enable more and avoid returning None values
113
114
            top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(
                xla_device),
115
116
            top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
                xla_device),
117
            min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
118
                xla_device))