"tests/vscode:/vscode.git/clone" did not exist on "98b09ddc2761545f3164d930b143f84737b1ab43"
pooler.py 23.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import Callable, Optional, TypeVar, Union
guanyu1's avatar
guanyu1 committed
9
from torch.nn.utils.rnn import pad_sequence
10
11
import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
16
from vllm.logger import init_logger
17
from vllm.model_executor.models.adapters import _load_st_projector
18
19
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
20
21
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput
22
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
23

24
25
logger = init_logger(__name__)

26
27
28
29
PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
30
31
32
33
34


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
35
    ALL = 1
36
    CLS = 2
37
    STEP = 3
38
    MEAN = 4
39
40


41
42
43
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
44
    task: PoolingTask
45

46
    @classmethod
47
    def from_config(
48
        cls,
49
        task: PoolingTask,
50
51
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
52
        assert pooler_config.pooling_type is not None
53
        return cls(task=task,
54
                   pooling_type=PoolingType[pooler_config.pooling_type])
55
56


57
58
59
60
61
62
63
64
65
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


66
67
68
69
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
70
71
    def for_encode(pooler_config: PoolerConfig):
        if pooler_config.pooling_type == "STEP":
72
            return StepPooler()
73

74
75
76
        resolved_config = ResolvedPoolingConfig(task="encode",
                                                pooling_type=PoolingType.ALL)

77
78
        return SimplePooler.from_config(resolved_config)

79
    @staticmethod
80
81
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
82
            task="embed",
83
84
85
86
87
88
89
90
91
92
            pooler_config=pooler_config,
        )

        return SimplePooler.from_config(resolved_config)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: Optional[ClassifierFn],
    ):
93
        resolved_config = ResolvedPoolingConfig.from_config(
94
            task="classify",
95
96
            pooler_config=pooler_config,
        )
97
98

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
99
100

        return ClassifierPooler(
101
            pooling=pooling,
102
103
104
105
106
107
108
109
110
            classifier=classifier,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
111
        """
112
        Construct the updated pooling parameters to use for a supported task.
113
        """
114
        return PoolingParamsUpdate()
115
116
117
118
119
120
121
122
123
124

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


125
126
127
128
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
129
    return pooling_metadata.prompt_lens
130
131


132
133
def get_prompt_token_ids(
        pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
134
135
    assert pooling_metadata.prompt_token_ids is not None, (
        "Please set `requires_token_ids=True` in `get_pooling_updates`")
136
137

    return [
138
139
        pooling_metadata.prompt_token_ids[i, :num]
        for i, num in enumerate(pooling_metadata.prompt_lens)
140
141
142
    ]


143
144
def get_pooling_params(
        pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
145
    pooling_params = pooling_metadata.pooling_params
146
147
148
149
150
    return pooling_params


def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    pooling_params = get_pooling_params(pooling_metadata)
151
152
153
154
155
156
157
158
159
160

    tasks: list[PoolingTask] = [
        task for pooling_param in pooling_params
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


161
def get_classification_activation_function(config: PretrainedConfig):
162
163
164
165
166
167
168
169
170
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
189

190
    return PoolerClassify()
191

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

208
    @abstractmethod
209
    def get_supported_tasks(self) -> Set[PoolingTask]:
210
211
        raise NotImplementedError

212
213
214
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

215
216
    @abstractmethod
    def forward_all(
217
        self,
218
        hidden_states: torch.Tensor,
219
        pooling_cursor: PoolingCursor,
220
221
222
223
224
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
225
        hidden_states: torch.Tensor,
226
        pooling_metadata: PoolingMetadata,
227
    ) -> Union[list[torch.Tensor], torch.Tensor]:
228
229
        pooling_cursor = pooling_metadata.pooling_cursor
        return self.forward_all(hidden_states, pooling_cursor)
230
231


232
233
class CLSPool(PoolingMethod):

234
235
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
236

237
238
239
    def forward_all(
        self,
        hidden_states: torch.Tensor,
240
        pooling_cursor: PoolingCursor,
241
    ) -> Union[list[torch.Tensor], torch.Tensor]:
242
243
244
245
        assert not pooling_cursor.is_partial_prefill(), \
            "partial prefill not supported with CLS pooling"

        return hidden_states[pooling_cursor.first_token_indices_gpu]
246
247


248
class LastPool(PoolingMethod):
249

250
251
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
252

253
254
255
    def forward_all(
        self,
        hidden_states: torch.Tensor,
256
        pooling_cursor: PoolingCursor,
257
    ) -> Union[list[torch.Tensor], torch.Tensor]:
258
        return hidden_states[pooling_cursor.last_token_indices_gpu]
259
260


261
class AllPool(PoolingMethod):
262

263
264
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
265

266
267
268
    def forward_all(
        self,
        hidden_states: torch.Tensor,
269
        pooling_cursor: PoolingCursor,
270
    ) -> Union[list[torch.Tensor], torch.Tensor]:
271
272
273
274
275
276
277
278

        assert not pooling_cursor.is_partial_prefill(), \
            "partial prefill not supported with ALL pooling"

        hidden_states_lst = list(
            hidden_states.split(
                pooling_cursor.num_scheduled_tokens_cpu.tolist()))
        return [hidden_states_lst[i] for i in pooling_cursor.index]
279
280


281
class MeanPool(PoolingMethod):
282

283
284
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
285

286
    def forward_all(
287
        self,
288
        hidden_states: torch.Tensor,
289
290
291
292
        pooling_cursor: PoolingCursor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:

        assert not pooling_cursor.is_partial_prefill(), \
293
            "partial prefill not supported with MEAN pooling"
294

295
296
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device,
                                                        non_blocking=True)
297

298
299
300
301
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

302
303
304
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
        return (cumsum[end_indices] - cumsum[start_indices] +
305
306
307
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


308
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
309
310


311
class BasePoolerActivation(nn.Module, ABC):
312

313
314
315
316
317
318
319
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
320

321

322
class PoolerActivation(BasePoolerActivation):
323

324
325
326
327
328
329
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
330

331
332
333
334
335
336
337
338
339
340
341
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
342
343


344
345
346
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
347
348
349
        return pooled_data


350
351
352
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
353
        return F.normalize(pooled_data, p=2, dim=-1)
354
355


356
357
358
class PoolerMultiLabelClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
359
        return F.sigmoid(pooled_data)
360
361


362
363
class PoolerClassify(PoolerActivation):

364
365
    def __init__(self, *, static_num_labels: bool = True) -> None:
        super().__init__()
366

367
368
369
370
371
372
373
374
375
376
        if static_num_labels:
            vllm_config = get_current_vllm_config()
            self.num_labels = getattr(vllm_config.model_config.hf_config,
                                      "num_labels", 0)
            if self.num_labels == 0:
                logger.warning("num_labels should be > 0 for classification"
                               "models, falling back to softmax. "
                               "Please check if the configuration is correct.")
        else:
            self.num_labels = None
377
378

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
379
380
381
        num_labels = (self.num_labels if self.num_labels is not None else
                      pooled_data.shape[-1])

382
        if num_labels < 2:
383
            return F.sigmoid(pooled_data)
384

385
        return F.softmax(pooled_data, dim=-1)
386
387
388
389
390
391
392
393
394
395
396
397
398


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


399
400
class PoolerHead(nn.Module):

401
    def __init__(self, activation: PoolerActivation) -> None:
402
        super().__init__()
403
        self.activation = activation
404

405
406
407
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

408
409
410
411
412
413
414
415
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

416
417
418
        # Load ST projector if available

        vllm_config = get_current_vllm_config()
419
        self.projector: Optional[nn.Module] = _load_st_projector(
420
            vllm_config.model_config) if vllm_config else None
421
        self.head_dtype = vllm_config.model_config.head_dtype
422

423
424
425
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

426
427
428
429
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_dimension]

430
431
        pooled_data = pooled_data.to(self.head_dtype)

432
433
        # Apply ST projector
        if self.projector is not None:
434
            pooled_data = self.projector(pooled_data)
435
        # pooled_data shape: [batchsize, embedding_dimension]
436

437
        pooling_params = get_pooling_params(pooling_metadata)
438

439
        # for matryoshka representation
440
441
442
        dimensions_list = [
            pooling_param.dimensions for pooling_param in pooling_params
        ]
443
444
445
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
446
447
448
449
450
451
452
453
454
455
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
456

457
458
459
460
461
462
463
464
465
466
467
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

468
        # pooled_data shape: [batchsize, embedding_dimension]
469
470
471
472
473
474
        return pooled_data


class RewardPoolerHead(PoolerHead):

    def __init__(self) -> None:
475
        super().__init__(activation=PoolerClassify(static_num_labels=False))
476

477
478
479
        vllm_config = get_current_vllm_config()
        self.head_dtype = vllm_config.model_config.head_dtype

480
481
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):
482
483
484
485
486
487

        if isinstance(pooled_data, list):
            pooled_data = [p.to(self.head_dtype) for p in pooled_data]
        else:
            pooled_data = pooled_data.to(self.head_dtype)

488
489
490
491
492
493
494
495
496
497
498
499
500
501
        pooling_params = get_pooling_params(pooling_metadata)

        # for softmax
        flags = [p.softmax for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data
502
503


504
class SimplePooler(Pooler):
505
    """A layer that pools specific information from hidden states.
506

507
508
509
510
511
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
512

513
514
515
516
517
518
    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
519
520
521
522
523
524
        if pooler_config.task == "embed":
            head = EmbeddingPoolerHead()
        elif pooler_config.task == "encode":
            head = RewardPoolerHead()
        else:
            raise NotImplementedError(f"Unknown task: {pooler_config.task}")
525
526
527
528
529
530
531
532
        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

533
534
535
536
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
537
        return self.pooling.get_pooling_updates(task)
538

539
540
541
542
543
544
545
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
546
        return pooled_data
547
548


549
class StepPooler(Pooler):
550

551
    def __init__(self, ) -> None:
552
553
554
        super().__init__()

        self.pooling = AllPool()
555
        self.head = RewardPoolerHead()
556
557
558
559
560
561
562

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
563
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
564
565
566

        pooled_data = list[torch.Tensor]()

567
568
569
570
571
572
573
574
        pooling_params = get_pooling_params(pooling_metadata)

        for data, token_id, pooling_param in zip(pooled_data_lst,
                                                 prompt_token_ids,
                                                 pooling_params):
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

575
576
577
578
579
580
581
582
583
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

584
585
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
586

587
588
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
589

590
591
592
593
594
595
596
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
597
        return pooled_data
598
599


600
class ClassifierPooler(Pooler):
601
    """A pooling layer for classification tasks.
602
603

    This layer does the following:
604
605
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
606
    3. Applies an activation function to the output.
607
608
    """

609
610
611
612
613
614
615
616
    @staticmethod
    def act_fn_for_seq_cls(config: ModelConfig):
        return get_classification_activation_function(config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(config: ModelConfig):
        return get_cross_encoder_activation_function(config.hf_config)

617
618
    def __init__(
        self,
619
        pooling: PoolingFn,
620
621
        classifier: Optional[ClassifierFn],
        act_fn: Optional[PoolerActivation] = None,
622
    ) -> None:
623
        super().__init__()
624

625
626
        vllm_config = get_current_vllm_config()

627
        self.pooling = pooling
628
        self.classifier = classifier
629
        self.act_fn = act_fn or PoolerClassify()
630
631
        self.logit_bias: Optional[
            float] = vllm_config.model_config.pooler_config.logit_bias
632
        self.head_dtype = vllm_config.model_config.head_dtype
633

634
635
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
636

637
638
    def forward(
        self,
639
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
640
641
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
642
        pooled_data = self.pooling(hidden_states, pooling_metadata)
643
        if isinstance(pooled_data, list):
guanyu1's avatar
guanyu1 committed
644
645
            pooled_data = pad_sequence(pooled_data, batch_first=True, padding_value=0.0)
            #pooled_data = torch.stack(pooled_data)
646
647
        # pooled_data shape: [batchsize, hidden_size]

648
        pooled_data = pooled_data.to(self.head_dtype)
649
        if self.classifier is not None:
guanyu1's avatar
guanyu1 committed
650
651
            pooled_data = self.classifier(pooled_data,pooling_metadata)
        
652

653
654
655
        if self.logit_bias is not None:
            pooled_data -= self.logit_bias

656
657
658
659
660
661
662
663
664
665
        pooling_params = get_pooling_params(pooling_metadata)
        flags = [p.activation for p in pooling_params]

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
                self.act_fn(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]
666

667
        # scores shape: [batchsize, num_labels]
668
        return scores
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686


class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
                    f"Supported tasks: {pooler.get_supported_tasks()}")

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
687

688
689
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
690

691
692
693
694
695
696
697
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

698
        outputs = list[torch.Tensor]()
699
700
701
702
703
704
705
706
707
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
                    f"Supported tasks: {self.get_supported_tasks()}")

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
708
                hidden_states,
709
710
711
                pooling_metadata[offset:offset + num_items],
            )

712
            outputs.extend(group_output)
713
714
            offset += num_items

715
        return outputs
716
717
718
719

    def extra_repr(self) -> str:
        s = f"supported_task={self.get_supported_tasks()}"
        return s