metadata.py 4.72 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from dataclasses import dataclass

5
import numpy as np
6
7
8
import torch

from vllm.pooling_params import PoolingParams
9
from vllm.tasks import PoolingTask
10
from vllm.utils.platform_utils import is_pin_memory_available
11
12
13
14
15
16
17
18
19
20

pin_memory = is_pin_memory_available()


@dataclass
class PoolingCursor:
    index: list[int]
    first_token_indices_gpu: torch.Tensor
    last_token_indices_gpu: torch.Tensor
    prompt_lens_cpu: torch.Tensor
21
    seq_lens_cpu: torch.Tensor
22
23
24
25
26
27
28
29
    num_scheduled_tokens_cpu: torch.Tensor

    def __getitem__(self, indices: slice):
        return PoolingCursor(
            index=self.index[indices],
            first_token_indices_gpu=self.first_token_indices_gpu[indices],
            last_token_indices_gpu=self.last_token_indices_gpu[indices],
            prompt_lens_cpu=self.prompt_lens_cpu[indices],
30
            seq_lens_cpu=self.seq_lens_cpu[indices],
31
32
33
34
            num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
        )

    def is_partial_prefill(self):
35
        return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
36

37
38
39
40
41
42
43
44
45
46
47
48
    def is_finished(self):
        return self.prompt_lens_cpu == self.seq_lens_cpu


class PoolingStates:
    def __init__(self):
        # for chunked prefill with ALL pooling
        self.hidden_states_cache: list[torch.Tensor] = []

    def clean(self):
        self.hidden_states_cache.clear()

49
50
51
52

@dataclass
class PoolingMetadata:
    """Tensors for pooling."""
53

54
    prompt_lens: torch.Tensor  # CPU Tensor
55
    prompt_token_ids: torch.Tensor | None
56
    pooling_params: list[PoolingParams]
57
    pooling_states: list[PoolingStates]
58
    pooling_cursor: PoolingCursor | None = None
59

60
61
62
63
64
65
66
67
68
69
70
71
    def __post_init__(self) -> None:
        pooling_params = self.pooling_params

        tasks: list[PoolingTask] = [
            task
            for pooling_param in pooling_params
            if (task := pooling_param.task) is not None
        ]
        assert len(pooling_params) == len(tasks)

        self.tasks = tasks

72
73
74
    def __getitem__(self, indices: slice):
        return PoolingMetadata(
            prompt_lens=self.prompt_lens[indices],
75
76
77
            prompt_token_ids=None
            if self.prompt_token_ids is None
            else self.prompt_token_ids[indices],
78
            pooling_params=self.pooling_params[indices],
79
            pooling_states=self.pooling_states[indices],
80
            pooling_cursor=None
81
82
            if self.pooling_cursor is None
            else self.pooling_cursor[indices],
83
        )
84

85
86
87
88
89
90
91
92
    def get_prompt_token_ids(self) -> list[torch.Tensor]:
        prompt_token_ids = self.prompt_token_ids
        assert prompt_token_ids is not None, (
            "Please set `requires_token_ids=True` in `get_pooling_updates`"
        )

        return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]

93
94
95
96
97
98
    def get_pooling_cursor(self) -> PoolingCursor:
        pooling_cursor = self.pooling_cursor
        assert pooling_cursor is not None, "Should call `build_pooling_cursor` first"

        return pooling_cursor

99
    def build_pooling_cursor(
100
        self,
101
        num_scheduled_tokens_np: np.ndarray,
102
103
        seq_lens_cpu: torch.Tensor,
        device: torch.device,
104
        query_start_loc_gpu: torch.Tensor | None = None,
105
    ):
106
107
        n_seq = len(num_scheduled_tokens_np)
        prompt_lens = self.prompt_lens
108

109
        assert len(prompt_lens) == n_seq
110

111
112
        index = list(range(n_seq))
        num_scheduled_tokens_cpu = torch.from_numpy(num_scheduled_tokens_np)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        if query_start_loc_gpu is None:
            cumsum = torch.zeros(
                n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu"
            )
            torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
            cumsum = cumsum.to(device, non_blocking=True)
        else:
            if query_start_loc_gpu.shape[0] != n_seq + 1:
                raise ValueError(
                    "query_start_loc_gpu length does not match "
                    f"the number of sequences: {query_start_loc_gpu.shape[0]} "
                    f"!= {n_seq + 1}."
                )
            if query_start_loc_gpu.device != device:
                raise ValueError(
                    "query_start_loc_gpu must be on the same device as the "
                    f"hidden states: {query_start_loc_gpu.device} != {device}."
                )
            cumsum = query_start_loc_gpu
132
133
134
135
136
137
138
139
        self.pooling_cursor = PoolingCursor(
            index=index,
            first_token_indices_gpu=cumsum[:n_seq],
            last_token_indices_gpu=cumsum[1:] - 1,
            prompt_lens_cpu=prompt_lens,
            seq_lens_cpu=seq_lens_cpu,
            num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
        )