pooler.py 17.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import IntEnum
5
from typing import Optional, Union
6
7
8

import torch
import torch.nn as nn
9
10
import torch.nn.functional as F
from typing_extensions import assert_never
11

12
from vllm.config import ModelConfig, PoolerConfig
13
14
15
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
16
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
17
from vllm.transformers_utils.config import (
18
    get_classification_activation_function,
19
    get_cross_encoder_activation_function)
20
21
22
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
23
24
25
26
27


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
28
    ALL = 1
29
    CLS = 2
30
    STEP = 3
31
    MEAN = 4
32
33


34
class SimplePooler(nn.Module):
35
36
37
38
39
40
41
42
    """A layer that pools specific information from hidden states.

    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.

    Attributes:
43
        pooling_type: The type of pooling to use.
44
45
46
        normalize: Whether to normalize the pooled data.
    """

47
48
49
50
51
52
53
    @staticmethod
    def from_pooling_type(
        pooling_type: PoolingType,
        *,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
54
        returned_token_ids: Optional[list[int]] = None,
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    ) -> "SimplePooler":
        if pooling_type == PoolingType.LAST:
            assert step_tag_id is None and returned_token_ids is None
            return LastPool(normalize=normalize, softmax=softmax)
        if pooling_type == PoolingType.ALL:
            assert step_tag_id is None and returned_token_ids is None
            return AllPool(normalize=normalize, softmax=softmax)
        if pooling_type == PoolingType.CLS:
            assert step_tag_id is None and returned_token_ids is None
            return CLSPool(normalize=normalize, softmax=softmax)
        if pooling_type == PoolingType.MEAN:
            assert step_tag_id is None and returned_token_ids is None
            return MeanPool(normalize=normalize, softmax=softmax)
        if pooling_type == PoolingType.STEP:
            return StepPool(normalize=normalize,
                            softmax=softmax,
                            step_tag_id=step_tag_id,
                            returned_token_ids=returned_token_ids)

        assert_never(pooling_type)

    def __init__(self, *, normalize: bool, softmax: bool) -> None:
        super().__init__()

        self.head = PoolerHead(normalize=normalize, softmax=softmax)

    def get_prompt_lens(
        self,
83
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
84
85
        pooling_metadata: PoolingMetadata,
    ) -> torch.Tensor:
86
87
88
        if isinstance(pooling_metadata, V1PoolingMetadata):
            return pooling_metadata.prompt_lens
        assert isinstance(hidden_states, torch.Tensor)
89
90
91
92
93
        return PoolingTensors.from_pooling_metadata(
            pooling_metadata, hidden_states.device).prompt_lens

    def extract_states(
        self,
94
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
95
96
97
98
99
100
101
102
103
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput:
        return PoolingSequenceGroupOutput(data)

    def forward(
        self,
104
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
105
106
107
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
108
        pooled_data = self.head(pooled_data, pooling_metadata)
109
110
111
112
113
114
115
116
        pooled_outputs = [self.build_output(data) for data in pooled_data]
        return PoolerOutput(outputs=pooled_outputs)


class CLSPool(SimplePooler):

    def extract_states(
        self,
117
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
118
119
120
121
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

122
123
124
125
126
127
128
129
        if isinstance(hidden_states, list):
            result = []
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
                    "partial prefill not supported with CLS pooling"
                result.append(req_state[0])
            return result

130
131
132
133
134
135
136
137
138
        first_token_flat_indices = torch.zeros_like(prompt_lens)
        first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
        return hidden_states[first_token_flat_indices]


class LastPool(SimplePooler):

    def extract_states(
        self,
139
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
140
141
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
142
143
144
        if isinstance(hidden_states, list):
            return [h[-1] for h in hidden_states]

145
146
147
148
149
150
151
152
153
154
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

        last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
        return hidden_states[last_token_flat_indices]


class AllPool(SimplePooler):

    def extract_states(
        self,
155
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
156
157
158
159
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

160
161
162
163
164
165
        if isinstance(hidden_states, list):
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
                    "partial prefill not supported with ALL pooling"
            return hidden_states

166
167
168
169
170
171
172
173
174
175
176
177
178
        offset = 0
        pooled_data = list[torch.Tensor]()
        for prompt_len in prompt_lens:
            pooled_data.append(hidden_states[offset:offset + prompt_len])
            offset += prompt_len

        return pooled_data


class MeanPool(SimplePooler):

    def extract_states(
        self,
179
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
180
181
182
183
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

184
185
186
187
188
189
190
191
192
        if isinstance(hidden_states, list):
            result = []
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
                    "partial prefill not supported with mean pooling"
                result.append(torch.mean(req_state, dim=0,
                                         dtype=torch.float32))
            return result

193
194
195
196
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

197
198
199
200
201
202
203
204
205
206
207
        start_indices = torch.cat([
            torch.tensor([0], device=hidden_states.device),
            torch.cumsum(prompt_lens[:-1], dim=0)
        ])
        end_indices = torch.cumsum(prompt_lens, dim=0)
        return (cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


class StepPool(SimplePooler):

208
209
    def __init__(
        self,
210
        *,
211
212
213
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
214
        returned_token_ids: Optional[list[int]] = None,
215
    ):
216
217
218
219
220
        super().__init__(normalize=normalize, softmax=softmax)

        self.step_tag_id = step_tag_id
        self.returned_token_ids = returned_token_ids

221
222
223
224
225
226
227
228
229
230
231
232
233
234
    def get_prompt_token_ids(
        self,
        pooling_metadata: PoolingMetadata,
    ) -> list[torch.Tensor]:
        if isinstance(pooling_metadata, V1PoolingMetadata):
            return [
                pooling_metadata.prompt_token_ids[i, :num]
                for i, num in enumerate(pooling_metadata.prompt_lens)
            ]
        return [
            torch.tensor(seq_data_i.prompt_token_ids)
            for seq_data_i in pooling_metadata.seq_data.values()
        ]

235
236
    def extract_states(
        self,
237
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
238
239
240
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
241
        prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
242

243
        pooled_data_lst = list[torch.Tensor]()
244
245
246
        if isinstance(hidden_states, list):
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
247
248
                    "partial prefill not supported with step pooling"
            pooled_data_lst = hidden_states
249
250
251
252
253
        else:
            offset = 0
            for prompt_len in prompt_lens:
                pooled_data_i = hidden_states[offset:offset + prompt_len]
                offset += prompt_len
254
                pooled_data_lst.append(pooled_data_i)
255

256
        pooled_data = list[torch.Tensor]()
257
        returned_token_ids = self.returned_token_ids
258
259
        step_tag_id = self.step_tag_id

260
        for data, token_id in zip(pooled_data_lst, prompt_token_ids):
261
262
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]
263

264
265
266
            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)
267
268
269
270
271
272
273

        return pooled_data


class PoolerHead(nn.Module):

    def __init__(self, *, normalize: bool, softmax: bool) -> None:
274
        super().__init__()
275

276
        self.normalize = normalize
277
        self.softmax = softmax
278

279
280
281
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

282
283
284
285
286
287
288
        # Using float32 in PoolerHead
        if isinstance(pooled_data, list):
            for i in range(len(pooled_data)):
                pooled_data[i] = pooled_data[i].to(torch.float32)
        else:
            pooled_data = pooled_data.to(torch.float32)

289
        # for matryoshka representation
290
291
292
293
294
295
296
297
298
299
300
        if isinstance(pooling_metadata, V0PoolingMetadata):
            dimensions_list = [
                pooling_param.dimensions
                for _, pooling_param in pooling_metadata.seq_groups
            ]
        else:
            assert isinstance(pooled_data, list)
            dimensions_list = [
                pooling_param.dimensions
                for pooling_param in pooling_metadata.pooling_params
            ]
301
302
303
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
304
305
306
307
308
309
310
311
312
313
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
314

315
316
317
        if self.normalize:
            if isinstance(pooled_data, list):
                pooled_data = [
318
                    F.normalize(data, p=2, dim=-1) for data in pooled_data
319
320
                ]
            else:
321
                pooled_data = F.normalize(pooled_data, p=2, dim=-1)
322
323
324

        if self.softmax:
            if isinstance(pooled_data, list):
325
326
327
328
329
                pooled_data = [
                    F.softmax(data, dim=-1)
                    if data.shape[-1] >= 2 else F.sigmoid(data)
                    for data in pooled_data
                ]
330
            else:
331
332
333
334
                if pooled_data.shape[-1] >= 2:
                    pooled_data = F.softmax(pooled_data, dim=-1)
                else:
                    pooled_data = F.sigmoid(pooled_data)
335

336
337
338
339
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
340
341
342
343
        return pooled_data


class Pooler(nn.Module):
344
345
346
347
348
349
350
351
352

    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
353
        returned_token_ids: Optional[list[int]] = None,
354
355
    ) -> SimplePooler:
        return SimplePooler.from_pooling_type(
356
357
            pooling_type=PoolingType[pooler_config.pooling_type]
            if pooler_config.pooling_type is not None else pooling_type,
358
359
360
361
362
363
364
365
            normalize=pooler_config.normalize
            if pooler_config.normalize is not None else normalize,
            softmax=pooler_config.softmax
            if pooler_config.softmax is not None else softmax,
            step_tag_id=pooler_config.step_tag_id
            if pooler_config.step_tag_id is not None else step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids
            if pooler_config.returned_token_ids is not None else
366
367
            returned_token_ids,
        )
368

369

370
371
class ClassifierPooler(nn.Module):
    """A pooling layer for classification tasks.
372
373

    This layer does the following:
374
375
376
377
378
379
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
    3. Applies an activation function to the output. In the case of
       classification models it is either sigmoid or softmax. In the
       case of scoring models, the same behavior is configuration
       dependent, as in the sentence-transformers library.
380
381
382
383
    """

    def __init__(
        self,
384
        config: ModelConfig,
385
386
387
388
389
390
        classifier: nn.Module,
        pooler: Optional[nn.Module] = None,
    ):
        super().__init__()
        self.classifier = classifier
        self.pooler = pooler
391

392
393
394
395
396
397
398
399
        self.classification_act_fn = get_classification_activation_function(
            config.hf_config)
        self.cross_encoder_act_fn = get_cross_encoder_activation_function(
            config.hf_config)

    def _get_act_fn(self, use_cross_encoder: bool):
        return (self.cross_encoder_act_fn
                if use_cross_encoder else self.classification_act_fn)
400

401
402
403
404
405
406
407
408
409
410
411
    def get_prompt_lens(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> torch.Tensor:
        if isinstance(pooling_metadata, V1PoolingMetadata):
            return pooling_metadata.prompt_lens
        assert isinstance(hidden_states, torch.Tensor)
        return PoolingTensors.from_pooling_metadata(
            pooling_metadata, hidden_states.device).prompt_lens

412
413
    def forward(
        self,
414
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
415
416
417
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        """Pools sentence pair scores from the hidden_states."""
418
        prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
419

420
421
422
423
424
425
426
427
428
429
430
431
        pooled_data = list[torch.Tensor]()
        if isinstance(hidden_states, list):
            for req_state, prompt_len in zip(hidden_states, prompt_lens):
                assert prompt_len == req_state.shape[0], \
                    "partial prefill not supported with classifier"
            pooled_data = hidden_states
        else:
            offset = 0
            for prompt_len in prompt_lens:
                pooled_data_i = hidden_states[offset:offset + prompt_len]
                offset += prompt_len
                pooled_data.append(pooled_data_i)
432
433

        pooled_data_lst = []
434
        for pooled_data_i in pooled_data:
435
436
437
438
439
440
441
442
443
444
445
446
447
448

            if self.pooler is not None:
                final_shape_tensor = self.pooler(pooled_data_i)
            else:
                final_shape_tensor = self.classifier(pooled_data_i)

            pooled_data_lst.append(final_shape_tensor)

        pooled_output = torch.stack(pooled_data_lst)

        if self.pooler is not None:
            # apply classifier once on the full batch if possible
            pooled_output = self.classifier(pooled_output)

449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        if isinstance(pooling_metadata, V0PoolingMetadata):
            use_cross_encoder_list = [
                pooling_param.use_cross_encoder
                for _, pooling_param in pooling_metadata.seq_groups
            ]
        else:
            use_cross_encoder_list = [
                pooling_param.use_cross_encoder
                for pooling_param in pooling_metadata.pooling_params
            ]

        # shape of scores: (batch_size, num_labels)
        if all(use_cross_encoder == use_cross_encoder_list[0]
               for use_cross_encoder in use_cross_encoder_list):
            act_fn = self._get_act_fn(use_cross_encoder_list[0])
            scores = act_fn(pooled_output)
        else:
            scores = torch.stack([
                self._get_act_fn(use_cross_encoder)(vecs)
                for use_cross_encoder, vecs in zip(use_cross_encoder_list,
                                                   pooled_output)
            ])
471
472

        pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
473
        return PoolerOutput(outputs=pooled_outputs)