pooler.py 23 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import Callable, Optional, TypeVar, Union
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
16
from vllm.logger import init_logger
17
from vllm.model_executor.models.adapters import _load_st_projector
18
19
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
20
21
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput
22
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
23

24
25
logger = init_logger(__name__)

26
27
PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
28
29
    Union[torch.Tensor, list[torch.Tensor]],
]
30
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
31
32
33
34


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
35

36
    LAST = 0
37
    ALL = 1
38
    CLS = 2
39
    STEP = 3
40
    MEAN = 4
41
42


43
44
45
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
46
    task: PoolingTask
47

48
    @classmethod
49
    def from_config(
50
        cls,
51
        task: PoolingTask,
52
53
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
54
        assert pooler_config.pooling_type is not None
55
        return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
56
57


58
59
60
61
62
63
64
65
66
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


67
68
69
70
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
71
72
    def for_encode(pooler_config: PoolerConfig):
        if pooler_config.pooling_type == "STEP":
73
            return StepPooler()
74

75
76
77
        resolved_config = ResolvedPoolingConfig(
            task="encode", pooling_type=PoolingType.ALL
        )
78

79
80
        return SimplePooler.from_config(resolved_config)

81
    @staticmethod
82
83
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
84
            task="embed",
85
86
87
88
89
90
91
92
93
94
            pooler_config=pooler_config,
        )

        return SimplePooler.from_config(resolved_config)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: Optional[ClassifierFn],
    ):
95
        resolved_config = ResolvedPoolingConfig.from_config(
96
            task="classify",
97
98
            pooler_config=pooler_config,
        )
99
100

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
101
102

        return ClassifierPooler(
103
            pooling=pooling,
104
105
106
107
108
109
110
111
112
            classifier=classifier,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
113
        """
114
        Construct the updated pooling parameters to use for a supported task.
115
        """
116
        return PoolingParamsUpdate()
117
118
119
120
121
122
123
124
125
126

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


127
128
129
130
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
131
    return pooling_metadata.prompt_lens
132
133


134
def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
135
    assert pooling_metadata.prompt_token_ids is not None, (
136
137
        "Please set `requires_token_ids=True` in `get_pooling_updates`"
    )
138
139

    return [
140
141
        pooling_metadata.prompt_token_ids[i, :num]
        for i, num in enumerate(pooling_metadata.prompt_lens)
142
143
144
    ]


145
def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
146
    pooling_params = pooling_metadata.pooling_params
147
148
149
150
151
    return pooling_params


def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    pooling_params = get_pooling_params(pooling_metadata)
152
153

    tasks: list[PoolingTask] = [
154
155
        task
        for pooling_param in pooling_params
156
157
158
159
160
161
162
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


163
def get_classification_activation_function(config: PretrainedConfig):
164
165
166
167
168
169
170
171
172
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
173
174
175
176
177
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
178
179
180
181
    if (
        hasattr(config, "sentence_transformers")
        and "activation_fn" in config.sentence_transformers
    ):
182
        function_name = config.sentence_transformers["activation_fn"]
183
184
185
186
    elif (
        hasattr(config, "sbert_ce_default_activation_function")
        and config.sbert_ce_default_activation_function is not None
    ):
187
188
189
190
191
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
192
193
            "torch.nn.modules for security reasons"
        )
194
195
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
196

197
    return PoolerClassify()
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

class PoolingMethod(nn.Module, ABC):
    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

214
    @abstractmethod
215
    def get_supported_tasks(self) -> Set[PoolingTask]:
216
217
        raise NotImplementedError

218
219
220
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

221
222
    @abstractmethod
    def forward_all(
223
        self,
224
        hidden_states: torch.Tensor,
225
        pooling_cursor: PoolingCursor,
226
227
228
229
230
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
231
        hidden_states: torch.Tensor,
232
        pooling_metadata: PoolingMetadata,
233
    ) -> Union[list[torch.Tensor], torch.Tensor]:
234
235
        pooling_cursor = pooling_metadata.pooling_cursor
        return self.forward_all(hidden_states, pooling_cursor)
236
237


238
class CLSPool(PoolingMethod):
239
240
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
241

242
243
244
    def forward_all(
        self,
        hidden_states: torch.Tensor,
245
        pooling_cursor: PoolingCursor,
246
    ) -> Union[list[torch.Tensor], torch.Tensor]:
247
        assert not pooling_cursor.is_partial_prefill(), (
248
            "partial prefill not supported with CLS pooling"
249
        )
250
251

        return hidden_states[pooling_cursor.first_token_indices_gpu]
252
253


254
class LastPool(PoolingMethod):
255
256
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
257

258
259
260
    def forward_all(
        self,
        hidden_states: torch.Tensor,
261
        pooling_cursor: PoolingCursor,
262
    ) -> Union[list[torch.Tensor], torch.Tensor]:
263
        return hidden_states[pooling_cursor.last_token_indices_gpu]
264
265


266
class AllPool(PoolingMethod):
267
268
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
269

270
271
272
    def forward_all(
        self,
        hidden_states: torch.Tensor,
273
        pooling_cursor: PoolingCursor,
274
    ) -> Union[list[torch.Tensor], torch.Tensor]:
275
        assert not pooling_cursor.is_partial_prefill(), (
276
            "partial prefill not supported with ALL pooling"
277
        )
278
279

        hidden_states_lst = list(
280
281
            hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
        )
282
        return [hidden_states_lst[i] for i in pooling_cursor.index]
283
284


285
class MeanPool(PoolingMethod):
286
287
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
288

289
    def forward_all(
290
        self,
291
        hidden_states: torch.Tensor,
292
293
        pooling_cursor: PoolingCursor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
294
        assert not pooling_cursor.is_partial_prefill(), (
295
            "partial prefill not supported with MEAN pooling"
296
        )
297

298
299
300
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(
            hidden_states.device, non_blocking=True
        )
301

302
303
304
305
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

306
307
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
308
309
310
        return (
            cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
        ) / prompt_lens.unsqueeze(1)
311
312


313
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
314
315


316
317
318
319
320
321
322
323
class BasePoolerActivation(nn.Module, ABC):
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
324

325

326
327
328
329
330
331
332
class PoolerActivation(BasePoolerActivation):
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
333

334
335
336
337
338
339
340
341
342
343
344
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
345
346


347
348
class PoolerIdentity(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
349
350
351
        return pooled_data


352
353
class PoolerNormalize(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
354
        return F.normalize(pooled_data, p=2, dim=-1)
355
356


357
358
class PoolerMultiLabelClassify(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
359
        return F.sigmoid(pooled_data)
360
361


362
class PoolerClassify(PoolerActivation):
363
364
    def __init__(self, *, static_num_labels: bool = True) -> None:
        super().__init__()
365

366
367
        if static_num_labels:
            vllm_config = get_current_vllm_config()
368
369
370
            self.num_labels = getattr(
                vllm_config.model_config.hf_config, "num_labels", 0
            )
371
            if self.num_labels == 0:
372
373
374
375
376
                logger.warning(
                    "num_labels should be > 0 for classification"
                    "models, falling back to softmax. "
                    "Please check if the configuration is correct."
                )
377
378
        else:
            self.num_labels = None
379
380

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
381
382
383
        num_labels = (
            self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
        )
384

385
        if num_labels < 2:
386
            return F.sigmoid(pooled_data)
387

388
        return F.softmax(pooled_data, dim=-1)
389
390
391
392
393
394
395
396
397
398
399
400


class LambdaPoolerActivation(PoolerActivation):
    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


401
class PoolerHead(nn.Module):
402
    def __init__(self, activation: PoolerActivation) -> None:
403
        super().__init__()
404
        self.activation = activation
405

406
407
408
409
410
    def forward(
        self,
        pooled_data: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ):
411
412
413
414
415
416
417
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):
    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

418
419
420
        # Load ST projector if available

        vllm_config = get_current_vllm_config()
421
422
423
        self.projector: Optional[nn.Module] = (
            _load_st_projector(vllm_config.model_config) if vllm_config else None
        )
424
        self.head_dtype = vllm_config.model_config.head_dtype
425

426
427
428
429
430
    def forward(
        self,
        pooled_data: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ):
431
432
433
434
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_dimension]

435
436
        pooled_data = pooled_data.to(self.head_dtype)

437
438
        # Apply ST projector
        if self.projector is not None:
439
            pooled_data = self.projector(pooled_data)
440
        # pooled_data shape: [batchsize, embedding_dimension]
441

442
        pooling_params = get_pooling_params(pooling_metadata)
443

444
        # for matryoshka representation
445
        dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
446
447
448
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
449
            if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
450
451
452
453
454
455
456
457
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
458

459
460
461
462
463
464
465
466
467
468
469
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

470
        # pooled_data shape: [batchsize, embedding_dimension]
471
472
473
474
475
        return pooled_data


class RewardPoolerHead(PoolerHead):
    def __init__(self) -> None:
476
        super().__init__(activation=PoolerClassify(static_num_labels=False))
477

478
479
480
        vllm_config = get_current_vllm_config()
        self.head_dtype = vllm_config.model_config.head_dtype

481
482
483
484
485
    def forward(
        self,
        pooled_data: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ):
486
487
488
489
490
        if isinstance(pooled_data, list):
            pooled_data = [p.to(self.head_dtype) for p in pooled_data]
        else:
            pooled_data = pooled_data.to(self.head_dtype)

491
492
493
494
495
496
497
498
499
500
501
502
503
504
        pooling_params = get_pooling_params(pooling_metadata)

        # for softmax
        flags = [p.softmax for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data
505
506


507
class SimplePooler(Pooler):
508
    """A layer that pools specific information from hidden states.
509

510
511
512
513
514
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
515

516
517
518
519
520
521
    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
522
523
524
525
526
527
        if pooler_config.task == "embed":
            head = EmbeddingPoolerHead()
        elif pooler_config.task == "encode":
            head = RewardPoolerHead()
        else:
            raise NotImplementedError(f"Unknown task: {pooler_config.task}")
528
529
530
531
532
533
534
535
        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

536
537
538
539
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
540
        return self.pooling.get_pooling_updates(task)
541

542
543
544
545
546
547
548
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
549
        return pooled_data
550
551


552
class StepPooler(Pooler):
553
554
555
    def __init__(
        self,
    ) -> None:
556
557
558
        super().__init__()

        self.pooling = AllPool()
559
        self.head = RewardPoolerHead()
560
561
562
563
564
565
566

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
567
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
568
569
570

        pooled_data = list[torch.Tensor]()

571
572
        pooling_params = get_pooling_params(pooling_metadata)

573
574
575
        for data, token_id, pooling_param in zip(
            pooled_data_lst, prompt_token_ids, pooling_params
        ):
576
577
578
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

579
580
581
582
583
584
585
586
587
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

588
589
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
590

591
592
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
593

594
595
596
597
598
599
600
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
601
        return pooled_data
602
603


604
class ClassifierPooler(Pooler):
605
    """A pooling layer for classification tasks.
606
607

    This layer does the following:
608
609
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
610
    3. Applies an activation function to the output.
611
612
    """

613
614
615
616
617
618
619
620
    @staticmethod
    def act_fn_for_seq_cls(config: ModelConfig):
        return get_classification_activation_function(config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(config: ModelConfig):
        return get_cross_encoder_activation_function(config.hf_config)

621
622
    def __init__(
        self,
623
        pooling: PoolingFn,
624
625
        classifier: Optional[ClassifierFn],
        act_fn: Optional[PoolerActivation] = None,
626
    ) -> None:
627
        super().__init__()
628

629
630
        vllm_config = get_current_vllm_config()

631
        self.pooling = pooling
632
        self.classifier = classifier
633
        self.act_fn = act_fn or PoolerClassify()
634
635
636
        self.logit_bias: Optional[float] = (
            vllm_config.model_config.pooler_config.logit_bias
        )
637
        self.head_dtype = vllm_config.model_config.head_dtype
638

639
640
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
641

642
643
    def forward(
        self,
644
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
645
646
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
647
        pooled_data = self.pooling(hidden_states, pooling_metadata)
648
649
650
651
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

652
653
        pooled_data = pooled_data.to(self.head_dtype)

654
        if self.classifier is not None:
655
656
            pooled_data = self.classifier(pooled_data)
        # pooled_data shape: [batchsize, num_labels]
657

658
659
660
        if self.logit_bias is not None:
            pooled_data -= self.logit_bias

661
662
663
664
665
666
667
        pooling_params = get_pooling_params(pooling_metadata)
        flags = [p.activation for p in pooling_params]

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
668
                self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
669
            ]
670

671
        # scores shape: [batchsize, num_labels]
672
        return scores
673
674
675
676
677
678
679
680
681
682
683
684


class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
685
686
                    f"Supported tasks: {pooler.get_supported_tasks()}"
                )
687
688
689
690
691

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
692

693
694
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
695

696
697
698
699
700
701
702
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

703
        outputs = list[torch.Tensor]()
704
705
706
707
708
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
709
710
                    f"Supported tasks: {self.get_supported_tasks()}"
                )
711
712
713

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
714
                hidden_states,
715
                pooling_metadata[offset : offset + num_items],
716
717
            )

718
            outputs.extend(group_output)
719
720
            offset += num_items

721
        return outputs
722
723
724
725

    def extra_repr(self) -> str:
        s = f"supported_task={self.get_supported_tasks()}"
        return s