pooler.py 27.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Callable, Mapping, Set
5
from dataclasses import dataclass
6
from itertools import groupby
7
from typing import TypeAlias, TypeVar
8
9
10

import torch
import torch.nn as nn
11
import torch.nn.functional as F
12
from transformers import PretrainedConfig
13

14
15
from vllm.config import ModelConfig, get_current_vllm_config
from vllm.config.pooler import PoolerConfig, PoolingTypeStr
16
from vllm.logger import init_logger
17
from vllm.model_executor.models.adapters import _load_st_projector
18
19
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
20
from vllm.utils.import_utils import resolve_obj_by_qualname
21
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokenwisePoolerOutput
22
from vllm.v1.pool.metadata import PoolingMetadata
23

24
25
logger = init_logger(__name__)

26
PoolingFn = Callable[
27
28
    [torch.Tensor | list[torch.Tensor], PoolingMetadata],
    torch.Tensor | list[torch.Tensor],
29
]
30
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
31
32


33
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
34
35
36
TokenwisePoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokenwisePoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokenwisePoolingMethodOutput
37
38

TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
39
TokenwisePoolerHeadOutput: TypeAlias = torch.Tensor | None
40
41


42
43
@dataclass(frozen=True)
class ResolvedPoolingConfig:
44
    pooling_type: PoolingTypeStr
45
    task: PoolingTask
46

47
    @classmethod
48
    def from_config(
49
        cls,
50
        task: PoolingTask,
51
52
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
53
        assert pooler_config.pooling_type is not None
54
        return cls(task=task, pooling_type=pooler_config.pooling_type)
55
56


57
58
59
60
61
62
63
64
65
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


66
def get_classification_activation_function(config: PretrainedConfig):
67
68
69
70
71
72
73
74
75
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
76
77
78
79
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
80
    function_name: str | None = None
81
82
83
84
    if (
        hasattr(config, "sentence_transformers")
        and "activation_fn" in config.sentence_transformers
    ):
85
        function_name = config.sentence_transformers["activation_fn"]
86
87
88
89
    elif (
        hasattr(config, "sbert_ce_default_activation_function")
        and config.sbert_ce_default_activation_function is not None
    ):
90
91
92
93
94
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
95
96
            "torch.nn.modules for security reasons"
        )
97
98
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
99

100
    return PoolerClassify()
101

102
103
104

class PoolingMethod(nn.Module, ABC):
    @staticmethod
105
106
    def from_pooling_type(pooling_type: PoolingTypeStr) -> "PoolingMethod":
        if pooling_type == "LAST":
107
            return LastPool()
108
        if pooling_type == "ALL":
109
            return AllPool()
110
        if pooling_type == "CLS":
111
            return CLSPool()
112
        if pooling_type == "MEAN":
113
            return MeanPool()
114
115
116
117
118
        if pooling_type == "STEP":
            raise ValueError(
                "'STEP' pooling is handled by StepPooler "
                "and is not a standalone PoolingMethod."
            )
119

120
        raise NotImplementedError(f"Unsupported method: {pooling_type!r}")
121

122
    @abstractmethod
123
    def get_supported_tasks(self) -> Set[PoolingTask]:
124
125
        raise NotImplementedError

126
127
128
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

129
    @abstractmethod
130
131
    def forward(
        self,
132
        hidden_states: torch.Tensor,
133
        pooling_metadata: PoolingMetadata,
134
135
    ) -> PoolingMethodOutput:
        raise NotImplementedError
136
137


138
class CLSPool(PoolingMethod):
139
    def get_supported_tasks(self) -> Set[PoolingTask]:
140
        return {"token_embed", "token_classify", "embed", "classify", "score"}
141

142
    def forward(
143
144
        self,
        hidden_states: torch.Tensor,
145
146
147
        pooling_metadata: PoolingMetadata,
    ) -> TokenPoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
148
        assert not pooling_cursor.is_partial_prefill(), (
149
            "partial prefill not supported with CLS pooling"
150
        )
151
152

        return hidden_states[pooling_cursor.first_token_indices_gpu]
153
154


155
class LastPool(PoolingMethod):
156
    def get_supported_tasks(self) -> Set[PoolingTask]:
157
        return {"token_embed", "token_classify", "embed", "classify", "score"}
158

159
    def forward(
160
161
        self,
        hidden_states: torch.Tensor,
162
163
164
        pooling_metadata: PoolingMetadata,
    ) -> TokenPoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
165
        return hidden_states[pooling_cursor.last_token_indices_gpu]
166
167


168
class AllPool(PoolingMethod):
169
170
171
172
173
174
175
176
    def __init__(self):
        super().__init__()

        vllm_config = get_current_vllm_config()
        self.enable_chunked_prefill = (
            vllm_config.scheduler_config.enable_chunked_prefill
        )

177
    def get_supported_tasks(self) -> Set[PoolingTask]:
178
        return {"token_embed", "token_classify"}
179

180
181
182
183
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
184
    ) -> TokenwisePoolingMethodOutput:
185
        pooling_cursor = pooling_metadata.get_pooling_cursor()
186
187
        hidden_states_all = hidden_states.split(
            pooling_cursor.num_scheduled_tokens_cpu.tolist()
188
        )
189
        hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
190
191
192
193
194
195
196
197
198
199
200
201

        if not self.enable_chunked_prefill:
            return hidden_states_lst

        pooling_states = pooling_metadata.pooling_states

        # If chunked_prefill is enabled
        # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
        for p, hs_chunk in zip(pooling_states, hidden_states_lst):
            p.hidden_states_cache.append(hs_chunk)

        # 2. Once prefill is finished, send hidden_states_cache to PoolerHead
202
        output_list = list[torch.Tensor | None]()
203
        for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
204
205
206
207
208
209
210
211
212
213
214
            if finished:
                hidden_states_cache = p.hidden_states_cache
                if len(hidden_states_cache) == 1:
                    output_list.append(hidden_states_cache[0])
                else:
                    output_list.append(torch.concat(hidden_states_cache, dim=0))
                p.clean()
            else:
                output_list.append(None)

        return output_list
215
216


217
class MeanPool(PoolingMethod):
218
    def get_supported_tasks(self) -> Set[PoolingTask]:
219
        return {"token_embed", "token_classify", "embed", "classify", "score"}
220

221
    def forward(
222
        self,
223
        hidden_states: torch.Tensor,
224
225
226
        pooling_metadata: PoolingMetadata,
    ) -> TokenPoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
227
        assert not pooling_cursor.is_partial_prefill(), (
228
            "partial prefill not supported with MEAN pooling"
229
        )
230

231
232
233
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(
            hidden_states.device, non_blocking=True
        )
234

235
236
237
238
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

239
240
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
241
242
243
        return (
            cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
        ) / prompt_lens.unsqueeze(1)
244
245


246
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
247
248


249
250
251
252
253
254
255
256
class BasePoolerActivation(nn.Module, ABC):
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
257

258

259
260
261
262
263
264
265
class PoolerActivation(BasePoolerActivation):
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
266

267
268
269
270
271
272
273
274
275
276
277
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
278
279


280
281
class PoolerIdentity(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
282
283
284
        return pooled_data


285
286
class PoolerNormalize(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
287
        return F.normalize(pooled_data, p=2, dim=-1)
288
289


290
291
class PoolerMultiLabelClassify(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
292
        return F.sigmoid(pooled_data)
293
294


295
class PoolerClassify(PoolerActivation):
296
297
    def __init__(self, *, static_num_labels: bool = True) -> None:
        super().__init__()
298

299
300
        if static_num_labels:
            vllm_config = get_current_vllm_config()
301
302
303
            self.num_labels = getattr(
                vllm_config.model_config.hf_config, "num_labels", 0
            )
304
            if self.num_labels == 0:
305
306
307
308
309
                logger.warning(
                    "num_labels should be > 0 for classification"
                    "models, falling back to softmax. "
                    "Please check if the configuration is correct."
                )
310
311
        else:
            self.num_labels = None
312
313

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
314
315
316
        num_labels = (
            self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
        )
317

318
        if num_labels < 2:
319
            return F.sigmoid(pooled_data)
320

321
        return F.softmax(pooled_data, dim=-1)
322
323
324
325
326
327
328
329
330
331
332
333


class LambdaPoolerActivation(PoolerActivation):
    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
    def for_token_embed(pooler_config: PoolerConfig):
        head = TokenEmbeddingPoolerHead()

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_token_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None = None,
        act_fn: PoolerActivation | str | None = None,
    ):
        head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="embed",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
        head = EmbeddingPoolerHead()

        return SimplePooler(pooling=pooling, head=head)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="classify",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)

        return ClassifierPooler(
            pooling=pooling,
            classifier=classifier,
            act_fn=act_fn,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        """
        Construct the updated pooling parameters to use for a supported task.
        """
        return PoolingParamsUpdate()

    @abstractmethod
    def forward(
        self,
404
        hidden_states: torch.Tensor,
405
406
407
408
409
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


410
411
412
413
414
415
class DummyPooler(Pooler):
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"plugin", "score"}

    def forward(
        self,
416
        hidden_states: torch.Tensor,
417
418
419
420
421
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        return hidden_states


422
423
class TokenPoolerHead(nn.Module, ABC):
    """Applicable to pooling strategies that output one token."""
424

425
    @abstractmethod
426
427
    def forward(
        self,
428
        pooled_data: TokenPoolingMethodOutput,
429
        pooling_metadata: PoolingMetadata,
430
431
    ) -> TokenPoolerHeadOutput:
        raise NotImplementedError
432
433


434
class EmbeddingPoolerHead(TokenPoolerHead):
435
    def __init__(self) -> None:
436
        super().__init__()
437

438
439
        # Load ST projector if available
        vllm_config = get_current_vllm_config()
440
        self.projector = (
441
442
            _load_st_projector(vllm_config.model_config) if vllm_config else None
        )
443
        self.head_dtype = vllm_config.model_config.head_dtype
444

445
446
        self.activation = PoolerNormalize()

447
448
    def forward(
        self,
449
        pooled_data: TokenPoolingMethodOutput,
450
        pooling_metadata: PoolingMetadata,
451
    ) -> TokenPoolerHeadOutput:
452
453
454
455
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_dimension]

456
457
        pooled_data = pooled_data.to(self.head_dtype)

458
459
        # Apply ST projector
        if self.projector is not None:
460
            pooled_data = self.projector(pooled_data)
461
        # pooled_data shape: [batchsize, embedding_dimension]
462

463
        pooling_params = pooling_metadata.pooling_params
464

465
        # for matryoshka representation
466
        dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
467
468
469
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
470
            if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
471
472
473
474
475
476
477
478
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
479

480
481
482
483
484
485
486
487
488
489
490
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

491
        # pooled_data shape: [batchsize, embedding_dimension]
492
493
494
        return pooled_data


495
class SimplePooler(Pooler):
496
    """A layer that pools specific information from hidden states.
497

498
499
500
501
502
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
503

504
    def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None:
505
506
507
508
509
        super().__init__()

        self.pooling = pooling
        self.head = head

510
511
512
513
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
514
        return self.pooling.get_pooling_updates(task)
515

516
517
    def forward(
        self,
518
        hidden_states: torch.Tensor,
519
        pooling_metadata: PoolingMetadata,
520
    ) -> TokenPoolerHeadOutput:
521
522
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
523
        return pooled_data
524
525


526
class ClassifierPooler(Pooler):
527
    """A pooling layer for classification tasks.
528
529

    This layer does the following:
530
531
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
532
    3. Applies an activation function to the output.
533
534
    """

535
    @staticmethod
536
537
538
539
540
541
    def act_fn_for_seq_cls(model_config: ModelConfig):
        return get_classification_activation_function(model_config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(model_config: ModelConfig):
        return get_cross_encoder_activation_function(model_config.hf_config)
542
543

    @staticmethod
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    def resolve_act_fn(
        model_config: ModelConfig,
        static_num_labels: bool = True,
        act_fn: PoolerActivation | str | None = None,
    ):
        if isinstance(act_fn, str):
            if act_fn == "classify":
                return ClassifierPooler.act_fn_for_seq_cls(model_config)
            elif act_fn == "score":
                return ClassifierPooler.act_fn_for_cross_encoder(model_config)
            else:
                raise ValueError(f"act_fn [{act_fn=}] not supported.")
        elif act_fn is None:
            return PoolerClassify(static_num_labels=static_num_labels)
        else:
            assert callable(act_fn)
            return act_fn
561

562
563
    def __init__(
        self,
564
        pooling: PoolingFn,
565
        classifier: ClassifierFn | None,
566
        act_fn: PoolerActivation | str | None = None,
567
    ) -> None:
568
        super().__init__()
569

570
        vllm_config = get_current_vllm_config()
571
        self.pooling = pooling
572
        self.classifier = classifier
573
574
575
        self.act_fn = self.resolve_act_fn(
            vllm_config.model_config, static_num_labels=True, act_fn=act_fn
        )
576
        self.logit_bias: float | None = (
577
578
            vllm_config.model_config.pooler_config.logit_bias
        )
579
        self.head_dtype = vllm_config.model_config.head_dtype
580

581
582
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
583

584
585
    def forward(
        self,
586
        hidden_states: torch.Tensor,
587
        pooling_metadata: PoolingMetadata,
588
    ) -> TokenPoolerOutput:
589
        pooled_data = self.pooling(hidden_states, pooling_metadata)
590
591
592
593
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

594
595
        pooled_data = pooled_data.to(self.head_dtype)

596
        if self.classifier is not None:
597
598
            pooled_data = self.classifier(pooled_data)
        # pooled_data shape: [batchsize, num_labels]
599

600
601
602
        if self.logit_bias is not None:
            pooled_data -= self.logit_bias

603
        pooling_params = pooling_metadata.pooling_params
604
        flags = [p.use_activation for p in pooling_params]
605
606
607
608
609

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
610
                self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
611
            ]
612

613
        # scores shape: [batchsize, num_labels]
614
        return scores
615
616


617
class TokenwisePoolerHead(nn.Module, ABC):
618
619
620
    """Applicable to pooling strategies that output multiple tokens."""

    @abstractmethod
621
    def forward(
622
        self,
623
        pooled_data: TokenwisePoolingMethodOutputItem,
624
        pooling_param: PoolingParams,
625
    ) -> TokenwisePoolerHeadOutput:
626
627
628
        raise NotImplementedError


629
class TokenEmbeddingPoolerHead(TokenwisePoolerHead):
630
631
632
633
634
635
636
637
638
639
640
641
642
643
    def __init__(self) -> None:
        super().__init__()

        # Load ST projector if available
        vllm_config = get_current_vllm_config()
        self.projector = (
            _load_st_projector(vllm_config.model_config) if vllm_config else None
        )
        self.head_dtype = vllm_config.model_config.head_dtype

        self.activation = PoolerNormalize()

    def forward(
        self,
644
        pooled_data: TokenwisePoolingMethodOutputItem,
645
        pooling_param: PoolingParams,
646
    ) -> TokenwisePoolerHeadOutput:
647
648
649
650
        # for unfinished chunked prefill
        if pooled_data is None:
            return None

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
        pooled_data = pooled_data.to(self.head_dtype)
        # pooled_data shape: [n_tokens, hidden_dimension]

        # Apply ST projector
        if self.projector is not None:
            pooled_data = self.projector(pooled_data)
        # pooled_data shape: [n_tokens, embedding_dimension]

        # for matryoshka representation
        pooled_data = pooled_data[..., : pooling_param.dimensions]

        # for normalize
        if pooling_param.normalize:
            pooled_data = self.activation(pooled_data)

        # pooled_data shape: [n_tokens, embedding_dimension]
        return pooled_data


670
class TokenClassifierPoolerHead(TokenwisePoolerHead):
671
672
673
674
675
676
    def __init__(
        self,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ) -> None:
        super().__init__()
677

678
679
680
681
682
683
684
685
        vllm_config = get_current_vllm_config()

        self.classifier = classifier
        self.logit_bias: float | None = (
            vllm_config.model_config.pooler_config.logit_bias
        )
        self.head_dtype = vllm_config.model_config.head_dtype

686
687
688
        self.activation = ClassifierPooler.resolve_act_fn(
            vllm_config.model_config, static_num_labels=False, act_fn=act_fn
        )
689
690
691

    def forward(
        self,
692
        pooled_data: TokenwisePoolingMethodOutputItem,
693
        pooling_param: PoolingParams,
694
    ) -> TokenwisePoolerHeadOutput:
695
        # for unfinished chunked prefill
696
        if pooled_data is None:
697
698
            return None

699
        pooled_data = pooled_data.to(self.head_dtype)
700
701
702
        # hidden_states shape: [n_token, hidden_size]

        if self.classifier is not None:
703
            scores = self.classifier(pooled_data)
704
        else:
705
            scores = pooled_data
706
707
708
709
710
        # scores shape: [n_token, num_labels]

        if self.logit_bias is not None:
            scores -= self.logit_bias

711
        if pooling_param.use_activation:
712
            scores = self.activation(scores)
713
714
715
716
717
718

        # scores shape: [n_token, num_labels]
        return scores


class AllPooler(Pooler):
719
    def __init__(self, head: TokenwisePoolerHead) -> None:
720
721
722
723
724
725
726
727
728
729
730
731
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
732
    ) -> TokenwisePoolerOutput:
733
        pooled_data = self.pooling(hidden_states, pooling_metadata)
734
        pooling_params = pooling_metadata.pooling_params
735
736
        assert len(pooled_data) == len(pooling_params)

737
        return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
738
739
740


class StepPooler(Pooler):
741
    def __init__(self, head: TokenwisePoolerHead) -> None:
742
743
744
745
746
747
748
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def extract_states(
        self,
749
        hidden_states: torch.Tensor,
750
        pooling_metadata: PoolingMetadata,
751
    ) -> list[torch.Tensor | None]:
752
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
753
754
        prompt_token_ids = pooling_metadata.get_prompt_token_ids()
        pooling_params = pooling_metadata.pooling_params
755

756
        pooled_data = list[torch.Tensor | None]()
757
758
759
        for data, token_id, pooling_param in zip(
            pooled_data_lst, prompt_token_ids, pooling_params
        ):
760
761
762
763
764
            # for unfinished chunked prefill
            if data is None:
                pooled_data.append(data)
                continue

765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)

    def forward(
        self,
785
        hidden_states: torch.Tensor,
786
        pooling_metadata: PoolingMetadata,
787
    ) -> TokenwisePoolerOutput:
788
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
789
        pooling_params = pooling_metadata.pooling_params
790
791
        assert len(pooled_data) == len(pooling_params)

792
        return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
793
794


795
796
797
798
799
800
801
802
803
804
class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
805
806
                    f"Supported tasks: {pooler.get_supported_tasks()}"
                )
807
808
809
810
811

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
812

813
814
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
815

816
817
    def forward(
        self,
818
        hidden_states: torch.Tensor,
819
820
821
822
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

823
        outputs = list[torch.Tensor | None]()
824
        offset = 0
825
        for task, group in groupby(pooling_metadata.tasks):
826
827
828
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
829
830
                    f"Supported tasks: {self.get_supported_tasks()}"
                )
831
832
833

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
834
                hidden_states,
835
                pooling_metadata[offset : offset + num_items],
836
837
            )

838
            outputs.extend(group_output)
839
840
            offset += num_items

841
        return outputs
842
843
844
845

    def extra_repr(self) -> str:
        s = f"supported_task={self.get_supported_tasks()}"
        return s