"platforms/cuda/vscode:/vscode.git/clone" did not exist on "58e0996fe9f8324cf7754d2f12b97180323b7599"
structured_outputs.py 3.3 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch

from vllm.triton_utils import tl, triton
7
8
9
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
from vllm.v1.worker.gpu.input_batch import InputBatch
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11


12
13
14
15
16
17
18
19
20
21
22
23
class StructuredOutputsWorker:
    def __init__(
        self,
        max_num_logits: int,
        vocab_size: int,
    ):
        # NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor
        # to save a unnecessary CPU-to-CPU copy.
        self.logits_indices = UvaBufferPool(max_num_logits, torch.int32)
        self.grammar_bitmask = UvaBufferPool(
            (max_num_logits, cdiv(vocab_size, 32)), torch.int32
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
26
27
28
29
30
31
32
33
    def apply_grammar_bitmask(
        self,
        logits: torch.Tensor,
        input_batch: InputBatch,
        grammar_req_ids: list[str],
        grammar_bitmask: np.ndarray,
    ) -> None:
        if not grammar_req_ids:
            return
Woosuk Kwon's avatar
Woosuk Kwon committed
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        # Construct bitmask -> logits mapping
        mapping: list[int] = []
        req_ids = input_batch.req_ids
        cu_num_logits = input_batch.cu_num_logits_np.tolist()
        req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)}
        for grammar_req_id in grammar_req_ids:
            req_idx = req_id_to_idx[grammar_req_id]
            logits_start_idx = cu_num_logits[req_idx]
            logits_end_idx = cu_num_logits[req_idx + 1]
            mapping.extend(range(logits_start_idx, logits_end_idx))
        # Copy the mapping.
        mapping_np = np.array(mapping, dtype=np.int32)
        logits_indices = self.logits_indices.copy_to_uva(mapping_np)

        # Copy the bitmask.
        bitmask = self.grammar_bitmask.copy_to_uva(grammar_bitmask)

        num_masks = bitmask.shape[0]
        assert num_masks == len(mapping)
        vocab_size = logits.shape[-1]
        BLOCK_SIZE = 8192
        grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE))
        _apply_grammar_bitmask_kernel[grid](
            logits,
            logits.stride(0),
            logits_indices,
            bitmask,
            bitmask.stride(0),
            vocab_size,
            BLOCK_SIZE=BLOCK_SIZE,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
69
70
71
72
73


# Adapted from
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
@triton.jit
def _apply_grammar_bitmask_kernel(
    logits_ptr,
    logits_stride,
74
    logits_indices_ptr,
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
78
79
    bitmask_ptr,
    bitmask_stride,
    vocab_size,
    BLOCK_SIZE: tl.constexpr,
):
80
81
    bitmask_idx = tl.program_id(0)
    logits_idx = tl.load(logits_indices_ptr + bitmask_idx)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    # Load the bitmask.
    block_id = tl.program_id(1)
    bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
    packed_bitmask = tl.load(
        bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
        mask=bitmask_offset < bitmask_stride,
    )
    # Unpack the bitmask.
    bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
    bitmask = bitmask.reshape(BLOCK_SIZE)

    # Apply the bitmask to the logits.
    block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    tl.store(
        logits_ptr + logits_idx * logits_stride + block_offset,
        -float("inf"),
        mask=bitmask & (block_offset < vocab_size),
    )