pooler.py 23.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import Callable, Optional, TypeVar, Union, cast
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig
16
from vllm.logger import init_logger
17
from vllm.pooling_params import PoolingParams
18
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
19
from vllm.tasks import PoolingTask
20
from vllm.utils import current_stream, resolve_obj_by_qualname
21
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
22

23
24
logger = init_logger(__name__)

25
26
27
28
PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
29
30
31
32
33


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
34
    ALL = 1
35
    CLS = 2
36
    STEP = 3
37
    MEAN = 4
38
39


40
41
42
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
43
    task: PoolingTask
44

45
    @classmethod
46
    def from_config(
47
        cls,
48
        task: PoolingTask,
49
50
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
51
        assert pooler_config.pooling_type is not None
52
        return cls(task=task,
53
                   pooling_type=PoolingType[pooler_config.pooling_type])
54
55


56
57
58
59
60
61
62
63
64
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


65
66
67
68
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
69
70
    def for_encode(pooler_config: PoolerConfig):
        if pooler_config.pooling_type == "STEP":
71
            return StepPooler()
72

73
74
75
        resolved_config = ResolvedPoolingConfig(task="encode",
                                                pooling_type=PoolingType.ALL)

76
77
        return SimplePooler.from_config(resolved_config)

78
    @staticmethod
79
80
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
81
            task="embed",
82
83
84
85
86
87
88
89
90
91
            pooler_config=pooler_config,
        )

        return SimplePooler.from_config(resolved_config)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: Optional[ClassifierFn],
    ):
92
        resolved_config = ResolvedPoolingConfig.from_config(
93
            task="classify",
94
95
            pooler_config=pooler_config,
        )
96
97

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
98
99

        return ClassifierPooler(
100
            pooling=pooling,
101
102
103
104
105
106
107
108
109
            classifier=classifier,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
110
        """
111
        Construct the updated pooling parameters to use for a supported task.
112
        """
113
        return PoolingParamsUpdate()
114
115
116
117
118
119
120
121
122
123

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


124
125
126
127
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
128
    return pooling_metadata.prompt_lens
129
130


131
132
def get_prompt_token_ids(
        pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
133
134
    assert pooling_metadata.prompt_token_ids is not None, (
        "Please set `requires_token_ids=True` in `get_pooling_updates`")
135
136

    return [
137
138
        pooling_metadata.prompt_token_ids[i, :num]
        for i, num in enumerate(pooling_metadata.prompt_lens)
139
140
141
    ]


142
143
def get_pooling_params(
        pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
144
    pooling_params = pooling_metadata.pooling_params
145
146
147
148
149
    return pooling_params


def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    pooling_params = get_pooling_params(pooling_metadata)
150
151
152
153
154
155
156
157
158
159

    tasks: list[PoolingTask] = [
        task for pooling_param in pooling_params
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


160
def get_classification_activation_function(config: PretrainedConfig):
161
162
163
164
165
166
167
168
169
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
188

189
    return PoolerClassify()
190

191

192
193
def build_output(
    all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
194
195
196
197
198
199
200
    # Pooling models D2H & synchronize occurs here
    if isinstance(all_data, list):
        all_data = [d.to("cpu", non_blocking=True) for d in all_data]
    else:
        all_data = all_data.to("cpu", non_blocking=True)
    current_stream().synchronize()

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

220
    @abstractmethod
221
    def get_supported_tasks(self) -> Set[PoolingTask]:
222
223
        raise NotImplementedError

224
225
226
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

227
228
    @abstractmethod
    def forward_all(
229
        self,
230
        hidden_states: torch.Tensor,
231
        pooling_cursor: PoolingCursor,
232
233
234
235
236
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
237
        hidden_states: torch.Tensor,
238
        pooling_metadata: PoolingMetadata,
239
    ) -> Union[list[torch.Tensor], torch.Tensor]:
240
241
        pooling_cursor = pooling_metadata.pooling_cursor
        return self.forward_all(hidden_states, pooling_cursor)
242
243


244
245
class CLSPool(PoolingMethod):

246
247
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
248

249
250
251
    def forward_all(
        self,
        hidden_states: torch.Tensor,
252
        pooling_cursor: PoolingCursor,
253
    ) -> Union[list[torch.Tensor], torch.Tensor]:
254
255
256
257
        assert not pooling_cursor.is_partial_prefill(), \
            "partial prefill not supported with CLS pooling"

        return hidden_states[pooling_cursor.first_token_indices_gpu]
258
259


260
class LastPool(PoolingMethod):
261

262
263
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
264

265
266
267
    def forward_all(
        self,
        hidden_states: torch.Tensor,
268
        pooling_cursor: PoolingCursor,
269
    ) -> Union[list[torch.Tensor], torch.Tensor]:
270
        return hidden_states[pooling_cursor.last_token_indices_gpu]
271
272


273
class AllPool(PoolingMethod):
274

275
276
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
277

278
279
280
    def forward_all(
        self,
        hidden_states: torch.Tensor,
281
        pooling_cursor: PoolingCursor,
282
    ) -> Union[list[torch.Tensor], torch.Tensor]:
283
284
285
286
287
288
289
290

        assert not pooling_cursor.is_partial_prefill(), \
            "partial prefill not supported with ALL pooling"

        hidden_states_lst = list(
            hidden_states.split(
                pooling_cursor.num_scheduled_tokens_cpu.tolist()))
        return [hidden_states_lst[i] for i in pooling_cursor.index]
291
292


293
class MeanPool(PoolingMethod):
294

295
296
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
297

298
    def forward_all(
299
        self,
300
        hidden_states: torch.Tensor,
301
302
303
304
        pooling_cursor: PoolingCursor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:

        assert not pooling_cursor.is_partial_prefill(), \
305
            "partial prefill not supported with MEAN pooling"
306

307
308
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device,
                                                        non_blocking=True)
309

310
311
312
313
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

314
315
316
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
        return (cumsum[end_indices] - cumsum[start_indices] +
317
318
319
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


320
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
321
322


323
class BasePoolerActivation(nn.Module, ABC):
324

325
326
327
328
329
330
331
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
332

333

334
class PoolerActivation(BasePoolerActivation):
335

336
337
338
339
340
341
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
342

343
344
345
346
347
348
349
350
351
352
353
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
354
355


356
357
358
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
359
360
361
        return pooled_data


362
363
364
365
366
367
368
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


369
370
371
372
373
374
class PoolerMultiLabelClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)


375
376
class PoolerClassify(PoolerActivation):

377
378
    def __init__(self, *, static_num_labels: bool = True) -> None:
        super().__init__()
379

380
381
382
383
384
385
386
387
388
389
390
        if static_num_labels:
            from vllm.config import get_current_vllm_config
            vllm_config = get_current_vllm_config()
            self.num_labels = getattr(vllm_config.model_config.hf_config,
                                      "num_labels", 0)
            if self.num_labels == 0:
                logger.warning("num_labels should be > 0 for classification"
                               "models, falling back to softmax. "
                               "Please check if the configuration is correct.")
        else:
            self.num_labels = None
391
392

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
393
394
395
        num_labels = (self.num_labels if self.num_labels is not None else
                      pooled_data.shape[-1])

396
397
398
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

399
        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
400
401
402
403
404
405
406
407
408
409
410
411
412


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


413
414
class PoolerHead(nn.Module):

415
    def __init__(self, activation: PoolerActivation) -> None:
416
        super().__init__()
417
        self.activation = activation
418

419
420
421
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

422
423
424
425
426
427
428
429
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

430
431
432
433
434
435
436
437
        # Load ST projector if available
        from vllm.config import get_current_vllm_config
        from vllm.model_executor.models.adapters import _load_st_projector

        vllm_config = get_current_vllm_config()
        self.projector = _load_st_projector(
            vllm_config.model_config) if vllm_config else None

438
439
440
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

441
442
443
444
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_dimension]

445
446
447
448
449
450
451
452
453
        # Apply ST projector
        if self.projector is not None:
            projector = cast(nn.Module, self.projector)

            def _proj(x: torch.Tensor) -> torch.Tensor:
                orig_dtype = x.dtype
                y = projector(x.to(torch.float32))
                return y.to(orig_dtype)

454
455
            pooled_data = _proj(pooled_data)
        # pooled_data shape: [batchsize, embedding_dimension]
456

457
        pooling_params = get_pooling_params(pooling_metadata)
458

459
        # for matryoshka representation
460
461
462
        dimensions_list = [
            pooling_param.dimensions for pooling_param in pooling_params
        ]
463
464
465
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
466
467
468
469
470
471
472
473
474
475
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
476

477
478
479
480
481
482
483
484
485
486
487
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

488
        # pooled_data shape: [batchsize, embedding_dimension]
489
490
491
492
493
494
        return pooled_data


class RewardPoolerHead(PoolerHead):

    def __init__(self) -> None:
495
        super().__init__(activation=PoolerClassify(static_num_labels=False))
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512

    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):
        pooling_params = get_pooling_params(pooling_metadata)

        # for softmax
        flags = [p.softmax for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data
513
514


515
class SimplePooler(Pooler):
516
    """A layer that pools specific information from hidden states.
517

518
519
520
521
522
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
523

524
525
526
527
528
529
    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
530
531
532
533
534
535
        if pooler_config.task == "embed":
            head = EmbeddingPoolerHead()
        elif pooler_config.task == "encode":
            head = RewardPoolerHead()
        else:
            raise NotImplementedError(f"Unknown task: {pooler_config.task}")
536
537
538
539
540
541
542
543
        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

544
545
546
547
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
548
        return self.pooling.get_pooling_updates(task)
549

550
551
552
553
554
555
556
557
558
559
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


560
class StepPooler(Pooler):
561

562
    def __init__(self, ) -> None:
563
564
565
        super().__init__()

        self.pooling = AllPool()
566
        self.head = RewardPoolerHead()
567
568
569
570
571
572
573

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
574
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
575
576
577

        pooled_data = list[torch.Tensor]()

578
579
580
581
582
583
584
585
        pooling_params = get_pooling_params(pooling_metadata)

        for data, token_id, pooling_param in zip(pooled_data_lst,
                                                 prompt_token_ids,
                                                 pooling_params):
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

586
587
588
589
590
591
592
593
594
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

595
596
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
597

598
599
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
600

601
602
603
604
605
606
607
608
609
610
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


611
class ClassifierPooler(Pooler):
612
    """A pooling layer for classification tasks.
613
614

    This layer does the following:
615
616
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
617
    3. Applies an activation function to the output.
618
619
    """

620
621
622
623
624
625
626
627
    @staticmethod
    def act_fn_for_seq_cls(config: ModelConfig):
        return get_classification_activation_function(config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(config: ModelConfig):
        return get_cross_encoder_activation_function(config.hf_config)

628
629
    def __init__(
        self,
630
        pooling: PoolingFn,
631
632
        classifier: Optional[ClassifierFn],
        act_fn: Optional[PoolerActivation] = None,
633
    ) -> None:
634
        super().__init__()
635

636
637
638
        from vllm.config import get_current_vllm_config
        vllm_config = get_current_vllm_config()

639
        self.pooling = pooling
640
        self.classifier = classifier
641
        self.act_fn = act_fn or PoolerClassify()
642
643
        self.logit_bias: Optional[
            float] = vllm_config.model_config.pooler_config.logit_bias
644

645
646
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
647

648
649
    def forward(
        self,
650
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
651
652
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
653
        pooled_data = self.pooling(hidden_states, pooling_metadata)
654
655
656
657
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

658
        if self.classifier is not None:
659
660
            pooled_data = self.classifier(pooled_data)
        # pooled_data shape: [batchsize, num_labels]
661

662
663
664
        if self.logit_bias is not None:
            pooled_data -= self.logit_bias

665
666
667
668
669
670
671
672
673
674
        pooling_params = get_pooling_params(pooling_metadata)
        flags = [p.activation for p in pooling_params]

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
                self.act_fn(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]
675

676
        # scores shape: [batchsize, num_labels]
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        return build_output(scores)


class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
                    f"Supported tasks: {pooler.get_supported_tasks()}")

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
696

697
698
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
699

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

        outputs = list[PoolingSequenceGroupOutput]()
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
                    f"Supported tasks: {self.get_supported_tasks()}")

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
717
                hidden_states,
718
719
720
721
722
723
724
                pooling_metadata[offset:offset + num_items],
            )

            outputs.extend(group_output.outputs)
            offset += num_items

        return PoolerOutput(outputs)