"vscode:/vscode.git/clone" did not exist on "53fc16640281a939e01c98be80fa67179ba9a088"
pooler.py 5.41 KB
Newer Older
1
from enum import IntEnum
2
from typing import List, Optional
3
4
5
6

import torch
import torch.nn as nn

7
from vllm.config import PoolerConfig
8
9
10
11
12
13
14
15
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
                                                  PoolingTensors)
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
16
    ALL = 1
17
    CLS = 2
18
    STEP = 3
19
    MEAN = 4
20
21
22
23
24
25
26
27
28
29
30


class Pooler(nn.Module):
    """A layer that pools specific information from hidden states.

    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.

    Attributes:
31
        pooling_type: The type of pooling to use.
32
33
34
        normalize: Whether to normalize the pooled data.
    """

35
36
37
38
39
40
41
42
    def __init__(
        self,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
        returned_token_ids: Optional[List[int]] = None,
    ):
43
        super().__init__()
44

45
46
        self.pooling_type = pooling_type
        self.normalize = normalize
47
        self.softmax = softmax
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        self.step_tag_id = step_tag_id
        self.returned_token_ids = returned_token_ids

    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
        returned_token_ids: Optional[List[int]] = None,
    ) -> Optional["Pooler"]:
        if pooler_config is None:
            return None
        return cls(
            pooling_type=PoolingType[pooler_config.pooling_type]
            if pooler_config.pooling_type is not None else pooling_type,
            normalize=pooler_config.pooling_norm
            if pooler_config.pooling_norm is not None else normalize,
            softmax=pooler_config.pooling_softmax
            if pooler_config.pooling_softmax is not None else softmax,
            step_tag_id=pooler_config.pooling_step_tag_id
            if pooler_config.pooling_step_tag_id is not None else step_tag_id,
            returned_token_ids=pooler_config.pooling_returned_token_ids
            if pooler_config.pooling_returned_token_ids is not None else
            returned_token_ids,
        )
76
77
78
79
80
81
82

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        """Pools specific information from hidden states based on metadata."""
83

84
85
86
        prompt_lens = PoolingTensors.from_pooling_metadata(
            pooling_metadata, hidden_states.device).prompt_lens

87
88
89
90
91
92
        if self.pooling_type is PoolingType.CLS:
            first_token_flat_indices = torch.zeros_like(prompt_lens)
            first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
                                                         dim=0)[:-1]
            pooled_data = hidden_states[first_token_flat_indices]
        elif self.pooling_type == PoolingType.LAST:
93
94
            last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
            pooled_data = hidden_states[last_token_flat_indices]
95
96
97
98
99
100
        elif self.pooling_type == PoolingType.ALL:
            offset = 0
            pooled_data = []
            for prompt_len in prompt_lens:
                pooled_data.append(hidden_states[offset:offset + prompt_len])
                offset += prompt_len
101
102
103
104
105
106
107
108
109
110
111
        elif self.pooling_type == PoolingType.MEAN:
            # Calculate mean pooling
            cumsum = torch.cumsum(hidden_states, dim=0)
            start_indices = torch.cat([
                torch.tensor([0], device=hidden_states.device),
                torch.cumsum(prompt_lens[:-1], dim=0)
            ])
            end_indices = torch.cumsum(prompt_lens, dim=0)
            pooled_data = (
                cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        elif self.pooling_type == PoolingType.STEP:
            if self.returned_token_ids is not None and len(
                    self.returned_token_ids) > 0:
                logits = hidden_states[:,
                                       self.returned_token_ids].softmax(dim=-1)
            else:
                logits = hidden_states.softmax(dim=-1)
            offset = 0
            pooled_data = []
            for prompt_len, seq_data_i in zip(
                    prompt_lens, pooling_metadata.seq_data.values()):
                if self.step_tag_id is None:
                    pooled_data.append(logits[offset:offset + prompt_len])
                else:
                    step_idxs = torch.tensor(
                        seq_data_i.prompt_token_ids) == self.step_tag_id
                    pooled_data.append(logits[offset:offset +
                                              prompt_len][step_idxs])
                offset += prompt_len
131
132
133
134
135
136
        else:
            raise ValueError(f"Invalid pooling type: {self.pooling_type}")

        if self.normalize:
            pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)

137
138
139
        if self.softmax:
            pooled_data = nn.functional.softmax(pooled_data, dim=-1)

140
141
142
143
144
        pooled_outputs = [
            EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
        ]

        return PoolerOutput(outputs=pooled_outputs)