pooler.py 26.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Callable, Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import TypeVar
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
16
from vllm.logger import init_logger
17
from vllm.model_executor.models.adapters import _load_st_projector
18
19
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
20
from vllm.utils.import_utils import resolve_obj_by_qualname
21
from vllm.v1.outputs import PoolerOutput
22
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
23

24
25
logger = init_logger(__name__)

26
PoolingFn = Callable[
27
28
    [torch.Tensor | list[torch.Tensor], PoolingMetadata],
    torch.Tensor | list[torch.Tensor],
29
]
30
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
31
32
33
34


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
35

36
    LAST = 0
37
    ALL = 1
38
    CLS = 2
39
    STEP = 3
40
    MEAN = 4
41
42


43
44
45
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
46
    task: PoolingTask
47

48
    @classmethod
49
    def from_config(
50
        cls,
51
        task: PoolingTask,
52
53
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
54
        assert pooler_config.pooling_type is not None
55
        return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
56
57


58
59
60
61
62
63
64
65
66
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


67
def get_classification_activation_function(config: PretrainedConfig):
68
69
70
71
72
73
74
75
76
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
77
78
79
80
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
81
    function_name: str | None = None
82
83
84
85
    if (
        hasattr(config, "sentence_transformers")
        and "activation_fn" in config.sentence_transformers
    ):
86
        function_name = config.sentence_transformers["activation_fn"]
87
88
89
90
    elif (
        hasattr(config, "sbert_ce_default_activation_function")
        and config.sbert_ce_default_activation_function is not None
    ):
91
92
93
94
95
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
96
97
            "torch.nn.modules for security reasons"
        )
98
99
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
100

101
    return PoolerClassify()
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

class PoolingMethod(nn.Module, ABC):
    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

118
    @abstractmethod
119
    def get_supported_tasks(self) -> Set[PoolingTask]:
120
121
        raise NotImplementedError

122
123
124
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

125
126
    @abstractmethod
    def forward_all(
127
        self,
128
        hidden_states: torch.Tensor,
129
        pooling_cursor: PoolingCursor,
130
    ) -> PoolerOutput:
131
132
133
134
        raise NotImplementedError

    def forward(
        self,
135
        hidden_states: torch.Tensor,
136
        pooling_metadata: PoolingMetadata,
137
    ) -> PoolerOutput:
138
139
        pooling_cursor = pooling_metadata.pooling_cursor
        return self.forward_all(hidden_states, pooling_cursor)
140
141


142
class CLSPool(PoolingMethod):
143
    def get_supported_tasks(self) -> Set[PoolingTask]:
144
        return {"token_embed", "token_classify", "embed", "classify", "score"}
145

146
147
148
    def forward_all(
        self,
        hidden_states: torch.Tensor,
149
        pooling_cursor: PoolingCursor,
150
    ) -> PoolerOutput:
151
        assert not pooling_cursor.is_partial_prefill(), (
152
            "partial prefill not supported with CLS pooling"
153
        )
154
155

        return hidden_states[pooling_cursor.first_token_indices_gpu]
156
157


158
class LastPool(PoolingMethod):
159
    def get_supported_tasks(self) -> Set[PoolingTask]:
160
        return {"token_embed", "token_classify", "embed", "classify", "score"}
161

162
163
164
    def forward_all(
        self,
        hidden_states: torch.Tensor,
165
        pooling_cursor: PoolingCursor,
166
    ) -> PoolerOutput:
167
        return hidden_states[pooling_cursor.last_token_indices_gpu]
168
169


170
class AllPool(PoolingMethod):
171
172
173
174
175
176
177
178
    def __init__(self):
        super().__init__()

        vllm_config = get_current_vllm_config()
        self.enable_chunked_prefill = (
            vllm_config.scheduler_config.enable_chunked_prefill
        )

179
    def get_supported_tasks(self) -> Set[PoolingTask]:
180
        return {"token_embed", "token_classify"}
181

182
    def forward_all(
183
184
185
186
        self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
    ) -> PoolerOutput:
        raise NotImplementedError(
            "forward_all is not implemented for AllPool. Use forward instead."
187
        )
188

189
190
191
192
193
194
195
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooling_cursor = pooling_metadata.pooling_cursor
        is_finished = pooling_cursor.is_finished()
196
        hidden_states_lst = list(
197
198
            hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
        )
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]

        if not self.enable_chunked_prefill:
            return hidden_states_lst

        pooling_states = pooling_metadata.pooling_states

        # If chunked_prefill is enabled
        # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
        for p, hs_chunk in zip(pooling_states, hidden_states_lst):
            p.hidden_states_cache.append(hs_chunk)

        # 2. Once prefill is finished, send hidden_states_cache to PoolerHead
        output_list: PoolerOutput = []
        for p, finished in zip(pooling_states, is_finished):
            if finished:
                hidden_states_cache = p.hidden_states_cache
                if len(hidden_states_cache) == 1:
                    output_list.append(hidden_states_cache[0])
                else:
                    output_list.append(torch.concat(hidden_states_cache, dim=0))
                p.clean()
            else:
                output_list.append(None)

        return output_list
225
226


227
class MeanPool(PoolingMethod):
228
    def get_supported_tasks(self) -> Set[PoolingTask]:
229
        return {"token_embed", "token_classify", "embed", "classify", "score"}
230

231
    def forward_all(
232
        self,
233
        hidden_states: torch.Tensor,
234
        pooling_cursor: PoolingCursor,
235
    ) -> PoolerOutput:
236
        assert not pooling_cursor.is_partial_prefill(), (
237
            "partial prefill not supported with MEAN pooling"
238
        )
239

240
241
242
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(
            hidden_states.device, non_blocking=True
        )
243

244
245
246
247
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

248
249
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
250
251
252
        return (
            cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
        ) / prompt_lens.unsqueeze(1)
253
254


255
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
256
257


258
259
260
261
262
263
264
265
class BasePoolerActivation(nn.Module, ABC):
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
266

267

268
269
270
271
272
273
274
class PoolerActivation(BasePoolerActivation):
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
275

276
277
278
279
280
281
282
283
284
285
286
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
287
288


289
290
class PoolerIdentity(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
291
292
293
        return pooled_data


294
295
class PoolerNormalize(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
296
        return F.normalize(pooled_data, p=2, dim=-1)
297
298


299
300
class PoolerMultiLabelClassify(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
301
        return F.sigmoid(pooled_data)
302
303


304
class PoolerClassify(PoolerActivation):
305
306
    def __init__(self, *, static_num_labels: bool = True) -> None:
        super().__init__()
307

308
309
        if static_num_labels:
            vllm_config = get_current_vllm_config()
310
311
312
            self.num_labels = getattr(
                vllm_config.model_config.hf_config, "num_labels", 0
            )
313
            if self.num_labels == 0:
314
315
316
317
318
                logger.warning(
                    "num_labels should be > 0 for classification"
                    "models, falling back to softmax. "
                    "Please check if the configuration is correct."
                )
319
320
        else:
            self.num_labels = None
321
322

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
323
324
325
        num_labels = (
            self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
        )
326

327
        if num_labels < 2:
328
            return F.sigmoid(pooled_data)
329

330
        return F.softmax(pooled_data, dim=-1)
331
332
333
334
335
336
337
338
339
340
341
342


class LambdaPoolerActivation(PoolerActivation):
    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
    def for_token_embed(pooler_config: PoolerConfig):
        head = TokenEmbeddingPoolerHead()

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_token_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None = None,
        act_fn: PoolerActivation | str | None = None,
    ):
        head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="embed",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
        head = EmbeddingPoolerHead()

        return SimplePooler(pooling=pooling, head=head)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="classify",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)

        return ClassifierPooler(
            pooling=pooling,
            classifier=classifier,
            act_fn=act_fn,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        """
        Construct the updated pooling parameters to use for a supported task.
        """
        return PoolingParamsUpdate()

    @abstractmethod
    def forward(
        self,
        hidden_states: list[torch.Tensor] | torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


419
420
421
422
423
424
425
426
427
428
429
430
class DummyPooler(Pooler):
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"plugin", "score"}

    def forward(
        self,
        hidden_states: list[torch.Tensor] | torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        return hidden_states


431
class PoolerHead(nn.Module):
432
    def __init__(self, activation: PoolerActivation) -> None:
433
        super().__init__()
434
        self.activation = activation
435

436
437
    def forward(
        self,
438
        pooled_data: list[torch.Tensor] | torch.Tensor,
439
        pooling_metadata: PoolingMetadata,
440
    ) -> PoolerOutput:
441
442
443
444
445
446
447
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):
    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

448
449
        # Load ST projector if available
        vllm_config = get_current_vllm_config()
450
        self.projector: nn.Module | None = (
451
452
            _load_st_projector(vllm_config.model_config) if vllm_config else None
        )
453
        self.head_dtype = vllm_config.model_config.head_dtype
454

455
456
    def forward(
        self,
457
        pooled_data: list[torch.Tensor] | torch.Tensor,
458
        pooling_metadata: PoolingMetadata,
459
    ) -> PoolerOutput:
460
461
462
463
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_dimension]

464
465
        pooled_data = pooled_data.to(self.head_dtype)

466
467
        # Apply ST projector
        if self.projector is not None:
468
            pooled_data = self.projector(pooled_data)
469
        # pooled_data shape: [batchsize, embedding_dimension]
470

471
        pooling_params = pooling_metadata.pooling_params
472

473
        # for matryoshka representation
474
        dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
475
476
477
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
478
            if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
479
480
481
482
483
484
485
486
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
487

488
489
490
491
492
493
494
495
496
497
498
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

499
        # pooled_data shape: [batchsize, embedding_dimension]
500
501
502
        return pooled_data


503
class SimplePooler(Pooler):
504
    """A layer that pools specific information from hidden states.
505

506
507
508
509
510
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
511

512
513
514
515
516
517
    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

518
519
520
521
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
522
        return self.pooling.get_pooling_updates(task)
523

524
525
    def forward(
        self,
526
        hidden_states: torch.Tensor | list[torch.Tensor],
527
528
529
530
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
531
        return pooled_data
532
533


534
class ClassifierPooler(Pooler):
535
    """A pooling layer for classification tasks.
536
537

    This layer does the following:
538
539
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
540
    3. Applies an activation function to the output.
541
542
    """

543
    @staticmethod
544
545
546
547
548
549
    def act_fn_for_seq_cls(model_config: ModelConfig):
        return get_classification_activation_function(model_config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(model_config: ModelConfig):
        return get_cross_encoder_activation_function(model_config.hf_config)
550
551

    @staticmethod
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    def resolve_act_fn(
        model_config: ModelConfig,
        static_num_labels: bool = True,
        act_fn: PoolerActivation | str | None = None,
    ):
        if isinstance(act_fn, str):
            if act_fn == "classify":
                return ClassifierPooler.act_fn_for_seq_cls(model_config)
            elif act_fn == "score":
                return ClassifierPooler.act_fn_for_cross_encoder(model_config)
            else:
                raise ValueError(f"act_fn [{act_fn=}] not supported.")
        elif act_fn is None:
            return PoolerClassify(static_num_labels=static_num_labels)
        else:
            assert callable(act_fn)
            return act_fn
569

570
571
    def __init__(
        self,
572
        pooling: PoolingFn,
573
        classifier: ClassifierFn | None,
574
        act_fn: PoolerActivation | str | None = None,
575
    ) -> None:
576
        super().__init__()
577

578
        vllm_config = get_current_vllm_config()
579
        self.pooling = pooling
580
        self.classifier = classifier
581
582
583
        self.act_fn = self.resolve_act_fn(
            vllm_config.model_config, static_num_labels=True, act_fn=act_fn
        )
584
        self.logit_bias: float | None = (
585
586
            vllm_config.model_config.pooler_config.logit_bias
        )
587
        self.head_dtype = vllm_config.model_config.head_dtype
588

589
590
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
591

592
593
    def forward(
        self,
594
        hidden_states: torch.Tensor | list[torch.Tensor],
595
596
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
597
        pooled_data = self.pooling(hidden_states, pooling_metadata)
598
599
600
601
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

602
603
        pooled_data = pooled_data.to(self.head_dtype)

604
        if self.classifier is not None:
605
606
            pooled_data = self.classifier(pooled_data)
        # pooled_data shape: [batchsize, num_labels]
607

608
609
610
        if self.logit_bias is not None:
            pooled_data -= self.logit_bias

611
        pooling_params = pooling_metadata.pooling_params
612
        flags = [p.use_activation for p in pooling_params]
613
614
615
616
617

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
618
                self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
619
            ]
620

621
        # scores shape: [batchsize, num_labels]
622
        return scores
623
624


625
626
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
    def forward(
627
628
629
630
631
632
        self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
    ) -> PoolerOutput:
        # for unfinished chunked prefill
        if pooled_data is None:
            return None

633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        pooled_data = pooled_data.to(self.head_dtype)
        # pooled_data shape: [n_tokens, hidden_dimension]

        # Apply ST projector
        if self.projector is not None:
            pooled_data = self.projector(pooled_data)
        # pooled_data shape: [n_tokens, embedding_dimension]

        # for matryoshka representation
        pooled_data = pooled_data[..., : pooling_param.dimensions]

        # for normalize
        if pooling_param.normalize:
            pooled_data = self.activation(pooled_data)

        # pooled_data shape: [n_tokens, embedding_dimension]
        return pooled_data


class TokenClassifierPoolerHead(nn.Module):
    def __init__(
        self,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ) -> None:
        super().__init__()
        vllm_config = get_current_vllm_config()

        self.classifier = classifier
        self.act_fn = ClassifierPooler.resolve_act_fn(
            vllm_config.model_config, static_num_labels=False, act_fn=act_fn
        )
        self.logit_bias: float | None = (
            vllm_config.model_config.pooler_config.logit_bias
        )
        self.head_dtype = vllm_config.model_config.head_dtype

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_classify"}

    def forward(
        self,
675
        hidden_states: torch.Tensor | None,
676
        pooling_param: PoolingParams,
677
678
679
680
681
    ) -> PoolerOutput:
        # for unfinished chunked prefill
        if hidden_states is None:
            return None

682
683
684
685
686
687
688
689
690
691
692
693
        hidden_states = hidden_states.to(self.head_dtype)
        # hidden_states shape: [n_token, hidden_size]

        if self.classifier is not None:
            scores = self.classifier(hidden_states)
        else:
            scores = hidden_states
        # scores shape: [n_token, num_labels]

        if self.logit_bias is not None:
            scores -= self.logit_bias

694
        if pooling_param.use_activation:
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
            scores = self.act_fn(scores)

        # scores shape: [n_token, num_labels]
        return scores


class AllPooler(Pooler):
    def __init__(self, head: nn.Module | PoolerHead) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
717
        pooling_params = pooling_metadata.pooling_params
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
        assert len(pooled_data) == len(pooling_params)

        pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
        return pooled_data


class StepPooler(Pooler):
    def __init__(self, head: nn.Module | PoolerHead) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def extract_states(
        self,
        hidden_states: torch.Tensor | list[torch.Tensor],
        pooling_metadata: PoolingMetadata,
735
    ) -> PoolerOutput:
736
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
737
738
        prompt_token_ids = pooling_metadata.get_prompt_token_ids()
        pooling_params = pooling_metadata.pooling_params
739

740
        pooled_data: PoolerOutput = []
741
742
743
        for data, token_id, pooling_param in zip(
            pooled_data_lst, prompt_token_ids, pooling_params
        ):
744
745
746
747
748
            # for unfinished chunked prefill
            if data is None:
                pooled_data.append(data)
                continue

749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)

    def forward(
        self,
        hidden_states: torch.Tensor | list[torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
773
        pooling_params = pooling_metadata.pooling_params
774
775
776
777
778
779
        assert len(pooled_data) == len(pooling_params)

        pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
        return pooled_data


780
781
782
783
784
785
786
787
788
789
class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
790
791
                    f"Supported tasks: {pooler.get_supported_tasks()}"
                )
792
793
794
795
796

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
797

798
799
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
800

801
802
    def forward(
        self,
803
        hidden_states: torch.Tensor | list[torch.Tensor],
804
805
806
807
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

808
        outputs = list[torch.Tensor]()
809
        offset = 0
810
        for task, group in groupby(pooling_metadata.tasks):
811
812
813
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
814
815
                    f"Supported tasks: {self.get_supported_tasks()}"
                )
816
817
818

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
819
                hidden_states,
820
                pooling_metadata[offset : offset + num_items],
821
822
            )

823
            outputs.extend(group_output)
824
825
            offset += num_items

826
        return outputs
827
828
829
830

    def extra_repr(self) -> str:
        s = f"supported_task={self.get_supported_tasks()}"
        return s