pooler.py 23.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import Callable, Optional, TypeVar, Union
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig
16
17
18
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
19
from vllm.pooling_params import PoolingParams
20
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
21
from vllm.tasks import PoolingTask
22
from vllm.utils import resolve_obj_by_qualname
23
24
25
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
26
27
28
29
PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
30
31
32
33
34


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
35
    ALL = 1
36
    CLS = 2
37
    STEP = 3
38
    MEAN = 4
39
40


41
42
43
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
44
    task: PoolingTask
45

46
    @classmethod
47
    def from_config(
48
        cls,
49
        task: PoolingTask,
50
51
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
52
        assert pooler_config.pooling_type is not None
53
        return cls(task=task,
54
                   pooling_type=PoolingType[pooler_config.pooling_type])
55
56


57
58
59
60
61
62
63
64
65
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


66
67
68
69
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
70
71
    def for_encode(pooler_config: PoolerConfig):
        if pooler_config.pooling_type == "STEP":
72
            return StepPooler()
73

74
75
76
        resolved_config = ResolvedPoolingConfig(task="encode",
                                                pooling_type=PoolingType.ALL)

77
78
        return SimplePooler.from_config(resolved_config)

79
    @staticmethod
80
81
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
82
            task="embed",
83
84
85
86
87
88
89
90
91
92
            pooler_config=pooler_config,
        )

        return SimplePooler.from_config(resolved_config)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: Optional[ClassifierFn],
    ):
93
        resolved_config = ResolvedPoolingConfig.from_config(
94
            task="classify",
95
96
            pooler_config=pooler_config,
        )
97
98

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
99
100

        return ClassifierPooler(
101
            pooling=pooling,
102
103
104
105
106
107
108
109
110
            classifier=classifier,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
111
        """
112
        Construct the updated pooling parameters to use for a supported task.
113
        """
114
        return PoolingParamsUpdate()
115
116
117
118
119
120
121
122
123
124

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


125
126
127
128
129
130
131
132
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        return pooling_metadata.prompt_lens

    return PoolingTensors.from_pooling_metadata(
133
        pooling_metadata, hidden_states[0].device).prompt_lens
134
135


136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def get_prompt_token_ids(
        pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        assert pooling_metadata.prompt_token_ids is not None, (
            "Please set `requires_token_ids=True` in `get_pooling_updates`")

        return [
            pooling_metadata.prompt_token_ids[i, :num]
            for i, num in enumerate(pooling_metadata.prompt_lens)
        ]

    return [
        torch.tensor(seq_data_i.prompt_token_ids)
        for seq_data_i in pooling_metadata.seq_data.values()
    ]


153
154
def get_pooling_params(
        pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
155
156
157
158
    if isinstance(pooling_metadata, V0PoolingMetadata):
        pooling_params = [p for _, p in pooling_metadata.seq_groups]
    else:
        pooling_params = pooling_metadata.pooling_params
159
160
161
162
163
    return pooling_params


def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    pooling_params = get_pooling_params(pooling_metadata)
164
165
166
167
168
169
170
171
172
173

    tasks: list[PoolingTask] = [
        task for pooling_param in pooling_params
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def get_classification_activation_function(config: PretrainedConfig):
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
193

194
    return PoolerScore()
195

196

197
198
def build_output(
    all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

218
    @abstractmethod
219
    def get_supported_tasks(self) -> Set[PoolingTask]:
220
221
        raise NotImplementedError

222
223
224
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

225
226
227
228
229
    @abstractmethod
    def forward_one(
        self,
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
230
    ) -> torch.Tensor:
231
232
233
234
235
        """
        Note:
            `prompt_len=None` means `prompt_len=len(hidden_states)`.
        """
        raise NotImplementedError
236

237
238
    @abstractmethod
    def forward_all(
239
        self,
240
241
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
242
243
244
245
246
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
247
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
248
        pooling_metadata: PoolingMetadata,
249
250
251
252
253
254
255
256
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)

        if isinstance(hidden_states, list):
            return [
                self.forward_one(h, prompt_len)
                for h, prompt_len in zip(hidden_states, prompt_lens)
            ]
257

258
        return self.forward_all(hidden_states, prompt_lens)
259
260


261
262
class CLSPool(PoolingMethod):

263
264
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
265

266
    def forward_one(
267
        self,
268
269
270
271
272
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with CLS pooling"
273

274
        return hidden_states[0]
275

276
277
278
279
280
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
281
282
283
284
285
        first_token_flat_indices = torch.zeros_like(prompt_lens)
        first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
        return hidden_states[first_token_flat_indices]


286
class LastPool(PoolingMethod):
287

288
289
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
290

291
    def forward_one(
292
        self,
293
294
295
296
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return hidden_states[-1]
297

298
299
300
301
302
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
303
304
305
306
        last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
        return hidden_states[last_token_flat_indices]


307
class AllPool(PoolingMethod):
308

309
310
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
311

312
    def forward_one(
313
        self,
314
315
316
317
318
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with ALL pooling"
319

320
        return hidden_states
321

322
323
324
325
326
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
327
        return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
328
329


330
class MeanPool(PoolingMethod):
331

332
333
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
334

335
    def forward_one(
336
        self,
337
338
339
340
341
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with MEAN pooling"
342

343
        return hidden_states.mean(dim=0, dtype=torch.float32)
344

345
346
347
348
349
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
350
351
352
353
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

354
355
356
357
358
359
360
361
362
        start_indices = torch.cat([
            torch.tensor([0], device=hidden_states.device),
            torch.cumsum(prompt_lens[:-1], dim=0)
        ])
        end_indices = torch.cumsum(prompt_lens, dim=0)
        return (cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


363
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
364
365


366
class BasePoolerActivation(nn.Module, ABC):
367

368
369
370
371
372
373
374
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
375

376

377
class PoolerActivation(BasePoolerActivation):
378

379
380
381
382
383
384
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
385

386
387
388
389
390
391
392
393
394
395
396
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
397
398


399
400
401
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
402
403
404
        return pooled_data


405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


class PoolerClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)


class PoolerScore(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return pooled_data


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


443
444
class PoolerHead(nn.Module):

445
    def __init__(self, activation: PoolerActivation) -> None:
446
        super().__init__()
447
        self.activation = activation
448

449
450
451
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

452
453
454
455
456
457
458
459
460
461
462
463
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

        pooling_params = get_pooling_params(pooling_metadata)
464

465
        # for matryoshka representation
466
467
468
        dimensions_list = [
            pooling_param.dimensions for pooling_param in pooling_params
        ]
469
470
471
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
472
473
474
475
476
477
478
479
480
481
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
482

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data


class RewardPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerClassify())

    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):
        pooling_params = get_pooling_params(pooling_metadata)

        # for softmax
        flags = [p.softmax for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data
518
519


520
class SimplePooler(Pooler):
521
    """A layer that pools specific information from hidden states.
522

523
524
525
526
527
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
528

529
530
531
532
533
534
    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
535
536
537
538
539
540
        if pooler_config.task == "embed":
            head = EmbeddingPoolerHead()
        elif pooler_config.task == "encode":
            head = RewardPoolerHead()
        else:
            raise NotImplementedError(f"Unknown task: {pooler_config.task}")
541
542
543
544
545
546
547
548
        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

549
550
551
552
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
553
        return self.pooling.get_pooling_updates(task)
554

555
556
557
558
559
560
561
562
563
564
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


565
class StepPooler(Pooler):
566

567
    def __init__(self, ) -> None:
568
569
570
        super().__init__()

        self.pooling = AllPool()
571
        self.head = RewardPoolerHead()
572
573
574
575
576
577
578

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
579
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
580
581
582

        pooled_data = list[torch.Tensor]()

583
584
585
586
587
588
589
590
        pooling_params = get_pooling_params(pooling_metadata)

        for data, token_id, pooling_param in zip(pooled_data_lst,
                                                 prompt_token_ids,
                                                 pooling_params):
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

591
592
593
594
595
596
597
598
599
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

600
601
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
602

603
604
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
605

606
607
608
609
610
611
612
613
614
615
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


616
class ClassifierPooler(Pooler):
617
    """A pooling layer for classification tasks.
618
619

    This layer does the following:
620
621
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
622
    3. Applies an activation function to the output.
623
624
    """

625
626
627
628
629
630
631
632
    @staticmethod
    def act_fn_for_seq_cls(config: ModelConfig):
        return get_classification_activation_function(config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(config: ModelConfig):
        return get_cross_encoder_activation_function(config.hf_config)

633
634
    def __init__(
        self,
635
        pooling: PoolingFn,
636
637
        classifier: Optional[ClassifierFn],
        act_fn: Optional[PoolerActivation] = None,
638
    ) -> None:
639
        super().__init__()
640
641

        self.pooling = pooling
642
        self.classifier = classifier
643
        self.act_fn = act_fn or PoolerClassify()
644

645
646
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
647

648
649
    def forward(
        self,
650
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
651
652
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
653
        pooled_data = self.pooling(hidden_states, pooling_metadata)
654

655
656
657
658
659
660
661
662
        if self.classifier is not None:
            # apply classifier once on the full batch if possible
            if isinstance(pooled_data, torch.Tensor):
                pooled_data = self.classifier(pooled_data)
            elif len({data.shape for data in pooled_data}) <= 1:
                pooled_data = self.classifier(torch.stack(pooled_data))
            else:
                pooled_data = [self.classifier(data) for data in pooled_data]
663

664
665
666
667
668
669
670
671
672
673
        pooling_params = get_pooling_params(pooling_metadata)
        flags = [p.activation for p in pooling_params]

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
                self.act_fn(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693

        return build_output(scores)


class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
                    f"Supported tasks: {pooler.get_supported_tasks()}")

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
694

695
696
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
697

698
699
700
701
702
703
704
705
706
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

        if isinstance(hidden_states, list):
            hidden_states_lst = hidden_states
707
        else:
708
709
            prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
            hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
710

711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
        outputs = list[PoolingSequenceGroupOutput]()
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
                    f"Supported tasks: {self.get_supported_tasks()}")

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
                hidden_states_lst[offset:offset + num_items],
                pooling_metadata[offset:offset + num_items],
            )

            outputs.extend(group_output.outputs)
            offset += num_items

        return PoolerOutput(outputs)