metadata.py 4.32 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional

import torch

7
8
9
10
11
12
from vllm.v1.worker.gpu_input_batch import InputBatch

DEFAULT_SAMPLING_PARAMS = dict(
    temperature=-1.0,
    min_p=0.0,
    # strictly disabled for now
13
    top_k=0,
14
15
16
17
18
    # top_p=0.0,
    # frequency_penalties=0.0,
    # presence_penalties=0.0,
    # repetition_penalties=0.0,
)
19
20
21
22
23
24
25


@dataclass
class TPUSupportedSamplingMetadata:
    # This class exposes a more xla-friendly interface than SamplingMetadata
    # on TPU, in particular all arguments should be traceable and no optionals
    # are allowed, to avoid graph recompilation on Nones.
26
    temperature: torch.Tensor = None
27

28
    min_p: torch.Tensor = None
29
30
31
32
33
    # Still too slow on forward_native!
    top_k: torch.Tensor = None
    top_p: torch.Tensor = None

    # Greedy sampling flag for compiling single xla graph.
34
    all_greedy: bool = True
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    # unsupported, you need to return an extra tensor of static size BxV
    max_num_logprobs = None

    # TODO No penalties for now
    no_penalties: bool = True
    prompt_token_ids = None
    frequency_penalties = None
    presence_penalties = None
    repetition_penalties = None
    # should use tensor
    output_token_ids: list[list[int]] = field(default_factory=lambda: list())

    min_tokens = None  # impl is not vectorized

    logit_bias: list[Optional[dict[int, float]]] = field(
        default_factory=lambda: list())

    allowed_token_ids_mask = None
    bad_words_token_ids = None

56
57
58
59
60
61
62
63
64
    # Generator not supported by xla
    _generators: dict[int,
                      torch.Generator] = field(default_factory=lambda: dict())

    @property
    def generators(self) -> dict[int, torch.Generator]:
        # Generator not supported by torch/xla. This field must be immutable.
        return self._generators

65
    @classmethod
66
    def from_input_batch(
67
68
69
70
71
72
        cls,
        input_batch: InputBatch,
        padded_num_reqs: int,
        xla_device: torch.device,
        generate_params_if_all_greedy: bool = False
    ) -> "TPUSupportedSamplingMetadata":
73
        """
74
75
76
77
        Copy sampling tensors slices from `input_batch` to on device tensors.

        `InputBatch._make_sampling_metadata` causes recompilation on XLA as it 
        slices dynamic shapes on device tensors. This impl moves the dynamic 
78
79
80
81
82
83
84
85
86
87
        ops to CPU and produces tensors of fixed `padded_num_reqs` size.

        Args:
            input_batch: The input batch containing sampling parameters.
            padded_num_reqs: The padded number of requests.
            xla_device: The XLA device.
            generate_params_if_all_greedy: If True, generate sampling parameters
                even if all requests are greedy. this is useful for cases where
                we want to pre-compile a graph with sampling parameters, even if
                they are not strictly needed for greedy decoding.
88
        """
89
90
91
92
93
        # Early return to avoid unnecessary cpu to tpu copy
        if (input_batch.all_greedy is True
                and generate_params_if_all_greedy is False):
            return cls(all_greedy=True)

94
95
        num_reqs = input_batch.num_reqs

96
        def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
97
98
99
            # Pad value is the default one.
            cpu_tensor[num_reqs:padded_num_reqs] = fill_val

100
        fill_slice(input_batch.temperature_cpu_tensor,
101
                   DEFAULT_SAMPLING_PARAMS["temperature"])
102
        fill_slice(input_batch.min_p_cpu_tensor,
103
                   DEFAULT_SAMPLING_PARAMS["min_p"])
104
105
106
107
108
        fill_slice(input_batch.top_k_cpu_tensor,
                   DEFAULT_SAMPLING_PARAMS["top_k"])
        # TODO Temporarily disabled until sampling options are enabled
        # fill_slice(input_batch.top_p_cpu_tensor,
        #            DEFAULT_SAMPLING_PARAMS["top_p"])
109
110
111

        # Slice persistent device tensors to a fixed pre-compiled padded shape.
        return cls(
112
113
114
            temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
            to(xla_device),
            all_greedy=input_batch.all_greedy,
115
116
            # TODO enable more and avoid returning None values
            top_p=None,  # input_batch.top_p[:padded_num_reqs],
117
118
            top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
                xla_device),
119
            min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
120
                xla_device))