metadata.py 4.06 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional

import torch

7
8
9
10
11
12
13
14
15
16
17
18
from vllm.v1.worker.gpu_input_batch import InputBatch

DEFAULT_SAMPLING_PARAMS = dict(
    temperature=-1.0,
    min_p=0.0,
    # strictly disabled for now
    # top_k=-1,
    # top_p=0.0,
    # frequency_penalties=0.0,
    # presence_penalties=0.0,
    # repetition_penalties=0.0,
)
19
20
21
22
23
24
25


@dataclass
class TPUSupportedSamplingMetadata:
    # This class exposes a more xla-friendly interface than SamplingMetadata
    # on TPU, in particular all arguments should be traceable and no optionals
    # are allowed, to avoid graph recompilation on Nones.
26
    temperature: torch.Tensor = None
27

28
    min_p: torch.Tensor = None
29
30
31
32
33
    # Still too slow on forward_native!
    top_k: torch.Tensor = None
    top_p: torch.Tensor = None

    # Greedy sampling flag for compiling single xla graph.
34
    all_greedy: bool = True
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    # Generator not supported by xla
    generators: dict[int,
                     torch.Generator] = field(default_factory=lambda: dict())

    # unsupported, you need to return an extra tensor of static size BxV
    max_num_logprobs = None

    # TODO No penalties for now
    no_penalties: bool = True
    prompt_token_ids = None
    frequency_penalties = None
    presence_penalties = None
    repetition_penalties = None
    # should use tensor
    output_token_ids: list[list[int]] = field(default_factory=lambda: list())

    min_tokens = None  # impl is not vectorized

    logit_bias: list[Optional[dict[int, float]]] = field(
        default_factory=lambda: list())

    allowed_token_ids_mask = None
    bad_words_token_ids = None

    @classmethod
61
    def from_input_batch(
62
63
64
65
66
67
        cls,
        input_batch: InputBatch,
        padded_num_reqs: int,
        xla_device: torch.device,
        generate_params_if_all_greedy: bool = False
    ) -> "TPUSupportedSamplingMetadata":
68
        """
69
70
71
72
        Copy sampling tensors slices from `input_batch` to on device tensors.

        `InputBatch._make_sampling_metadata` causes recompilation on XLA as it 
        slices dynamic shapes on device tensors. This impl moves the dynamic 
73
74
75
76
77
78
79
80
81
82
        ops to CPU and produces tensors of fixed `padded_num_reqs` size.

        Args:
            input_batch: The input batch containing sampling parameters.
            padded_num_reqs: The padded number of requests.
            xla_device: The XLA device.
            generate_params_if_all_greedy: If True, generate sampling parameters
                even if all requests are greedy. this is useful for cases where
                we want to pre-compile a graph with sampling parameters, even if
                they are not strictly needed for greedy decoding.
83
        """
84
85
86
87
88
        # Early return to avoid unnecessary cpu to tpu copy
        if (input_batch.all_greedy is True
                and generate_params_if_all_greedy is False):
            return cls(all_greedy=True)

89
90
        num_reqs = input_batch.num_reqs

91
        def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
92
93
94
            # Pad value is the default one.
            cpu_tensor[num_reqs:padded_num_reqs] = fill_val

95
        fill_slice(input_batch.temperature_cpu_tensor,
96
97
                   DEFAULT_SAMPLING_PARAMS["temperature"])
        # TODO Temporarily disabled until sampling options are enabled
98
99
100
        # fill_slice(input_batch.top_p_cpu_tensor)
        # fill_slice(input_batch.top_k_cpu_tensor)
        fill_slice(input_batch.min_p_cpu_tensor,
101
102
103
104
                   DEFAULT_SAMPLING_PARAMS["min_p"])

        # Slice persistent device tensors to a fixed pre-compiled padded shape.
        return cls(
105
106
107
            temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
            to(xla_device),
            all_greedy=input_batch.all_greedy,
108
109
110
            # TODO enable more and avoid returning None values
            top_p=None,  # input_batch.top_p[:padded_num_reqs],
            top_k=None,  # input_batch.top_k[:padded_num_reqs],
111
112
113
            min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
                xla_device),
            generators=input_batch.generators)