pooler.py 21.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from abc import ABC, abstractmethod
from dataclasses import dataclass
5
from enum import IntEnum
6
from typing import Callable, Optional, TypeVar, Union
7
8
9

import torch
import torch.nn as nn
10
import torch.nn.functional as F
11
from transformers import PretrainedConfig
12

13
from vllm.config import ModelConfig, PoolerConfig
14
15
16
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
17
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
18
from vllm.utils import resolve_obj_by_qualname
19
20
21
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
22
23
24
25
26


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
27
    ALL = 1
28
    CLS = 2
29
    STEP = 3
30
    MEAN = 4
31
32


33
34
35
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
36

37
38
39
40
    normalize: bool
    softmax: bool
    step_tag_id: Optional[int]
    returned_token_ids: Optional[list[int]]
41

42
43
44
45
    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
46
47
48
49
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
50
        returned_token_ids: Optional[list[int]] = None,
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    ) -> "ResolvedPoolingConfig":
        return cls(
            pooling_type=PoolingType[pooler_config.pooling_type]
            if pooler_config.pooling_type is not None else pooling_type,
            normalize=pooler_config.normalize
            if pooler_config.normalize is not None else normalize,
            softmax=pooler_config.softmax
            if pooler_config.softmax is not None else softmax,
            step_tag_id=pooler_config.step_tag_id
            if pooler_config.step_tag_id is not None else step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids
            if pooler_config.returned_token_ids is not None else
            returned_token_ids,
        )
65
66


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        return pooling_metadata.prompt_lens

    assert isinstance(hidden_states, torch.Tensor)
    return PoolingTensors.from_pooling_metadata(
        pooling_metadata, hidden_states.device).prompt_lens


def get_classification_activation_function(config: PretrainedConfig):
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
98

99
    return PoolerScore()
100

101
102
103
104
105
106
107
108
109
110

def build_output(all_data: torch.Tensor) -> PoolerOutput:
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class BasePooler(nn.Module):

    @abstractmethod
    def forward(
111
        self,
112
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
113
        pooling_metadata: PoolingMetadata,
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    ) -> PoolerOutput:
        raise NotImplementedError


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

    @abstractmethod
    def forward_one(
        self,
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
138
    ) -> torch.Tensor:
139
140
141
142
143
        """
        Note:
            `prompt_len=None` means `prompt_len=len(hidden_states)`.
        """
        raise NotImplementedError
144

145
146
    @abstractmethod
    def forward_all(
147
        self,
148
149
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
150
151
152
153
154
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
155
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
156
        pooling_metadata: PoolingMetadata,
157
158
159
160
161
162
163
164
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)

        if isinstance(hidden_states, list):
            return [
                self.forward_one(h, prompt_len)
                for h, prompt_len in zip(hidden_states, prompt_lens)
            ]
165

166
        return self.forward_all(hidden_states, prompt_lens)
167
168


169
170
171
class CLSPool(PoolingMethod):

    def forward_one(
172
        self,
173
174
175
176
177
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with CLS pooling"
178

179
        return hidden_states[0]
180

181
182
183
184
185
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
186
187
188
189
190
        first_token_flat_indices = torch.zeros_like(prompt_lens)
        first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
        return hidden_states[first_token_flat_indices]


191
class LastPool(PoolingMethod):
192

193
    def forward_one(
194
        self,
195
196
197
198
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return hidden_states[-1]
199

200
201
202
203
204
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
205
206
207
208
        last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
        return hidden_states[last_token_flat_indices]


209
class AllPool(PoolingMethod):
210

211
    def forward_one(
212
        self,
213
214
215
216
217
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with ALL pooling"
218

219
        return hidden_states
220

221
222
223
224
225
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
226
227
        offset = 0
        pooled_data = list[torch.Tensor]()
228

229
230
231
232
233
234
235
        for prompt_len in prompt_lens:
            pooled_data.append(hidden_states[offset:offset + prompt_len])
            offset += prompt_len

        return pooled_data


236
class MeanPool(PoolingMethod):
237

238
    def forward_one(
239
        self,
240
241
242
243
244
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with MEAN pooling"
245

246
        return hidden_states.mean(dim=0, dtype=torch.float32)
247

248
249
250
251
252
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
253
254
255
256
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

257
258
259
260
261
262
263
264
265
        start_indices = torch.cat([
            torch.tensor([0], device=hidden_states.device),
            torch.cumsum(prompt_lens[:-1], dim=0)
        ])
        end_indices = torch.cumsum(prompt_lens, dim=0)
        return (cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


266
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
267
268


269
class BasePoolerActivation(nn.Module, ABC):
270

271
272
273
274
275
276
277
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
278

279

280
class PoolerActivation(BasePoolerActivation):
281

282
283
284
285
286
287
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
288

289
290
291
292
293
294
295
296
297
298
299
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
300
301


302
303
304
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
305
306
307
        return pooled_data


308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


class PoolerClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)


class PoolerScore(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return pooled_data


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


346
347
class PoolerHead(nn.Module):

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
    ) -> "PoolerHead":
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
            pooling_type=pooling_type,
            normalize=normalize,
            softmax=softmax,
            step_tag_id=None,
            returned_token_ids=None,
        )

        return cls.from_config(resolved_config)

    @classmethod
    def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
        if pooler_config.normalize and pooler_config.softmax:
            raise ValueError("`normalize=True` and `softmax=True` should not "
                             "be set together")

        activation: PoolerActivation
        if pooler_config.normalize:
            activation = PoolerNormalize()
        elif pooler_config.softmax:
            activation = PoolerClassify()
        else:
            activation = PoolerIdentity()

        return cls(activation)

    def __init__(self, activation: PoolerActivation) -> None:
384
        super().__init__()
385

386
        self.activation = activation
387

388
389
390
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

391
392
393
394
395
396
397
        # Using float32 in PoolerHead
        if isinstance(pooled_data, list):
            for i in range(len(pooled_data)):
                pooled_data[i] = pooled_data[i].to(torch.float32)
        else:
            pooled_data = pooled_data.to(torch.float32)

398
        # for matryoshka representation
399
400
401
402
403
404
405
406
407
408
409
        if isinstance(pooling_metadata, V0PoolingMetadata):
            dimensions_list = [
                pooling_param.dimensions
                for _, pooling_param in pooling_metadata.seq_groups
            ]
        else:
            assert isinstance(pooled_data, list)
            dimensions_list = [
                pooling_param.dimensions
                for pooling_param in pooling_metadata.pooling_params
            ]
410
411
412
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
413
414
415
416
417
418
419
420
421
422
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
423

424
        return self.activation(pooled_data)
425
426


427
428
class SimplePooler(BasePooler):
    """A layer that pools specific information from hidden states.
429

430
431
432
433
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
434

435
436
437
438
    Attributes:
        pooling_type: The type of pooling to use.
        normalize: Whether to normalize the pooled data.
    """
439
440
441
442
443
444
445
446

    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    ) -> "SimplePooler":
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
            pooling_type=pooling_type,
            normalize=normalize,
            softmax=softmax,
        )
        assert resolved_config.pooling_type != PoolingType.STEP

        return cls.from_config(resolved_config)

    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
        head = PoolerHead.from_config(pooler_config)

        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


class StepPooler(BasePooler):

    @classmethod
    def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
        assert pooler_config.pooling_type == PoolingType.STEP

        return cls(
            PoolerHead.from_config(pooler_config),
            step_tag_id=pooler_config.step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids,
        )

    def __init__(
        self,
        head: PoolerHead,
        *,
500
        step_tag_id: Optional[int] = None,
501
        returned_token_ids: Optional[list[int]] = None,
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    ) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head
        self.step_tag_id = step_tag_id
        self.returned_token_ids = returned_token_ids

    def get_prompt_token_ids(
        self,
        pooling_metadata: PoolingMetadata,
    ) -> list[torch.Tensor]:
        if isinstance(pooling_metadata, V1PoolingMetadata):
            return [
                pooling_metadata.prompt_token_ids[i, :num]
                for i, num in enumerate(pooling_metadata.prompt_lens)
            ]
        return [
            torch.tensor(seq_data_i.prompt_token_ids)
            for seq_data_i in pooling_metadata.seq_data.values()
        ]

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
        prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)

        pooled_data = list[torch.Tensor]()
        returned_token_ids = self.returned_token_ids
        step_tag_id = self.step_tag_id

        for data, token_id in zip(pooled_data_lst, prompt_token_ids):
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


class Pooler(nn.Module):

    @staticmethod
    def from_config_with_defaults(
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
        returned_token_ids: Optional[list[int]] = None,
    ) -> BasePooler:
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
            pooling_type=pooling_type,
            normalize=normalize,
            softmax=softmax,
            step_tag_id=step_tag_id,
            returned_token_ids=returned_token_ids,
574
        )
575

576
577
578
579
580
581
582
583
584
585
586
        if pooling_type == PoolingType.STEP:
            return StepPooler.from_config(resolved_config)

        return SimplePooler.from_config(resolved_config)


PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]

587

588
589
class ClassifierPooler(nn.Module):
    """A pooling layer for classification tasks.
590
591

    This layer does the following:
592
593
594
595
596
597
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
    3. Applies an activation function to the output. In the case of
       classification models it is either sigmoid or softmax. In the
       case of scoring models, the same behavior is configuration
       dependent, as in the sentence-transformers library.
598
599
600
601
    """

    def __init__(
        self,
602
        config: ModelConfig,
603
604
605
606
        pooling: PoolingFn,
        classifier: ClassifierFn,
        act_fn: Optional[PoolerActivation] = None,
    ) -> None:
607
        super().__init__()
608
609

        self.pooling = pooling
610
        self.classifier = classifier
611

612
        self.classification_act_fn = get_classification_activation_function(
613
            config.hf_config) if act_fn is None else act_fn
614
        self.cross_encoder_act_fn = get_cross_encoder_activation_function(
615
            config.hf_config) if act_fn is None else act_fn
616
617
618
619

    def _get_act_fn(self, use_cross_encoder: bool):
        return (self.cross_encoder_act_fn
                if use_cross_encoder else self.classification_act_fn)
620
621
622

    def forward(
        self,
623
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
624
625
626
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        """Pools sentence pair scores from the hidden_states."""
627
        pooled_data = self.pooling(hidden_states, pooling_metadata)
628

629
630
631
632
633
        # apply classifier once on the full batch if possible
        if isinstance(pooled_data, torch.Tensor):
            pooled_output = self.classifier(pooled_data)
        elif len({data.shape for data in pooled_data}) <= 1:
            pooled_output = self.classifier(torch.stack(pooled_data))
634
        else:
635
            pooled_output = [self.classifier(data) for data in pooled_data]
636

637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        if isinstance(pooling_metadata, V0PoolingMetadata):
            use_cross_encoder_list = [
                pooling_param.use_cross_encoder
                for _, pooling_param in pooling_metadata.seq_groups
            ]
        else:
            use_cross_encoder_list = [
                pooling_param.use_cross_encoder
                for pooling_param in pooling_metadata.pooling_params
            ]

        # shape of scores: (batch_size, num_labels)
        if all(use_cross_encoder == use_cross_encoder_list[0]
               for use_cross_encoder in use_cross_encoder_list):
            act_fn = self._get_act_fn(use_cross_encoder_list[0])
            scores = act_fn(pooled_output)
        else:
            scores = torch.stack([
                self._get_act_fn(use_cross_encoder)(vecs)
                for use_cross_encoder, vecs in zip(use_cross_encoder_list,
                                                   pooled_output)
            ])
659

660
        return build_output(scores)