pooler.py 22.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import Callable, Optional, TypeVar, Union, cast
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig
16
from vllm.pooling_params import PoolingParams
17
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
18
from vllm.tasks import PoolingTask
19
from vllm.utils import current_stream, resolve_obj_by_qualname
20
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
21

22
23
24
25
PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
26
27
28
29
30


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
31
    ALL = 1
32
    CLS = 2
33
    STEP = 3
34
    MEAN = 4
35
36


37
38
39
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
40
    task: PoolingTask
41

42
    @classmethod
43
    def from_config(
44
        cls,
45
        task: PoolingTask,
46
47
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
48
        assert pooler_config.pooling_type is not None
49
        return cls(task=task,
50
                   pooling_type=PoolingType[pooler_config.pooling_type])
51
52


53
54
55
56
57
58
59
60
61
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


62
63
64
65
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
66
67
    def for_encode(pooler_config: PoolerConfig):
        if pooler_config.pooling_type == "STEP":
68
            return StepPooler()
69

70
71
72
        resolved_config = ResolvedPoolingConfig(task="encode",
                                                pooling_type=PoolingType.ALL)

73
74
        return SimplePooler.from_config(resolved_config)

75
    @staticmethod
76
77
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
78
            task="embed",
79
80
81
82
83
84
85
86
87
88
            pooler_config=pooler_config,
        )

        return SimplePooler.from_config(resolved_config)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: Optional[ClassifierFn],
    ):
89
        resolved_config = ResolvedPoolingConfig.from_config(
90
            task="classify",
91
92
            pooler_config=pooler_config,
        )
93
94

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
95
96

        return ClassifierPooler(
97
            pooling=pooling,
98
99
100
101
102
103
104
105
106
            classifier=classifier,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
107
        """
108
        Construct the updated pooling parameters to use for a supported task.
109
        """
110
        return PoolingParamsUpdate()
111
112
113
114
115
116
117
118
119
120

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


121
122
123
124
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
125
    return pooling_metadata.prompt_lens
126
127


128
129
def get_prompt_token_ids(
        pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
130
131
    assert pooling_metadata.prompt_token_ids is not None, (
        "Please set `requires_token_ids=True` in `get_pooling_updates`")
132
133

    return [
134
135
        pooling_metadata.prompt_token_ids[i, :num]
        for i, num in enumerate(pooling_metadata.prompt_lens)
136
137
138
    ]


139
140
def get_pooling_params(
        pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
141
    pooling_params = pooling_metadata.pooling_params
142
143
144
145
146
    return pooling_params


def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    pooling_params = get_pooling_params(pooling_metadata)
147
148
149
150
151
152
153
154
155
156

    tasks: list[PoolingTask] = [
        task for pooling_param in pooling_params
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


157
def get_classification_activation_function(config: PretrainedConfig):
158
159
160
161
162
163
164
165
166
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
185

186
    return PoolerScore()
187

188

189
190
def build_output(
    all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
191
192
193
194
195
196
197
    # Pooling models D2H & synchronize occurs here
    if isinstance(all_data, list):
        all_data = [d.to("cpu", non_blocking=True) for d in all_data]
    else:
        all_data = all_data.to("cpu", non_blocking=True)
    current_stream().synchronize()

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

217
    @abstractmethod
218
    def get_supported_tasks(self) -> Set[PoolingTask]:
219
220
        raise NotImplementedError

221
222
223
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

224
225
    @abstractmethod
    def forward_all(
226
        self,
227
        hidden_states: torch.Tensor,
228
        pooling_cursor: PoolingCursor,
229
230
231
232
233
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
234
        hidden_states: torch.Tensor,
235
        pooling_metadata: PoolingMetadata,
236
    ) -> Union[list[torch.Tensor], torch.Tensor]:
237
238
        pooling_cursor = pooling_metadata.pooling_cursor
        return self.forward_all(hidden_states, pooling_cursor)
239
240


241
242
class CLSPool(PoolingMethod):

243
244
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
245

246
247
248
    def forward_all(
        self,
        hidden_states: torch.Tensor,
249
        pooling_cursor: PoolingCursor,
250
    ) -> Union[list[torch.Tensor], torch.Tensor]:
251
252
253
254
        assert not pooling_cursor.is_partial_prefill(), \
            "partial prefill not supported with CLS pooling"

        return hidden_states[pooling_cursor.first_token_indices_gpu]
255
256


257
class LastPool(PoolingMethod):
258

259
260
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
261

262
263
264
    def forward_all(
        self,
        hidden_states: torch.Tensor,
265
        pooling_cursor: PoolingCursor,
266
    ) -> Union[list[torch.Tensor], torch.Tensor]:
267
        return hidden_states[pooling_cursor.last_token_indices_gpu]
268
269


270
class AllPool(PoolingMethod):
271

272
273
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
274

275
276
277
    def forward_all(
        self,
        hidden_states: torch.Tensor,
278
        pooling_cursor: PoolingCursor,
279
    ) -> Union[list[torch.Tensor], torch.Tensor]:
280
281
282
283
284
285
286
287

        assert not pooling_cursor.is_partial_prefill(), \
            "partial prefill not supported with ALL pooling"

        hidden_states_lst = list(
            hidden_states.split(
                pooling_cursor.num_scheduled_tokens_cpu.tolist()))
        return [hidden_states_lst[i] for i in pooling_cursor.index]
288
289


290
class MeanPool(PoolingMethod):
291

292
293
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
294

295
    def forward_all(
296
        self,
297
        hidden_states: torch.Tensor,
298
299
300
301
        pooling_cursor: PoolingCursor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:

        assert not pooling_cursor.is_partial_prefill(), \
302
            "partial prefill not supported with MEAN pooling"
303

304
305
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device,
                                                        non_blocking=True)
306

307
308
309
310
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

311
312
313
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
        return (cumsum[end_indices] - cumsum[start_indices] +
314
315
316
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


317
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
318
319


320
class BasePoolerActivation(nn.Module, ABC):
321

322
323
324
325
326
327
328
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
329

330

331
class PoolerActivation(BasePoolerActivation):
332

333
334
335
336
337
338
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
339

340
341
342
343
344
345
346
347
348
349
350
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
351
352


353
354
355
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
356
357
358
        return pooled_data


359
360
361
362
363
364
365
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


366
367
368
369
370
371
class PoolerMultiLabelClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)


372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
class PoolerClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)


class PoolerScore(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return pooled_data


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


403
404
class PoolerHead(nn.Module):

405
    def __init__(self, activation: PoolerActivation) -> None:
406
        super().__init__()
407
        self.activation = activation
408

409
410
411
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

412
413
414
415
416
417
418
419
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

420
421
422
423
424
425
426
427
        # Load ST projector if available
        from vllm.config import get_current_vllm_config
        from vllm.model_executor.models.adapters import _load_st_projector

        vllm_config = get_current_vllm_config()
        self.projector = _load_st_projector(
            vllm_config.model_config) if vllm_config else None

428
429
430
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

431
432
433
434
435
436
437
438
439
440
441
442
443
444
        # Apply ST projector
        if self.projector is not None:
            projector = cast(nn.Module, self.projector)

            def _proj(x: torch.Tensor) -> torch.Tensor:
                orig_dtype = x.dtype
                y = projector(x.to(torch.float32))
                return y.to(orig_dtype)

            if isinstance(pooled_data, torch.Tensor):
                pooled_data = _proj(pooled_data)
            else:
                pooled_data = [_proj(t) for t in pooled_data]

445
        pooling_params = get_pooling_params(pooling_metadata)
446

447
448
449
450
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, embedding_dimension]

451
        # for matryoshka representation
452
453
454
        dimensions_list = [
            pooling_param.dimensions for pooling_param in pooling_params
        ]
455
456
457
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
458
459
460
461
462
463
464
465
466
467
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
468

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data


class RewardPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerClassify())

    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):
        pooling_params = get_pooling_params(pooling_metadata)

        # for softmax
        flags = [p.softmax for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data
504
505


506
class SimplePooler(Pooler):
507
    """A layer that pools specific information from hidden states.
508

509
510
511
512
513
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
514

515
516
517
518
519
520
    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
521
522
523
524
525
526
        if pooler_config.task == "embed":
            head = EmbeddingPoolerHead()
        elif pooler_config.task == "encode":
            head = RewardPoolerHead()
        else:
            raise NotImplementedError(f"Unknown task: {pooler_config.task}")
527
528
529
530
531
532
533
534
        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

535
536
537
538
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
539
        return self.pooling.get_pooling_updates(task)
540

541
542
543
544
545
546
547
548
549
550
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


551
class StepPooler(Pooler):
552

553
    def __init__(self, ) -> None:
554
555
556
        super().__init__()

        self.pooling = AllPool()
557
        self.head = RewardPoolerHead()
558
559
560
561
562
563
564

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
565
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
566
567
568

        pooled_data = list[torch.Tensor]()

569
570
571
572
573
574
575
576
        pooling_params = get_pooling_params(pooling_metadata)

        for data, token_id, pooling_param in zip(pooled_data_lst,
                                                 prompt_token_ids,
                                                 pooling_params):
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

577
578
579
580
581
582
583
584
585
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

586
587
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
588

589
590
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
591

592
593
594
595
596
597
598
599
600
601
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


602
class ClassifierPooler(Pooler):
603
    """A pooling layer for classification tasks.
604
605

    This layer does the following:
606
607
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
608
    3. Applies an activation function to the output.
609
610
    """

611
612
613
614
615
616
617
618
    @staticmethod
    def act_fn_for_seq_cls(config: ModelConfig):
        return get_classification_activation_function(config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(config: ModelConfig):
        return get_cross_encoder_activation_function(config.hf_config)

619
620
    def __init__(
        self,
621
        pooling: PoolingFn,
622
623
        classifier: Optional[ClassifierFn],
        act_fn: Optional[PoolerActivation] = None,
624
    ) -> None:
625
        super().__init__()
626
627

        self.pooling = pooling
628
        self.classifier = classifier
629
        self.act_fn = act_fn or PoolerClassify()
630

631
632
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
633

634
635
    def forward(
        self,
636
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
637
638
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
639
        pooled_data = self.pooling(hidden_states, pooling_metadata)
640

641
642
643
644
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

645
646
647
648
649
650
651
652
        if self.classifier is not None:
            # apply classifier once on the full batch if possible
            if isinstance(pooled_data, torch.Tensor):
                pooled_data = self.classifier(pooled_data)
            elif len({data.shape for data in pooled_data}) <= 1:
                pooled_data = self.classifier(torch.stack(pooled_data))
            else:
                pooled_data = [self.classifier(data) for data in pooled_data]
653

654
655
656
657
658
659
660
661
662
663
        pooling_params = get_pooling_params(pooling_metadata)
        flags = [p.activation for p in pooling_params]

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
                self.act_fn(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683

        return build_output(scores)


class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
                    f"Supported tasks: {pooler.get_supported_tasks()}")

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
684

685
686
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
687

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

        outputs = list[PoolingSequenceGroupOutput]()
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
                    f"Supported tasks: {self.get_supported_tasks()}")

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
705
                hidden_states,
706
707
708
709
710
711
712
                pooling_metadata[offset:offset + num_items],
            )

            outputs.extend(group_output.outputs)
            offset += num_items

        return PoolerOutput(outputs)