pooler.py 27.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Callable, Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import TypeAlias, TypeVar
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
16
from vllm.logger import init_logger
17
from vllm.model_executor.models.adapters import _load_st_projector
18
19
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
20
from vllm.utils.import_utils import resolve_obj_by_qualname
21
22
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
23

24
25
logger = init_logger(__name__)

26
PoolingFn = Callable[
27
28
    [torch.Tensor | list[torch.Tensor], PoolingMetadata],
    torch.Tensor | list[torch.Tensor],
29
]
30
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
31
32


33
34
35
36
37
38
39
40
41
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput

TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None


42
43
class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
44

45
    LAST = 0
46
    ALL = 1
47
    CLS = 2
48
    STEP = 3
49
    MEAN = 4
50
51


52
53
54
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
55
    task: PoolingTask
56

57
    @classmethod
58
    def from_config(
59
        cls,
60
        task: PoolingTask,
61
62
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
63
        assert pooler_config.pooling_type is not None
64
        return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
65
66


67
68
69
70
71
72
73
74
75
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


76
def get_classification_activation_function(config: PretrainedConfig):
77
78
79
80
81
82
83
84
85
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
86
87
88
89
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
90
    function_name: str | None = None
91
92
93
94
    if (
        hasattr(config, "sentence_transformers")
        and "activation_fn" in config.sentence_transformers
    ):
95
        function_name = config.sentence_transformers["activation_fn"]
96
97
98
99
    elif (
        hasattr(config, "sbert_ce_default_activation_function")
        and config.sbert_ce_default_activation_function is not None
    ):
100
101
102
103
104
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
105
106
            "torch.nn.modules for security reasons"
        )
107
108
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
109

110
    return PoolerClassify()
111

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

class PoolingMethod(nn.Module, ABC):
    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

127
    @abstractmethod
128
    def get_supported_tasks(self) -> Set[PoolingTask]:
129
130
        raise NotImplementedError

131
132
133
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

134
    @abstractmethod
135
136
    def forward(
        self,
137
        hidden_states: torch.Tensor,
138
        pooling_metadata: PoolingMetadata,
139
140
    ) -> PoolingMethodOutput:
        raise NotImplementedError
141
142


143
class CLSPool(PoolingMethod):
144
    def get_supported_tasks(self) -> Set[PoolingTask]:
145
        return {"token_embed", "token_classify", "embed", "classify", "score"}
146

147
    def forward(
148
149
        self,
        hidden_states: torch.Tensor,
150
151
152
        pooling_metadata: PoolingMetadata,
    ) -> TokenPoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
153
        assert not pooling_cursor.is_partial_prefill(), (
154
            "partial prefill not supported with CLS pooling"
155
        )
156
157

        return hidden_states[pooling_cursor.first_token_indices_gpu]
158
159


160
class LastPool(PoolingMethod):
161
    def get_supported_tasks(self) -> Set[PoolingTask]:
162
        return {"token_embed", "token_classify", "embed", "classify", "score"}
163

164
    def forward(
165
166
        self,
        hidden_states: torch.Tensor,
167
168
169
        pooling_metadata: PoolingMetadata,
    ) -> TokenPoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
170
        return hidden_states[pooling_cursor.last_token_indices_gpu]
171
172


173
class AllPool(PoolingMethod):
174
175
176
177
178
179
180
181
    def __init__(self):
        super().__init__()

        vllm_config = get_current_vllm_config()
        self.enable_chunked_prefill = (
            vllm_config.scheduler_config.enable_chunked_prefill
        )

182
    def get_supported_tasks(self) -> Set[PoolingTask]:
183
        return {"token_embed", "token_classify"}
184

185
186
187
188
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
189
190
    ) -> TokensPoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
191
        is_finished = pooling_cursor.is_finished()
192
        hidden_states_lst = list(
193
194
            hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
        )
195
196
197
198
199
200
201
202
203
204
205
206
207
        hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]

        if not self.enable_chunked_prefill:
            return hidden_states_lst

        pooling_states = pooling_metadata.pooling_states

        # If chunked_prefill is enabled
        # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
        for p, hs_chunk in zip(pooling_states, hidden_states_lst):
            p.hidden_states_cache.append(hs_chunk)

        # 2. Once prefill is finished, send hidden_states_cache to PoolerHead
208
        output_list = list[torch.Tensor | None]()
209
210
211
212
213
214
215
216
217
218
219
220
        for p, finished in zip(pooling_states, is_finished):
            if finished:
                hidden_states_cache = p.hidden_states_cache
                if len(hidden_states_cache) == 1:
                    output_list.append(hidden_states_cache[0])
                else:
                    output_list.append(torch.concat(hidden_states_cache, dim=0))
                p.clean()
            else:
                output_list.append(None)

        return output_list
221
222


223
class MeanPool(PoolingMethod):
224
    def get_supported_tasks(self) -> Set[PoolingTask]:
225
        return {"token_embed", "token_classify", "embed", "classify", "score"}
226

227
    def forward(
228
        self,
229
        hidden_states: torch.Tensor,
230
231
232
        pooling_metadata: PoolingMetadata,
    ) -> TokenPoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
233
        assert not pooling_cursor.is_partial_prefill(), (
234
            "partial prefill not supported with MEAN pooling"
235
        )
236

237
238
239
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(
            hidden_states.device, non_blocking=True
        )
240

241
242
243
244
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

245
246
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
247
248
249
        return (
            cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
        ) / prompt_lens.unsqueeze(1)
250
251


252
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
253
254


255
256
257
258
259
260
261
262
class BasePoolerActivation(nn.Module, ABC):
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
263

264

265
266
267
268
269
270
271
class PoolerActivation(BasePoolerActivation):
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
272

273
274
275
276
277
278
279
280
281
282
283
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
284
285


286
287
class PoolerIdentity(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
288
289
290
        return pooled_data


291
292
class PoolerNormalize(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
293
        return F.normalize(pooled_data, p=2, dim=-1)
294
295


296
297
class PoolerMultiLabelClassify(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
298
        return F.sigmoid(pooled_data)
299
300


301
class PoolerClassify(PoolerActivation):
302
303
    def __init__(self, *, static_num_labels: bool = True) -> None:
        super().__init__()
304

305
306
        if static_num_labels:
            vllm_config = get_current_vllm_config()
307
308
309
            self.num_labels = getattr(
                vllm_config.model_config.hf_config, "num_labels", 0
            )
310
            if self.num_labels == 0:
311
312
313
314
315
                logger.warning(
                    "num_labels should be > 0 for classification"
                    "models, falling back to softmax. "
                    "Please check if the configuration is correct."
                )
316
317
        else:
            self.num_labels = None
318
319

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
320
321
322
        num_labels = (
            self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
        )
323

324
        if num_labels < 2:
325
            return F.sigmoid(pooled_data)
326

327
        return F.softmax(pooled_data, dim=-1)
328
329
330
331
332
333
334
335
336
337
338
339


class LambdaPoolerActivation(PoolerActivation):
    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
    def for_token_embed(pooler_config: PoolerConfig):
        head = TokenEmbeddingPoolerHead()

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_token_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None = None,
        act_fn: PoolerActivation | str | None = None,
    ):
        head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="embed",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
        head = EmbeddingPoolerHead()

        return SimplePooler(pooling=pooling, head=head)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="classify",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)

        return ClassifierPooler(
            pooling=pooling,
            classifier=classifier,
            act_fn=act_fn,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        """
        Construct the updated pooling parameters to use for a supported task.
        """
        return PoolingParamsUpdate()

    @abstractmethod
    def forward(
        self,
410
        hidden_states: torch.Tensor,
411
412
413
414
415
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


416
417
418
419
420
421
class DummyPooler(Pooler):
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"plugin", "score"}

    def forward(
        self,
422
        hidden_states: torch.Tensor,
423
424
425
426
427
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        return hidden_states


428
429
class TokenPoolerHead(nn.Module, ABC):
    """Applicable to pooling strategies that output one token."""
430

431
    @abstractmethod
432
433
    def forward(
        self,
434
        pooled_data: TokenPoolingMethodOutput,
435
        pooling_metadata: PoolingMetadata,
436
437
    ) -> TokenPoolerHeadOutput:
        raise NotImplementedError
438
439


440
class EmbeddingPoolerHead(TokenPoolerHead):
441
    def __init__(self) -> None:
442
        super().__init__()
443

444
445
        # Load ST projector if available
        vllm_config = get_current_vllm_config()
446
        self.projector = (
447
448
            _load_st_projector(vllm_config.model_config) if vllm_config else None
        )
449
        self.head_dtype = vllm_config.model_config.head_dtype
450

451
452
        self.activation = PoolerNormalize()

453
454
    def forward(
        self,
455
        pooled_data: TokenPoolingMethodOutput,
456
        pooling_metadata: PoolingMetadata,
457
    ) -> TokenPoolerHeadOutput:
458
459
460
461
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_dimension]

462
463
        pooled_data = pooled_data.to(self.head_dtype)

464
465
        # Apply ST projector
        if self.projector is not None:
466
            pooled_data = self.projector(pooled_data)
467
        # pooled_data shape: [batchsize, embedding_dimension]
468

469
        pooling_params = pooling_metadata.pooling_params
470

471
        # for matryoshka representation
472
        dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
473
474
475
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
476
            if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
477
478
479
480
481
482
483
484
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
485

486
487
488
489
490
491
492
493
494
495
496
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

497
        # pooled_data shape: [batchsize, embedding_dimension]
498
499
500
        return pooled_data


501
class SimplePooler(Pooler):
502
    """A layer that pools specific information from hidden states.
503

504
505
506
507
508
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
509

510
    def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None:
511
512
513
514
515
        super().__init__()

        self.pooling = pooling
        self.head = head

516
517
518
519
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
520
        return self.pooling.get_pooling_updates(task)
521

522
523
    def forward(
        self,
524
        hidden_states: torch.Tensor,
525
        pooling_metadata: PoolingMetadata,
526
    ) -> TokenPoolerHeadOutput:
527
528
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
529
        return pooled_data
530
531


532
class ClassifierPooler(Pooler):
533
    """A pooling layer for classification tasks.
534
535

    This layer does the following:
536
537
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
538
    3. Applies an activation function to the output.
539
540
    """

541
    @staticmethod
542
543
544
545
546
547
    def act_fn_for_seq_cls(model_config: ModelConfig):
        return get_classification_activation_function(model_config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(model_config: ModelConfig):
        return get_cross_encoder_activation_function(model_config.hf_config)
548
549

    @staticmethod
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def resolve_act_fn(
        model_config: ModelConfig,
        static_num_labels: bool = True,
        act_fn: PoolerActivation | str | None = None,
    ):
        if isinstance(act_fn, str):
            if act_fn == "classify":
                return ClassifierPooler.act_fn_for_seq_cls(model_config)
            elif act_fn == "score":
                return ClassifierPooler.act_fn_for_cross_encoder(model_config)
            else:
                raise ValueError(f"act_fn [{act_fn=}] not supported.")
        elif act_fn is None:
            return PoolerClassify(static_num_labels=static_num_labels)
        else:
            assert callable(act_fn)
            return act_fn
567

568
569
    def __init__(
        self,
570
        pooling: PoolingFn,
571
        classifier: ClassifierFn | None,
572
        act_fn: PoolerActivation | str | None = None,
573
    ) -> None:
574
        super().__init__()
575

576
        vllm_config = get_current_vllm_config()
577
        self.pooling = pooling
578
        self.classifier = classifier
579
580
581
        self.act_fn = self.resolve_act_fn(
            vllm_config.model_config, static_num_labels=True, act_fn=act_fn
        )
582
        self.logit_bias: float | None = (
583
584
            vllm_config.model_config.pooler_config.logit_bias
        )
585
        self.head_dtype = vllm_config.model_config.head_dtype
586

587
588
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
589

590
591
    def forward(
        self,
592
        hidden_states: torch.Tensor,
593
        pooling_metadata: PoolingMetadata,
594
    ) -> TokenPoolerOutput:
595
        pooled_data = self.pooling(hidden_states, pooling_metadata)
596
597
598
599
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

600
601
        pooled_data = pooled_data.to(self.head_dtype)

602
        if self.classifier is not None:
603
604
            pooled_data = self.classifier(pooled_data)
        # pooled_data shape: [batchsize, num_labels]
605

606
607
608
        if self.logit_bias is not None:
            pooled_data -= self.logit_bias

609
        pooling_params = pooling_metadata.pooling_params
610
        flags = [p.use_activation for p in pooling_params]
611
612
613
614
615

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
616
                self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
617
            ]
618

619
        # scores shape: [batchsize, num_labels]
620
        return scores
621
622


623
624
625
626
class TokensPoolerHead(nn.Module, ABC):
    """Applicable to pooling strategies that output multiple tokens."""

    @abstractmethod
627
    def forward(
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        self,
        pooled_data: TokensPoolingMethodOutputItem,
        pooling_param: PoolingParams,
    ) -> TokensPoolerHeadOutput:
        raise NotImplementedError


class TokenEmbeddingPoolerHead(TokensPoolerHead):
    def __init__(self) -> None:
        super().__init__()

        # Load ST projector if available
        vllm_config = get_current_vllm_config()
        self.projector = (
            _load_st_projector(vllm_config.model_config) if vllm_config else None
        )
        self.head_dtype = vllm_config.model_config.head_dtype

        self.activation = PoolerNormalize()

    def forward(
        self,
        pooled_data: TokensPoolingMethodOutputItem,
        pooling_param: PoolingParams,
    ) -> TokensPoolerHeadOutput:
653
654
655
656
        # for unfinished chunked prefill
        if pooled_data is None:
            return None

657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        pooled_data = pooled_data.to(self.head_dtype)
        # pooled_data shape: [n_tokens, hidden_dimension]

        # Apply ST projector
        if self.projector is not None:
            pooled_data = self.projector(pooled_data)
        # pooled_data shape: [n_tokens, embedding_dimension]

        # for matryoshka representation
        pooled_data = pooled_data[..., : pooling_param.dimensions]

        # for normalize
        if pooling_param.normalize:
            pooled_data = self.activation(pooled_data)

        # pooled_data shape: [n_tokens, embedding_dimension]
        return pooled_data


676
class TokenClassifierPoolerHead(TokensPoolerHead):
677
678
679
680
681
682
    def __init__(
        self,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ) -> None:
        super().__init__()
683

684
685
686
687
688
689
690
691
        vllm_config = get_current_vllm_config()

        self.classifier = classifier
        self.logit_bias: float | None = (
            vllm_config.model_config.pooler_config.logit_bias
        )
        self.head_dtype = vllm_config.model_config.head_dtype

692
693
694
        self.activation = ClassifierPooler.resolve_act_fn(
            vllm_config.model_config, static_num_labels=False, act_fn=act_fn
        )
695
696
697

    def forward(
        self,
698
        pooled_data: TokensPoolingMethodOutputItem,
699
        pooling_param: PoolingParams,
700
    ) -> TokensPoolerHeadOutput:
701
        # for unfinished chunked prefill
702
        if pooled_data is None:
703
704
            return None

705
        pooled_data = pooled_data.to(self.head_dtype)
706
707
708
        # hidden_states shape: [n_token, hidden_size]

        if self.classifier is not None:
709
            scores = self.classifier(pooled_data)
710
        else:
711
            scores = pooled_data
712
713
714
715
716
        # scores shape: [n_token, num_labels]

        if self.logit_bias is not None:
            scores -= self.logit_bias

717
        if pooling_param.use_activation:
718
            scores = self.activation(scores)
719
720
721
722
723
724

        # scores shape: [n_token, num_labels]
        return scores


class AllPooler(Pooler):
725
    def __init__(self, head: TokensPoolerHead) -> None:
726
727
728
729
730
731
732
733
734
735
736
737
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
738
    ) -> TokensPoolerOutput:
739
        pooled_data = self.pooling(hidden_states, pooling_metadata)
740
        pooling_params = pooling_metadata.pooling_params
741
742
        assert len(pooled_data) == len(pooling_params)

743
        return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
744
745
746


class StepPooler(Pooler):
747
    def __init__(self, head: TokensPoolerHead) -> None:
748
749
750
751
752
753
754
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def extract_states(
        self,
755
        hidden_states: torch.Tensor,
756
        pooling_metadata: PoolingMetadata,
757
    ) -> list[torch.Tensor | None]:
758
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
759
760
        prompt_token_ids = pooling_metadata.get_prompt_token_ids()
        pooling_params = pooling_metadata.pooling_params
761

762
        pooled_data = list[torch.Tensor | None]()
763
764
765
        for data, token_id, pooling_param in zip(
            pooled_data_lst, prompt_token_ids, pooling_params
        ):
766
767
768
769
770
            # for unfinished chunked prefill
            if data is None:
                pooled_data.append(data)
                continue

771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)

    def forward(
        self,
791
        hidden_states: torch.Tensor,
792
        pooling_metadata: PoolingMetadata,
793
    ) -> TokensPoolerOutput:
794
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
795
        pooling_params = pooling_metadata.pooling_params
796
797
        assert len(pooled_data) == len(pooling_params)

798
        return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
799
800


801
802
803
804
805
806
807
808
809
810
class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
811
812
                    f"Supported tasks: {pooler.get_supported_tasks()}"
                )
813
814
815
816
817

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
818

819
820
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
821

822
823
    def forward(
        self,
824
        hidden_states: torch.Tensor,
825
826
827
828
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

829
        outputs = list[torch.Tensor | None]()
830
        offset = 0
831
        for task, group in groupby(pooling_metadata.tasks):
832
833
834
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
835
836
                    f"Supported tasks: {self.get_supported_tasks()}"
                )
837
838
839

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
840
                hidden_states,
841
                pooling_metadata[offset : offset + num_items],
842
843
            )

844
            outputs.extend(group_output)
845
846
            offset += num_items

847
        return outputs
848
849
850
851

    def extra_repr(self) -> str:
        s = f"supported_task={self.get_supported_tasks()}"
        return s