pooler.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from abc import ABC, abstractmethod
from dataclasses import dataclass
5
from enum import IntEnum
6
from typing import Callable, Optional, TypeVar, Union
7
8
9

import torch
import torch.nn as nn
10
import torch.nn.functional as F
11
from transformers import PretrainedConfig
12
from typing_extensions import assert_never
13

14
from vllm.config import ModelConfig, PoolerConfig
15
16
17
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
18
from vllm.pooling_params import PoolingParams, PoolingTask
19
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
20
from vllm.utils import resolve_obj_by_qualname
21
22
23
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
24
25
26
27
28


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
29
    ALL = 1
30
    CLS = 2
31
    STEP = 3
32
    MEAN = 4
33
34


35
36
37
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
38

39
40
41
42
    normalize: bool
    softmax: bool
    step_tag_id: Optional[int]
    returned_token_ids: Optional[list[int]]
43

44
45
46
47
    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
48
49
50
51
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
52
        returned_token_ids: Optional[list[int]] = None,
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    ) -> "ResolvedPoolingConfig":
        return cls(
            pooling_type=PoolingType[pooler_config.pooling_type]
            if pooler_config.pooling_type is not None else pooling_type,
            normalize=pooler_config.normalize
            if pooler_config.normalize is not None else normalize,
            softmax=pooler_config.softmax
            if pooler_config.softmax is not None else softmax,
            step_tag_id=pooler_config.step_tag_id
            if pooler_config.step_tag_id is not None else step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids
            if pooler_config.returned_token_ids is not None else
            returned_token_ids,
        )
67
68


69
70
71
72
73
74
75
76
77
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
    def from_config_with_defaults(
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
        returned_token_ids: Optional[list[int]] = None,
    ) -> "Pooler":
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
            pooling_type=pooling_type,
            normalize=normalize,
            softmax=softmax,
            step_tag_id=step_tag_id,
            returned_token_ids=returned_token_ids,
        )

        if pooling_type == PoolingType.STEP:
            return StepPooler.from_config(resolved_config)

        return SimplePooler.from_config(resolved_config)

104
105
106
107
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        """
        Construct the pooling parameters to use for a task,
        or `None` if the task is not supported.
        """
        return None

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


123
124
125
126
127
128
129
130
131
132
133
134
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        return pooling_metadata.prompt_lens

    assert isinstance(hidden_states, torch.Tensor)
    return PoolingTensors.from_pooling_metadata(
        pooling_metadata, hidden_states.device).prompt_lens


135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def get_prompt_token_ids(
        pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        assert pooling_metadata.prompt_token_ids is not None, (
            "Please set `requires_token_ids=True` in `get_pooling_updates`")

        return [
            pooling_metadata.prompt_token_ids[i, :num]
            for i, num in enumerate(pooling_metadata.prompt_lens)
        ]

    return [
        torch.tensor(seq_data_i.prompt_token_ids)
        for seq_data_i in pooling_metadata.seq_data.values()
    ]


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def get_classification_activation_function(config: PretrainedConfig):
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
171

172
    return PoolerScore()
173

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

def build_output(all_data: torch.Tensor) -> PoolerOutput:
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

195
    @abstractmethod
196
197
198
199
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
200
201
        raise NotImplementedError

202
203
204
205
206
    @abstractmethod
    def forward_one(
        self,
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
207
    ) -> torch.Tensor:
208
209
210
211
212
        """
        Note:
            `prompt_len=None` means `prompt_len=len(hidden_states)`.
        """
        raise NotImplementedError
213

214
215
    @abstractmethod
    def forward_all(
216
        self,
217
218
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
219
220
221
222
223
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
224
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
225
        pooling_metadata: PoolingMetadata,
226
227
228
229
230
231
232
233
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)

        if isinstance(hidden_states, list):
            return [
                self.forward_one(h, prompt_len)
                for h, prompt_len in zip(hidden_states, prompt_lens)
            ]
234

235
        return self.forward_all(hidden_states, prompt_lens)
236
237


238
239
class CLSPool(PoolingMethod):

240
241
242
243
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
244
245
246
        # The equalities are split up to keep mypy happy
        if (task == "encode" or task == "embed" or task == "classify"
                or task == "score"):
247
            return PoolingParamsUpdate()
248
249
250

        assert_never(task)

251
    def forward_one(
252
        self,
253
254
255
256
257
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with CLS pooling"
258

259
        return hidden_states[0]
260

261
262
263
264
265
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
266
267
268
269
270
        first_token_flat_indices = torch.zeros_like(prompt_lens)
        first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
        return hidden_states[first_token_flat_indices]


271
class LastPool(PoolingMethod):
272

273
274
275
276
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
277
278
279
        # The equalities are split up to keep mypy happy
        if (task == "encode" or task == "embed" or task == "classify"
                or task == "score"):
280
            return PoolingParamsUpdate()
281
282
283

        assert_never(task)

284
    def forward_one(
285
        self,
286
287
288
289
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return hidden_states[-1]
290

291
292
293
294
295
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
296
297
298
299
        last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
        return hidden_states[last_token_flat_indices]


300
class AllPool(PoolingMethod):
301

302
303
304
305
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
306
        if task == "encode":
307
            return PoolingParamsUpdate()
308
309
310
311
312
313
314

        # The equalities are split up to keep mypy happy
        if task == "embed" or task == "classify" or task == "score":
            return None

        assert_never(task)

315
    def forward_one(
316
        self,
317
318
319
320
321
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with ALL pooling"
322

323
        return hidden_states
324

325
326
327
328
329
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
330
331
        offset = 0
        pooled_data = list[torch.Tensor]()
332

333
334
335
336
337
338
339
        for prompt_len in prompt_lens:
            pooled_data.append(hidden_states[offset:offset + prompt_len])
            offset += prompt_len

        return pooled_data


340
class MeanPool(PoolingMethod):
341

342
343
344
345
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
346
347
348
        # The equalities are split up to keep mypy happy
        if (task == "encode" or task == "embed" or task == "classify"
                or task == "score"):
349
            return PoolingParamsUpdate()
350
351
352

        assert_never(task)

353
    def forward_one(
354
        self,
355
356
357
358
359
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with MEAN pooling"
360

361
        return hidden_states.mean(dim=0, dtype=torch.float32)
362

363
364
365
366
367
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
368
369
370
371
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

372
373
374
375
376
377
378
379
380
        start_indices = torch.cat([
            torch.tensor([0], device=hidden_states.device),
            torch.cumsum(prompt_lens[:-1], dim=0)
        ])
        end_indices = torch.cumsum(prompt_lens, dim=0)
        return (cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


381
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
382
383


384
class BasePoolerActivation(nn.Module, ABC):
385

386
387
388
389
390
391
392
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
393

394

395
class PoolerActivation(BasePoolerActivation):
396

397
398
399
400
401
402
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
403

404
405
406
407
408
409
410
411
412
413
414
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
415
416


417
418
419
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
420
421
422
        return pooled_data


423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


class PoolerClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)


class PoolerScore(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return pooled_data


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


461
462
class PoolerHead(nn.Module):

463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    @classmethod
    def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
        if pooler_config.normalize and pooler_config.softmax:
            raise ValueError("`normalize=True` and `softmax=True` should not "
                             "be set together")

        activation: PoolerActivation
        if pooler_config.normalize:
            activation = PoolerNormalize()
        elif pooler_config.softmax:
            activation = PoolerClassify()
        else:
            activation = PoolerIdentity()

        return cls(activation)

    def __init__(self, activation: PoolerActivation) -> None:
480
        super().__init__()
481

482
        self.activation = activation
483

484
485
486
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

487
488
489
490
491
492
493
        # Using float32 in PoolerHead
        if isinstance(pooled_data, list):
            for i in range(len(pooled_data)):
                pooled_data[i] = pooled_data[i].to(torch.float32)
        else:
            pooled_data = pooled_data.to(torch.float32)

494
        # for matryoshka representation
495
496
497
498
499
500
501
502
503
504
505
        if isinstance(pooling_metadata, V0PoolingMetadata):
            dimensions_list = [
                pooling_param.dimensions
                for _, pooling_param in pooling_metadata.seq_groups
            ]
        else:
            assert isinstance(pooled_data, list)
            dimensions_list = [
                pooling_param.dimensions
                for pooling_param in pooling_metadata.pooling_params
            ]
506
507
508
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
509
510
511
512
513
514
515
516
517
518
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
519

520
        return self.activation(pooled_data)
521
522


523
class SimplePooler(Pooler):
524
    """A layer that pools specific information from hidden states.
525

526
527
528
529
530
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
531
532

    @classmethod
533
    def from_config_with_defaults(  # type: ignore[override]
534
535
536
537
538
        cls,
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
    ) -> "SimplePooler":
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
            pooling_type=pooling_type,
            normalize=normalize,
            softmax=softmax,
        )
        assert resolved_config.pooling_type != PoolingType.STEP

        return cls.from_config(resolved_config)

    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
        head = PoolerHead.from_config(pooler_config)

        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

566
567
568
569
570
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
        return self.pooling.get_pooling_updates(task)
571

572
573
574
575
576
577
578
579
580
581
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


582
class StepPooler(Pooler):
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597

    @classmethod
    def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
        assert pooler_config.pooling_type == PoolingType.STEP

        return cls(
            PoolerHead.from_config(pooler_config),
            step_tag_id=pooler_config.step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids,
        )

    def __init__(
        self,
        head: PoolerHead,
        *,
598
        step_tag_id: Optional[int] = None,
599
        returned_token_ids: Optional[list[int]] = None,
600
601
602
603
604
605
606
607
608
609
610
611
612
613
    ) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head
        self.step_tag_id = step_tag_id
        self.returned_token_ids = returned_token_ids

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
614
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629

        pooled_data = list[torch.Tensor]()
        returned_token_ids = self.returned_token_ids
        step_tag_id = self.step_tag_id

        for data, token_id in zip(pooled_data_lst, prompt_token_ids):
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

630
631
632
633
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
634
        if task == "encode":
635
            return PoolingParamsUpdate(requires_token_ids=True)
636
637
638
639
640
641
642

        # The equalities are split up to keep mypy happy
        if task == "embed" or task == "classify" or task == "score":
            return None

        assert_never(task)

643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]

658

659
660
class ClassifierPooler(nn.Module):
    """A pooling layer for classification tasks.
661
662

    This layer does the following:
663
664
665
666
667
668
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
    3. Applies an activation function to the output. In the case of
       classification models it is either sigmoid or softmax. In the
       case of scoring models, the same behavior is configuration
       dependent, as in the sentence-transformers library.
669
670
671
672
    """

    def __init__(
        self,
673
        config: ModelConfig,
674
675
676
677
        pooling: PoolingFn,
        classifier: ClassifierFn,
        act_fn: Optional[PoolerActivation] = None,
    ) -> None:
678
        super().__init__()
679
680

        self.pooling = pooling
681
        self.classifier = classifier
682

683
        self.classification_act_fn = get_classification_activation_function(
684
            config.hf_config) if act_fn is None else act_fn
685
        self.cross_encoder_act_fn = get_cross_encoder_activation_function(
686
            config.hf_config) if act_fn is None else act_fn
687

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
    def _get_act_fn(self, task: PoolingTask):
        if task == "encode" or task == "classify":
            return self.classification_act_fn
        if task == "score":
            return self.cross_encoder_act_fn

        raise ValueError(f"Unsupported task: {task!r}")

    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
        # The equalities are split up to keep mypy happy
        if task == "encode" or task == "classify" or task == "score":
            return PoolingParamsUpdate()
703

704
705
706
707
708
        if task == "embed":
            return None

        assert_never(task)

709
710
    def forward(
        self,
711
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
712
713
714
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        """Pools sentence pair scores from the hidden_states."""
715
        pooled_data = self.pooling(hidden_states, pooling_metadata)
716

717
718
719
720
721
        # apply classifier once on the full batch if possible
        if isinstance(pooled_data, torch.Tensor):
            pooled_output = self.classifier(pooled_data)
        elif len({data.shape for data in pooled_data}) <= 1:
            pooled_output = self.classifier(torch.stack(pooled_data))
722
        else:
723
            pooled_output = [self.classifier(data) for data in pooled_data]
724

725
        task_list: list[PoolingTask]
726
        if isinstance(pooling_metadata, V0PoolingMetadata):
727
728
729
            task_list = [
                task for _, pooling_param in pooling_metadata.seq_groups
                if (task := pooling_param.task) is not None
730
731
            ]
        else:
732
733
734
            task_list = [
                task for pooling_param in pooling_metadata.pooling_params
                if (task := pooling_param.task) is not None
735
736
            ]

737
738
        assert len(task_list) == len(pooled_output)

739
        # shape of scores: (batch_size, num_labels)
740
741
        if len(set(task_list)) <= 1:
            act_fn = self._get_act_fn(task_list[0])
742
743
744
            scores = act_fn(pooled_output)
        else:
            scores = torch.stack([
745
746
                self._get_act_fn(task)(vecs)
                for task, vecs in zip(task_list, pooled_output)
747
            ])
748

749
        return build_output(scores)