pooler.py 15.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import IntEnum
5
from typing import Optional, Union
6
7
8

import torch
import torch.nn as nn
9
10
import torch.nn.functional as F
from typing_extensions import assert_never
11

12
from vllm.config import ModelConfig, PoolerConfig
13
14
15
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
16
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
17
18
from vllm.transformers_utils.config import (
    get_cross_encoder_activation_function)
19
20
21
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
22
23
24
25
26


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
27
    ALL = 1
28
    CLS = 2
29
    STEP = 3
30
    MEAN = 4
31
32


33
class SimplePooler(nn.Module):
34
35
36
37
38
39
40
41
    """A layer that pools specific information from hidden states.

    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.

    Attributes:
42
        pooling_type: The type of pooling to use.
43
44
45
        normalize: Whether to normalize the pooled data.
    """

46
47
48
49
50
51
52
    @staticmethod
    def from_pooling_type(
        pooling_type: PoolingType,
        *,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
53
        returned_token_ids: Optional[list[int]] = None,
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    ) -> "SimplePooler":
        if pooling_type == PoolingType.LAST:
            assert step_tag_id is None and returned_token_ids is None
            return LastPool(normalize=normalize, softmax=softmax)
        if pooling_type == PoolingType.ALL:
            assert step_tag_id is None and returned_token_ids is None
            return AllPool(normalize=normalize, softmax=softmax)
        if pooling_type == PoolingType.CLS:
            assert step_tag_id is None and returned_token_ids is None
            return CLSPool(normalize=normalize, softmax=softmax)
        if pooling_type == PoolingType.MEAN:
            assert step_tag_id is None and returned_token_ids is None
            return MeanPool(normalize=normalize, softmax=softmax)
        if pooling_type == PoolingType.STEP:
            return StepPool(normalize=normalize,
                            softmax=softmax,
                            step_tag_id=step_tag_id,
                            returned_token_ids=returned_token_ids)

        assert_never(pooling_type)

    def __init__(self, *, normalize: bool, softmax: bool) -> None:
        super().__init__()

        self.head = PoolerHead(normalize=normalize, softmax=softmax)

    def get_prompt_lens(
        self,
82
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
83
84
        pooling_metadata: PoolingMetadata,
    ) -> torch.Tensor:
85
86
87
        if isinstance(pooling_metadata, V1PoolingMetadata):
            return pooling_metadata.prompt_lens
        assert isinstance(hidden_states, torch.Tensor)
88
89
90
91
92
        return PoolingTensors.from_pooling_metadata(
            pooling_metadata, hidden_states.device).prompt_lens

    def extract_states(
        self,
93
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
94
95
96
97
98
99
100
101
102
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput:
        return PoolingSequenceGroupOutput(data)

    def forward(
        self,
103
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
104
105
106
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
107
        pooled_data = self.head(pooled_data, pooling_metadata)
108
109
110
111
112
113
114
115
        pooled_outputs = [self.build_output(data) for data in pooled_data]
        return PoolerOutput(outputs=pooled_outputs)


class CLSPool(SimplePooler):

    def extract_states(
        self,
116
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
117
118
119
120
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

121
122
123
124
125
126
127
128
        if isinstance(hidden_states, list):
            result = []
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
                    "partial prefill not supported with CLS pooling"
                result.append(req_state[0])
            return result

129
130
131
132
133
134
135
136
137
        first_token_flat_indices = torch.zeros_like(prompt_lens)
        first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
        return hidden_states[first_token_flat_indices]


class LastPool(SimplePooler):

    def extract_states(
        self,
138
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
139
140
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
141
142
143
        if isinstance(hidden_states, list):
            return [h[-1] for h in hidden_states]

144
145
146
147
148
149
150
151
152
153
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

        last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
        return hidden_states[last_token_flat_indices]


class AllPool(SimplePooler):

    def extract_states(
        self,
154
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
155
156
157
158
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

159
160
161
162
163
164
        if isinstance(hidden_states, list):
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
                    "partial prefill not supported with ALL pooling"
            return hidden_states

165
166
167
168
169
170
171
172
173
174
175
176
177
        offset = 0
        pooled_data = list[torch.Tensor]()
        for prompt_len in prompt_lens:
            pooled_data.append(hidden_states[offset:offset + prompt_len])
            offset += prompt_len

        return pooled_data


class MeanPool(SimplePooler):

    def extract_states(
        self,
178
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
179
180
181
182
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

183
184
185
186
187
188
189
190
191
        if isinstance(hidden_states, list):
            result = []
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
                    "partial prefill not supported with mean pooling"
                result.append(torch.mean(req_state, dim=0,
                                         dtype=torch.float32))
            return result

192
193
194
195
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

196
197
198
199
200
201
202
203
204
205
206
        start_indices = torch.cat([
            torch.tensor([0], device=hidden_states.device),
            torch.cumsum(prompt_lens[:-1], dim=0)
        ])
        end_indices = torch.cumsum(prompt_lens, dim=0)
        return (cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


class StepPool(SimplePooler):

207
208
    def __init__(
        self,
209
        *,
210
211
212
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
213
        returned_token_ids: Optional[list[int]] = None,
214
    ):
215
216
217
218
219
        super().__init__(normalize=normalize, softmax=softmax)

        self.step_tag_id = step_tag_id
        self.returned_token_ids = returned_token_ids

220
221
222
223
224
225
226
227
228
229
230
231
232
233
    def get_prompt_token_ids(
        self,
        pooling_metadata: PoolingMetadata,
    ) -> list[torch.Tensor]:
        if isinstance(pooling_metadata, V1PoolingMetadata):
            return [
                pooling_metadata.prompt_token_ids[i, :num]
                for i, num in enumerate(pooling_metadata.prompt_lens)
            ]
        return [
            torch.tensor(seq_data_i.prompt_token_ids)
            for seq_data_i in pooling_metadata.seq_data.values()
        ]

234
235
    def extract_states(
        self,
236
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
237
238
239
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
240
        prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
241

242
        pooled_data: list[torch.Tensor] = []
243

244
245
246
247
248
249
250
251
252
253
254
255
256
257
        if isinstance(hidden_states, list):
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
                    "partial prefill not supported with mean pooling"
            pooled_data = hidden_states
        else:
            offset = 0
            for prompt_len in prompt_lens:
                pooled_data_i = hidden_states[offset:offset + prompt_len]
                offset += prompt_len
                pooled_data.append(pooled_data_i)

        pooled_data = []
        returned_token_ids = self.returned_token_ids
258
259
        step_tag_id = self.step_tag_id

260
261
262
        for data, token_id in zip(pooled_data, prompt_token_ids):
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]
263

264
265
266
            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)
267
268
269
270
271
272
273

        return pooled_data


class PoolerHead(nn.Module):

    def __init__(self, *, normalize: bool, softmax: bool) -> None:
274
        super().__init__()
275

276
        self.normalize = normalize
277
        self.softmax = softmax
278

279
280
281
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

282
283
284
285
286
287
288
        # Using float32 in PoolerHead
        if isinstance(pooled_data, list):
            for i in range(len(pooled_data)):
                pooled_data[i] = pooled_data[i].to(torch.float32)
        else:
            pooled_data = pooled_data.to(torch.float32)

289
290
291
292
293
294
295
296
297
298
299
        if isinstance(pooling_metadata, V0PoolingMetadata):
            dimensions_list = [
                pooling_param.dimensions
                for _, pooling_param in pooling_metadata.seq_groups
            ]
        else:
            assert isinstance(pooled_data, list)
            dimensions_list = [
                pooling_param.dimensions
                for pooling_param in pooling_metadata.pooling_params
            ]
300
301
302
303
304
305
306
307
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
            pooled_data = [
                vecs if d is None else vecs[..., :d]
                for vecs, d in zip(pooled_data, dimensions_list)
            ]

308
309
310
        if self.normalize:
            if isinstance(pooled_data, list):
                pooled_data = [
311
                    F.normalize(data, p=2, dim=-1) for data in pooled_data
312
313
                ]
            else:
314
                pooled_data = F.normalize(pooled_data, p=2, dim=-1)
315
316
317

        if self.softmax:
            if isinstance(pooled_data, list):
318
319
320
321
322
                pooled_data = [
                    F.softmax(data, dim=-1)
                    if data.shape[-1] >= 2 else F.sigmoid(data)
                    for data in pooled_data
                ]
323
            else:
324
325
326
327
                if pooled_data.shape[-1] >= 2:
                    pooled_data = F.softmax(pooled_data, dim=-1)
                else:
                    pooled_data = F.sigmoid(pooled_data)
328
329
330
331
332

        return pooled_data


class Pooler(nn.Module):
333
334
335
336
337
338
339
340
341

    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
342
        returned_token_ids: Optional[list[int]] = None,
343
344
    ) -> SimplePooler:
        return SimplePooler.from_pooling_type(
345
346
            pooling_type=PoolingType[pooler_config.pooling_type]
            if pooler_config.pooling_type is not None else pooling_type,
347
348
349
350
351
352
353
354
            normalize=pooler_config.normalize
            if pooler_config.normalize is not None else normalize,
            softmax=pooler_config.softmax
            if pooler_config.softmax is not None else softmax,
            step_tag_id=pooler_config.step_tag_id
            if pooler_config.step_tag_id is not None else step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids
            if pooler_config.returned_token_ids is not None else
355
356
            returned_token_ids,
        )
357

358

359
360
class ClassifierPooler(nn.Module):
    """A pooling layer for classification tasks.
361
362

    This layer does the following:
363
364
365
366
367
368
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
    3. Applies an activation function to the output. In the case of
       classification models it is either sigmoid or softmax. In the
       case of scoring models, the same behavior is configuration
       dependent, as in the sentence-transformers library.
369
370
371
372
    """

    def __init__(
        self,
373
        config: ModelConfig,
374
375
376
377
378
379
        classifier: nn.Module,
        pooler: Optional[nn.Module] = None,
    ):
        super().__init__()
        self.classifier = classifier
        self.pooler = pooler
380
381
382
383
384
385
386
387
388
389

        if config.task == "score":
            self.default_activation_function = \
                get_cross_encoder_activation_function(config.hf_config)
        elif config.task == "classify":
            self.default_activation_function = nn.Sigmoid() \
                if config.hf_config.num_labels == 1 else nn.Softmax()
        else:
            raise NotImplementedError(f"task={config.task!r} is not supported"
                                      " with the classification pooler")
390

391
392
393
394
395
396
397
398
399
400
401
    def get_prompt_lens(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> torch.Tensor:
        if isinstance(pooling_metadata, V1PoolingMetadata):
            return pooling_metadata.prompt_lens
        assert isinstance(hidden_states, torch.Tensor)
        return PoolingTensors.from_pooling_metadata(
            pooling_metadata, hidden_states.device).prompt_lens

402
403
    def forward(
        self,
404
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
405
406
407
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        """Pools sentence pair scores from the hidden_states."""
408
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
409

410
411
412
413
414
415
416
417
418
419
420
421
        pooled_data = list[torch.Tensor]()
        if isinstance(hidden_states, list):
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
                    "partial prefill not supported with classifier"
            pooled_data = hidden_states
        else:
            offset = 0
            for prompt_len in prompt_lens:
                pooled_data_i = hidden_states[offset:offset + prompt_len]
                offset += prompt_len
                pooled_data.append(pooled_data_i)
422
423
424

        offset = 0
        pooled_data_lst = []
425
        for pooled_data_i in pooled_data:
426
427
428
429
430
431
432
433
434
435
436
437
438
439

            if self.pooler is not None:
                final_shape_tensor = self.pooler(pooled_data_i)
            else:
                final_shape_tensor = self.classifier(pooled_data_i)

            pooled_data_lst.append(final_shape_tensor)

        pooled_output = torch.stack(pooled_data_lst)

        if self.pooler is not None:
            # apply classifier once on the full batch if possible
            pooled_output = self.classifier(pooled_output)

440
441
442
        scores = self.default_activation_function(pooled_output).squeeze(-1)

        pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
443
        return PoolerOutput(outputs=pooled_outputs)