pooling_metadata.py 3.03 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import dataclass
5
from typing import Any, Optional
6
7
8
9
10

import torch

from vllm.pooling_params import PoolingParams
from vllm.utils import is_pin_memory_available
11
from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class PoolingMetadata:
    """Metadata for pooling operations in the Pooler layer.

    This class holds the necessary information for pooling operations,
    providing context for how to perform pooling and other related operations.

    Attributes:
        seq_groups: List of (seq_ids, pooling_params).
        seq_data: A mapping of sequence ID to additional sequence data.
        prompt_lens: List of the lengths of each prompt.
    """

    def __init__(
27
28
29
30
31
            self,
            seq_groups: list[tuple[list[int], PoolingParams]],
            seq_data: dict[int, Any],  # Specific data related to sequences
            prompt_lens: list[int],
            pooling_cursor: Optional[PoolingCursor] = None) -> None:
32
33
34
        self.seq_groups = seq_groups
        self.seq_data = seq_data
        self.prompt_lens = prompt_lens
35
        self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor
36
37
38
39
40
41
42

    def __repr__(self) -> str:
        return ("PoolingMetadata("
                f"seq_groups={self.seq_groups}, "
                f"seq_data={self.seq_data}, "
                f"prompt_lens={self.prompt_lens})")

43
44
45
46
47
    def __getitem__(self, indices: slice):
        return PoolingMetadata(
            seq_groups=self.seq_groups[indices],
            seq_data=dict(list(self.seq_data.items())[indices]),
            prompt_lens=self.prompt_lens[indices],
48
49
            pooling_cursor=None
            if self.pooling_cursor is None else self.pooling_cursor[indices],
50
51
        )

52
53
54
55
56
57
58
    def build_pooling_cursor(self, num_scheduled_tokens: list[int],
                             device: torch.device):
        prompt_lens = torch.tensor(self.prompt_lens, device="cpu")
        self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
                                                   prompt_lens,
                                                   device=device)

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

@dataclass
class PoolingTensors:
    """Tensors for pooling."""

    prompt_lens: torch.Tensor

    @classmethod
    def from_pooling_metadata(
        cls,
        pooling_metadata: "PoolingMetadata",
        device: torch.device,
    ) -> "PoolingTensors":
        """
        Create PoolingTensors from PoolingMetadata.

        Args:
            pooling_metadata: PoolingMetadata instance to convert.
            device: Device to store the tensors.
        """
        # Convert prompt lengths to tensor
        pin_memory = is_pin_memory_available()

        prompt_lens_t = torch.tensor(
            pooling_metadata.prompt_lens,
            device="cpu",
            dtype=torch.long,
            pin_memory=pin_memory,
        )

        return cls(prompt_lens=prompt_lens_t.to(device=device,
                                                non_blocking=True), )