pooler.py 25.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Callable, Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import TypeVar
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
16
from vllm.logger import init_logger
17
from vllm.model_executor.models.adapters import _load_st_projector
18
19
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
20
21
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput
22
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
23

24
25
logger = init_logger(__name__)

26
PoolingFn = Callable[
27
28
    [torch.Tensor | list[torch.Tensor], PoolingMetadata],
    torch.Tensor | list[torch.Tensor],
29
]
30
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
31
32
33
34


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
35

36
    LAST = 0
37
    ALL = 1
38
    CLS = 2
39
    STEP = 3
40
    MEAN = 4
41
42


43
44
45
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
46
    task: PoolingTask
47

48
    @classmethod
49
    def from_config(
50
        cls,
51
        task: PoolingTask,
52
53
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
54
        assert pooler_config.pooling_type is not None
55
        return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
56
57


58
59
60
61
62
63
64
65
66
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


67
def get_prompt_lens(
68
    hidden_states: torch.Tensor | list[torch.Tensor],
69
70
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
71
    return pooling_metadata.prompt_lens
72
73


74
def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
75
    assert pooling_metadata.prompt_token_ids is not None, (
76
77
        "Please set `requires_token_ids=True` in `get_pooling_updates`"
    )
78
79

    return [
80
81
        pooling_metadata.prompt_token_ids[i, :num]
        for i, num in enumerate(pooling_metadata.prompt_lens)
82
83
84
    ]


85
def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
86
    pooling_params = pooling_metadata.pooling_params
87
88
89
90
91
    return pooling_params


def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    pooling_params = get_pooling_params(pooling_metadata)
92
93

    tasks: list[PoolingTask] = [
94
95
        task
        for pooling_param in pooling_params
96
97
98
99
100
101
102
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


103
def get_classification_activation_function(config: PretrainedConfig):
104
105
106
107
108
109
110
111
112
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
113
114
115
116
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
117
    function_name: str | None = None
118
119
120
121
    if (
        hasattr(config, "sentence_transformers")
        and "activation_fn" in config.sentence_transformers
    ):
122
        function_name = config.sentence_transformers["activation_fn"]
123
124
125
126
    elif (
        hasattr(config, "sbert_ce_default_activation_function")
        and config.sbert_ce_default_activation_function is not None
    ):
127
128
129
130
131
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
132
133
            "torch.nn.modules for security reasons"
        )
134
135
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
136

137
    return PoolerClassify()
138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

class PoolingMethod(nn.Module, ABC):
    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

154
    @abstractmethod
155
    def get_supported_tasks(self) -> Set[PoolingTask]:
156
157
        raise NotImplementedError

158
159
160
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

161
162
    @abstractmethod
    def forward_all(
163
        self,
164
        hidden_states: torch.Tensor,
165
        pooling_cursor: PoolingCursor,
166
    ) -> list[torch.Tensor] | torch.Tensor:
167
168
169
170
        raise NotImplementedError

    def forward(
        self,
171
        hidden_states: torch.Tensor,
172
        pooling_metadata: PoolingMetadata,
173
    ) -> list[torch.Tensor] | torch.Tensor:
174
175
        pooling_cursor = pooling_metadata.pooling_cursor
        return self.forward_all(hidden_states, pooling_cursor)
176
177


178
class CLSPool(PoolingMethod):
179
    def get_supported_tasks(self) -> Set[PoolingTask]:
180
        return {"token_embed", "token_classify", "embed", "classify", "score"}
181

182
183
184
    def forward_all(
        self,
        hidden_states: torch.Tensor,
185
        pooling_cursor: PoolingCursor,
186
    ) -> list[torch.Tensor] | torch.Tensor:
187
        assert not pooling_cursor.is_partial_prefill(), (
188
            "partial prefill not supported with CLS pooling"
189
        )
190
191

        return hidden_states[pooling_cursor.first_token_indices_gpu]
192
193


194
class LastPool(PoolingMethod):
195
    def get_supported_tasks(self) -> Set[PoolingTask]:
196
        return {"token_embed", "token_classify", "embed", "classify", "score"}
197

198
199
200
    def forward_all(
        self,
        hidden_states: torch.Tensor,
201
        pooling_cursor: PoolingCursor,
202
    ) -> list[torch.Tensor] | torch.Tensor:
203
        return hidden_states[pooling_cursor.last_token_indices_gpu]
204
205


206
class AllPool(PoolingMethod):
207
    def get_supported_tasks(self) -> Set[PoolingTask]:
208
        return {"token_embed", "token_classify"}
209

210
211
212
    def forward_all(
        self,
        hidden_states: torch.Tensor,
213
        pooling_cursor: PoolingCursor,
214
    ) -> list[torch.Tensor] | torch.Tensor:
215
        assert not pooling_cursor.is_partial_prefill(), (
216
            "partial prefill not supported with ALL pooling"
217
        )
218
219

        hidden_states_lst = list(
220
221
            hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
        )
222
        return [hidden_states_lst[i] for i in pooling_cursor.index]
223
224


225
class MeanPool(PoolingMethod):
226
    def get_supported_tasks(self) -> Set[PoolingTask]:
227
        return {"token_embed", "token_classify", "embed", "classify", "score"}
228

229
    def forward_all(
230
        self,
231
        hidden_states: torch.Tensor,
232
        pooling_cursor: PoolingCursor,
233
    ) -> list[torch.Tensor] | torch.Tensor:
234
        assert not pooling_cursor.is_partial_prefill(), (
235
            "partial prefill not supported with MEAN pooling"
236
        )
237

238
239
240
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(
            hidden_states.device, non_blocking=True
        )
241

242
243
244
245
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

246
247
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
248
249
250
        return (
            cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
        ) / prompt_lens.unsqueeze(1)
251
252


253
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
254
255


256
257
258
259
260
261
262
263
class BasePoolerActivation(nn.Module, ABC):
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
264

265

266
267
268
269
270
271
272
class PoolerActivation(BasePoolerActivation):
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
273

274
275
276
277
278
279
280
281
282
283
284
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
285
286


287
288
class PoolerIdentity(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
289
290
291
        return pooled_data


292
293
class PoolerNormalize(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
294
        return F.normalize(pooled_data, p=2, dim=-1)
295
296


297
298
class PoolerMultiLabelClassify(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
299
        return F.sigmoid(pooled_data)
300
301


302
class PoolerClassify(PoolerActivation):
303
304
    def __init__(self, *, static_num_labels: bool = True) -> None:
        super().__init__()
305

306
307
        if static_num_labels:
            vllm_config = get_current_vllm_config()
308
309
310
            self.num_labels = getattr(
                vllm_config.model_config.hf_config, "num_labels", 0
            )
311
            if self.num_labels == 0:
312
313
314
315
316
                logger.warning(
                    "num_labels should be > 0 for classification"
                    "models, falling back to softmax. "
                    "Please check if the configuration is correct."
                )
317
318
        else:
            self.num_labels = None
319
320

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
321
322
323
        num_labels = (
            self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
        )
324

325
        if num_labels < 2:
326
            return F.sigmoid(pooled_data)
327

328
        return F.softmax(pooled_data, dim=-1)
329
330
331
332
333
334
335
336
337
338
339
340


class LambdaPoolerActivation(PoolerActivation):
    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
    def for_token_embed(pooler_config: PoolerConfig):
        head = TokenEmbeddingPoolerHead()

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_token_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None = None,
        act_fn: PoolerActivation | str | None = None,
    ):
        head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="embed",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
        head = EmbeddingPoolerHead()

        return SimplePooler(pooling=pooling, head=head)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="classify",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)

        return ClassifierPooler(
            pooling=pooling,
            classifier=classifier,
            act_fn=act_fn,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        """
        Construct the updated pooling parameters to use for a supported task.
        """
        return PoolingParamsUpdate()

    @abstractmethod
    def forward(
        self,
        hidden_states: list[torch.Tensor] | torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


417
class PoolerHead(nn.Module):
418
    def __init__(self, activation: PoolerActivation) -> None:
419
        super().__init__()
420
        self.activation = activation
421

422
423
    def forward(
        self,
424
        pooled_data: list[torch.Tensor] | torch.Tensor,
425
426
        pooling_metadata: PoolingMetadata,
    ):
427
428
429
430
431
432
433
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):
    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

434
435
        # Load ST projector if available
        vllm_config = get_current_vllm_config()
436
        self.projector: nn.Module | None = (
437
438
            _load_st_projector(vllm_config.model_config) if vllm_config else None
        )
439
        self.head_dtype = vllm_config.model_config.head_dtype
440

441
442
    def forward(
        self,
443
        pooled_data: list[torch.Tensor] | torch.Tensor,
444
445
        pooling_metadata: PoolingMetadata,
    ):
446
447
448
449
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_dimension]

450
451
        pooled_data = pooled_data.to(self.head_dtype)

452
453
        # Apply ST projector
        if self.projector is not None:
454
            pooled_data = self.projector(pooled_data)
455
        # pooled_data shape: [batchsize, embedding_dimension]
456

457
        pooling_params = get_pooling_params(pooling_metadata)
458

459
        # for matryoshka representation
460
        dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
461
462
463
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
464
            if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
465
466
467
468
469
470
471
472
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
473

474
475
476
477
478
479
480
481
482
483
484
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

485
        # pooled_data shape: [batchsize, embedding_dimension]
486
487
488
        return pooled_data


489
class SimplePooler(Pooler):
490
    """A layer that pools specific information from hidden states.
491

492
493
494
495
496
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
497

498
499
500
501
502
503
    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

504
505
506
507
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
508
        return self.pooling.get_pooling_updates(task)
509

510
511
    def forward(
        self,
512
        hidden_states: torch.Tensor | list[torch.Tensor],
513
514
515
516
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
517
        return pooled_data
518
519


520
class ClassifierPooler(Pooler):
521
    """A pooling layer for classification tasks.
522
523

    This layer does the following:
524
525
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
526
    3. Applies an activation function to the output.
527
528
    """

529
    @staticmethod
530
531
532
533
534
535
    def act_fn_for_seq_cls(model_config: ModelConfig):
        return get_classification_activation_function(model_config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(model_config: ModelConfig):
        return get_cross_encoder_activation_function(model_config.hf_config)
536
537

    @staticmethod
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
    def resolve_act_fn(
        model_config: ModelConfig,
        static_num_labels: bool = True,
        act_fn: PoolerActivation | str | None = None,
    ):
        if isinstance(act_fn, str):
            if act_fn == "classify":
                return ClassifierPooler.act_fn_for_seq_cls(model_config)
            elif act_fn == "score":
                return ClassifierPooler.act_fn_for_cross_encoder(model_config)
            else:
                raise ValueError(f"act_fn [{act_fn=}] not supported.")
        elif act_fn is None:
            return PoolerClassify(static_num_labels=static_num_labels)
        else:
            assert callable(act_fn)
            return act_fn
555

556
557
    def __init__(
        self,
558
        pooling: PoolingFn,
559
        classifier: ClassifierFn | None,
560
        act_fn: PoolerActivation | str | None = None,
561
    ) -> None:
562
        super().__init__()
563

564
        vllm_config = get_current_vllm_config()
565
        self.pooling = pooling
566
        self.classifier = classifier
567
568
569
        self.act_fn = self.resolve_act_fn(
            vllm_config.model_config, static_num_labels=True, act_fn=act_fn
        )
570
        self.logit_bias: float | None = (
571
572
            vllm_config.model_config.pooler_config.logit_bias
        )
573
        self.head_dtype = vllm_config.model_config.head_dtype
574

575
576
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
577

578
579
    def forward(
        self,
580
        hidden_states: torch.Tensor | list[torch.Tensor],
581
582
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
583
        pooled_data = self.pooling(hidden_states, pooling_metadata)
584
585
586
587
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

588
589
        pooled_data = pooled_data.to(self.head_dtype)

590
        if self.classifier is not None:
591
592
            pooled_data = self.classifier(pooled_data)
        # pooled_data shape: [batchsize, num_labels]
593

594
595
596
        if self.logit_bias is not None:
            pooled_data -= self.logit_bias

597
598
599
600
601
602
603
        pooling_params = get_pooling_params(pooling_metadata)
        flags = [p.activation for p in pooling_params]

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
604
                self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
605
            ]
606

607
        # scores shape: [batchsize, num_labels]
608
        return scores
609
610


611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
    def forward(
        self, pooled_data: torch.Tensor, pooling_param: PoolingParams
    ) -> torch.Tensor:
        pooled_data = pooled_data.to(self.head_dtype)
        # pooled_data shape: [n_tokens, hidden_dimension]

        # Apply ST projector
        if self.projector is not None:
            pooled_data = self.projector(pooled_data)
        # pooled_data shape: [n_tokens, embedding_dimension]

        # for matryoshka representation
        pooled_data = pooled_data[..., : pooling_param.dimensions]

        # for normalize
        if pooling_param.normalize:
            pooled_data = self.activation(pooled_data)

        # pooled_data shape: [n_tokens, embedding_dimension]
        return pooled_data


class TokenClassifierPoolerHead(nn.Module):
    def __init__(
        self,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ) -> None:
        super().__init__()
        vllm_config = get_current_vllm_config()

        self.classifier = classifier
        self.act_fn = ClassifierPooler.resolve_act_fn(
            vllm_config.model_config, static_num_labels=False, act_fn=act_fn
        )
        self.logit_bias: float | None = (
            vllm_config.model_config.pooler_config.logit_bias
        )
        self.head_dtype = vllm_config.model_config.head_dtype

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_classify"}

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_param: PoolingParams,
    ) -> torch.Tensor:
        hidden_states = hidden_states.to(self.head_dtype)
        # hidden_states shape: [n_token, hidden_size]

        if self.classifier is not None:
            scores = self.classifier(hidden_states)
        else:
            scores = hidden_states
        # scores shape: [n_token, num_labels]

        if self.logit_bias is not None:
            scores -= self.logit_bias

        if pooling_param.activation:
            scores = self.act_fn(scores)

        # scores shape: [n_token, num_labels]
        return scores


class AllPooler(Pooler):
    def __init__(self, head: nn.Module | PoolerHead) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooling_params = get_pooling_params(pooling_metadata)
        assert len(pooled_data) == len(pooling_params)

        pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
        return pooled_data


class StepPooler(Pooler):
    def __init__(self, head: nn.Module | PoolerHead) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def extract_states(
        self,
        hidden_states: torch.Tensor | list[torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> torch.Tensor | list[torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)

        pooled_data = list[torch.Tensor]()

        pooling_params = get_pooling_params(pooling_metadata)

        for data, token_id, pooling_param in zip(
            pooled_data_lst, prompt_token_ids, pooling_params
        ):
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)

    def forward(
        self,
        hidden_states: torch.Tensor | list[torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooling_params = get_pooling_params(pooling_metadata)
        assert len(pooled_data) == len(pooling_params)

        pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
        return pooled_data


755
756
757
758
759
760
761
762
763
764
class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
765
766
                    f"Supported tasks: {pooler.get_supported_tasks()}"
                )
767
768
769
770
771

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
772

773
774
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
775

776
777
    def forward(
        self,
778
        hidden_states: torch.Tensor | list[torch.Tensor],
779
780
781
782
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

783
        outputs = list[torch.Tensor]()
784
785
786
787
788
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
789
790
                    f"Supported tasks: {self.get_supported_tasks()}"
                )
791
792
793

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
794
                hidden_states,
795
                pooling_metadata[offset : offset + num_items],
796
797
            )

798
            outputs.extend(group_output)
799
800
            offset += num_items

801
        return outputs
802
803
804
805

    def extra_repr(self) -> str:
        s = f"supported_task={self.get_supported_tasks()}"
        return s