pooler.py 7.65 KB
Newer Older
1
from enum import IntEnum
2
from typing import List, Optional
3
4
5

import torch
import torch.nn as nn
6
from transformers import PretrainedConfig
7

8
from vllm.config import PoolerConfig
9
10
11
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
                                                  PoolingTensors)
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
12
13
from vllm.transformers_utils.config import (
    get_cross_encoder_activation_function)
14
15
16
17
18


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
19
    ALL = 1
20
    CLS = 2
21
    STEP = 3
22
    MEAN = 4
23
24
25
26
27
28
29
30
31
32
33


class Pooler(nn.Module):
    """A layer that pools specific information from hidden states.

    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.

    Attributes:
34
        pooling_type: The type of pooling to use.
35
36
37
        normalize: Whether to normalize the pooled data.
    """

38
39
40
41
42
43
44
45
    def __init__(
        self,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
        returned_token_ids: Optional[List[int]] = None,
    ):
46
        super().__init__()
47

48
49
        self.pooling_type = pooling_type
        self.normalize = normalize
50
        self.softmax = softmax
51
52
53
54
55
56
57
58
59
60
61
62
        self.step_tag_id = step_tag_id
        self.returned_token_ids = returned_token_ids

    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
        returned_token_ids: Optional[List[int]] = None,
63
    ) -> "Pooler":
64
65
66
        return cls(
            pooling_type=PoolingType[pooler_config.pooling_type]
            if pooler_config.pooling_type is not None else pooling_type,
67
68
69
70
71
72
73
74
            normalize=pooler_config.normalize
            if pooler_config.normalize is not None else normalize,
            softmax=pooler_config.softmax
            if pooler_config.softmax is not None else softmax,
            step_tag_id=pooler_config.step_tag_id
            if pooler_config.step_tag_id is not None else step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids
            if pooler_config.returned_token_ids is not None else
75
76
            returned_token_ids,
        )
77
78
79
80
81
82
83

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        """Pools specific information from hidden states based on metadata."""
84

85
86
87
        prompt_lens = PoolingTensors.from_pooling_metadata(
            pooling_metadata, hidden_states.device).prompt_lens

88
89
90
91
92
93
        if self.pooling_type is PoolingType.CLS:
            first_token_flat_indices = torch.zeros_like(prompt_lens)
            first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
                                                         dim=0)[:-1]
            pooled_data = hidden_states[first_token_flat_indices]
        elif self.pooling_type == PoolingType.LAST:
94
95
            last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
            pooled_data = hidden_states[last_token_flat_indices]
96
97
        elif self.pooling_type == PoolingType.ALL:
            offset = 0
98
            pooled_data = []
99
            for prompt_len in prompt_lens:
100
                pooled_data.append(hidden_states[offset:offset + prompt_len])
101
                offset += prompt_len
102
103
104
105
106
107
108
109
110
111
112
        elif self.pooling_type == PoolingType.MEAN:
            # Calculate mean pooling
            cumsum = torch.cumsum(hidden_states, dim=0)
            start_indices = torch.cat([
                torch.tensor([0], device=hidden_states.device),
                torch.cumsum(prompt_lens[:-1], dim=0)
            ])
            end_indices = torch.cumsum(prompt_lens, dim=0)
            pooled_data = (
                cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
113
        elif self.pooling_type == PoolingType.STEP:
114
115
116
117
118
119
            returned_token_ids = self.returned_token_ids
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                hidden_states = hidden_states[:, returned_token_ids]

            step_tag_id = self.step_tag_id

120
            offset = 0
121
            pooled_data = []
122
123
            for prompt_len, seq_data_i in zip(
                    prompt_lens, pooling_metadata.seq_data.values()):
124
                pooled_data_i = hidden_states[offset:offset + prompt_len]
125
126
127
128
                if step_tag_id is not None:
                    token_ids = torch.tensor(seq_data_i.prompt_token_ids)
                    pooled_data_i = pooled_data_i[token_ids == step_tag_id]

129
                offset += prompt_len
130
                pooled_data.append(pooled_data_i)
131
132
133
134
        else:
            raise ValueError(f"Invalid pooling type: {self.pooling_type}")

        if self.normalize:
135
136
137
138
139
140
141
            if isinstance(pooled_data, list):
                pooled_data = [
                    nn.functional.normalize(data, p=2, dim=1)
                    for data in pooled_data
                ]
            else:
                pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
142

143
        if self.softmax:
144
145
146
147
148
149
            if isinstance(pooled_data, list):
                pooled_data = [
                    nn.functional.softmax(data, dim=-1) for data in pooled_data
                ]
            else:
                pooled_data = nn.functional.softmax(pooled_data, dim=-1)
150

151
152
153
154
155
        pooled_outputs = [
            EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
        ]

        return PoolerOutput(outputs=pooled_outputs)
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216


class CrossEncodingPooler(nn.Module):
    """A layer that pools specific information from hidden states.

    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.

    Attributes:
        pooling_type: The type of pooling to use.
        normalize: Whether to normalize the pooled data.
    """

    def __init__(
        self,
        config: PretrainedConfig,
        classifier: nn.Module,
        pooler: Optional[nn.Module] = None,
    ):
        super().__init__()
        self.classifier = classifier
        self.pooler = pooler
        self.default_activation_function = \
            get_cross_encoder_activation_function(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        """Pools sentence pair scores from the hidden_states."""

        prompt_lens = PoolingTensors.from_pooling_metadata(
            pooling_metadata, hidden_states.device).prompt_lens

        offset = 0
        pooled_data_lst = []
        for prompt_len in prompt_lens:
            pooled_data_i = hidden_states[offset:offset + prompt_len]

            if self.pooler is not None:
                final_shape_tensor = self.pooler(pooled_data_i)
            else:
                final_shape_tensor = self.classifier(pooled_data_i)

            pooled_data_lst.append(final_shape_tensor)
            offset += prompt_len

        pooled_output = torch.stack(pooled_data_lst)

        if self.pooler is not None:
            # apply classifier once on the full batch if possible
            pooled_output = self.classifier(pooled_output)
        logits = self.default_activation_function(pooled_output)

        pooled_outputs = [
            EmbeddingSequenceGroupOutput(data.tolist()) for data in logits
        ]
        return PoolerOutput(outputs=pooled_outputs)