pooler.py 24.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import Callable, Optional, TypeVar, Union
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig
16
17
18
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
19
from vllm.pooling_params import PoolingParams
20
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
21
from vllm.tasks import PoolingTask
22
from vllm.utils import resolve_obj_by_qualname
23
24
25
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
26
27
28
29
PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
30
31
32
33
34


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
35
    ALL = 1
36
    CLS = 2
37
    STEP = 3
38
    MEAN = 4
39
40


41
42
43
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
44

45
46
47
48
    normalize: bool
    softmax: bool
    step_tag_id: Optional[int]
    returned_token_ids: Optional[list[int]]
49

50
51
52
53
    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
54
55
56
57
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
58
        returned_token_ids: Optional[list[int]] = None,
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    ) -> "ResolvedPoolingConfig":
        return cls(
            pooling_type=PoolingType[pooler_config.pooling_type]
            if pooler_config.pooling_type is not None else pooling_type,
            normalize=pooler_config.normalize
            if pooler_config.normalize is not None else normalize,
            softmax=pooler_config.softmax
            if pooler_config.softmax is not None else softmax,
            step_tag_id=pooler_config.step_tag_id
            if pooler_config.step_tag_id is not None else step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids
            if pooler_config.returned_token_ids is not None else
            returned_token_ids,
        )
73
74


75
76
77
78
79
80
81
82
83
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


84
85
86
87
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
88
    def for_encode(
89
        pooler_config: PoolerConfig,
90
91
92
93
94
95
96
        *,
        default_pooling_type: PoolingType = PoolingType.ALL,
        default_normalize: bool = False,
        default_softmax: bool = False,
        default_step_tag_id: Optional[int] = None,
        default_returned_token_ids: Optional[list[int]] = None,
    ):
97
98
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
99
100
101
102
103
            pooling_type=default_pooling_type,
            normalize=default_normalize,
            softmax=default_softmax,
            step_tag_id=default_step_tag_id,
            returned_token_ids=default_returned_token_ids,
104
105
        )

106
        if resolved_config.pooling_type == PoolingType.STEP:
107
108
109
110
            return StepPooler.from_config(resolved_config)

        return SimplePooler.from_config(resolved_config)

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    @staticmethod
    def for_embed(
        pooler_config: PoolerConfig,
        *,
        default_pooling_type: PoolingType = PoolingType.LAST,
        default_normalize: bool = True,
        default_softmax: bool = False,
    ):
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
            pooling_type=default_pooling_type,
            normalize=default_normalize,
            softmax=default_softmax,
        )

        return SimplePooler.from_config(resolved_config)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: Optional[ClassifierFn],
        *,
        default_pooling_type: PoolingType = PoolingType.LAST,
        default_normalize: bool = False,
        default_softmax: bool = True,
    ):
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
            pooling_type=default_pooling_type,
            normalize=default_normalize,
            softmax=default_softmax,
        )
        base_pooler = SimplePooler.from_config(resolved_config)
        if classifier is None:
            return base_pooler

        return ClassifierPooler(
            pooling=base_pooler.pooling,
            classifier=classifier,
            act_fn=base_pooler.head.activation,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
159
        """
160
        Construct the updated pooling parameters to use for a supported task.
161
        """
162
        return PoolingParamsUpdate()
163
164
165
166
167
168
169
170
171
172

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


173
174
175
176
177
178
179
180
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        return pooling_metadata.prompt_lens

    return PoolingTensors.from_pooling_metadata(
181
        pooling_metadata, hidden_states[0].device).prompt_lens
182
183


184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def get_prompt_token_ids(
        pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        assert pooling_metadata.prompt_token_ids is not None, (
            "Please set `requires_token_ids=True` in `get_pooling_updates`")

        return [
            pooling_metadata.prompt_token_ids[i, :num]
            for i, num in enumerate(pooling_metadata.prompt_lens)
        ]

    return [
        torch.tensor(seq_data_i.prompt_token_ids)
        for seq_data_i in pooling_metadata.seq_data.values()
    ]


201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    if isinstance(pooling_metadata, V0PoolingMetadata):
        pooling_params = [p for _, p in pooling_metadata.seq_groups]
    else:
        pooling_params = pooling_metadata.pooling_params

    tasks: list[PoolingTask] = [
        task for pooling_param in pooling_params
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def get_classification_activation_function(config: PretrainedConfig):
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
235

236
    return PoolerScore()
237

238

239
240
def build_output(
    all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

260
    @abstractmethod
261
    def get_supported_tasks(self) -> Set[PoolingTask]:
262
263
        raise NotImplementedError

264
265
266
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

267
268
269
270
271
    @abstractmethod
    def forward_one(
        self,
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
272
    ) -> torch.Tensor:
273
274
275
276
277
        """
        Note:
            `prompt_len=None` means `prompt_len=len(hidden_states)`.
        """
        raise NotImplementedError
278

279
280
    @abstractmethod
    def forward_all(
281
        self,
282
283
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
284
285
286
287
288
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
289
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
290
        pooling_metadata: PoolingMetadata,
291
292
293
294
295
296
297
298
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)

        if isinstance(hidden_states, list):
            return [
                self.forward_one(h, prompt_len)
                for h, prompt_len in zip(hidden_states, prompt_lens)
            ]
299

300
        return self.forward_all(hidden_states, prompt_lens)
301
302


303
304
class CLSPool(PoolingMethod):

305
306
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
307

308
    def forward_one(
309
        self,
310
311
312
313
314
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with CLS pooling"
315

316
        return hidden_states[0]
317

318
319
320
321
322
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
323
324
325
326
327
        first_token_flat_indices = torch.zeros_like(prompt_lens)
        first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
        return hidden_states[first_token_flat_indices]


328
class LastPool(PoolingMethod):
329

330
331
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
332

333
    def forward_one(
334
        self,
335
336
337
338
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return hidden_states[-1]
339

340
341
342
343
344
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
345
346
347
348
        last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
        return hidden_states[last_token_flat_indices]


349
class AllPool(PoolingMethod):
350

351
352
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
353

354
    def forward_one(
355
        self,
356
357
358
359
360
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with ALL pooling"
361

362
        return hidden_states
363

364
365
366
367
368
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
369
        return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
370
371


372
class MeanPool(PoolingMethod):
373

374
375
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
376

377
    def forward_one(
378
        self,
379
380
381
382
383
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with MEAN pooling"
384

385
        return hidden_states.mean(dim=0, dtype=torch.float32)
386

387
388
389
390
391
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
392
393
394
395
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

396
397
398
399
400
401
402
403
404
        start_indices = torch.cat([
            torch.tensor([0], device=hidden_states.device),
            torch.cumsum(prompt_lens[:-1], dim=0)
        ])
        end_indices = torch.cumsum(prompt_lens, dim=0)
        return (cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


405
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
406
407


408
class BasePoolerActivation(nn.Module, ABC):
409

410
411
412
413
414
415
416
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
417

418

419
class PoolerActivation(BasePoolerActivation):
420

421
422
423
424
425
426
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
427

428
429
430
431
432
433
434
435
436
437
438
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
439
440


441
442
443
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
444
445
446
        return pooled_data


447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


class PoolerClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)


class PoolerScore(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return pooled_data


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


485
486
class PoolerHead(nn.Module):

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    @classmethod
    def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
        if pooler_config.normalize and pooler_config.softmax:
            raise ValueError("`normalize=True` and `softmax=True` should not "
                             "be set together")

        activation: PoolerActivation
        if pooler_config.normalize:
            activation = PoolerNormalize()
        elif pooler_config.softmax:
            activation = PoolerClassify()
        else:
            activation = PoolerIdentity()

        return cls(activation)

    def __init__(self, activation: PoolerActivation) -> None:
504
        super().__init__()
505

506
        self.activation = activation
507

508
509
510
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

511
512
513
514
515
516
517
        # Using float32 in PoolerHead
        if isinstance(pooled_data, list):
            for i in range(len(pooled_data)):
                pooled_data[i] = pooled_data[i].to(torch.float32)
        else:
            pooled_data = pooled_data.to(torch.float32)

518
        # for matryoshka representation
519
520
521
522
523
524
525
526
527
528
529
        if isinstance(pooling_metadata, V0PoolingMetadata):
            dimensions_list = [
                pooling_param.dimensions
                for _, pooling_param in pooling_metadata.seq_groups
            ]
        else:
            assert isinstance(pooled_data, list)
            dimensions_list = [
                pooling_param.dimensions
                for pooling_param in pooling_metadata.pooling_params
            ]
530
531
532
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
533
534
535
536
537
538
539
540
541
542
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
543

544
        return self.activation(pooled_data)
545
546


547
class SimplePooler(Pooler):
548
    """A layer that pools specific information from hidden states.
549

550
551
552
553
554
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
555

556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
        head = PoolerHead.from_config(pooler_config)

        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

572
573
574
575
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
576
        return self.pooling.get_pooling_updates(task)
577

578
579
580
581
582
583
584
585
586
587
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


588
class StepPooler(Pooler):
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603

    @classmethod
    def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
        assert pooler_config.pooling_type == PoolingType.STEP

        return cls(
            PoolerHead.from_config(pooler_config),
            step_tag_id=pooler_config.step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids,
        )

    def __init__(
        self,
        head: PoolerHead,
        *,
604
        step_tag_id: Optional[int] = None,
605
        returned_token_ids: Optional[list[int]] = None,
606
607
608
609
610
611
612
613
614
615
616
617
618
619
    ) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head
        self.step_tag_id = step_tag_id
        self.returned_token_ids = returned_token_ids

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
620
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635

        pooled_data = list[torch.Tensor]()
        returned_token_ids = self.returned_token_ids
        step_tag_id = self.step_tag_id

        for data, token_id in zip(pooled_data_lst, prompt_token_ids):
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

636
637
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
638

639
640
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
641

642
643
644
645
646
647
648
649
650
651
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


652
class ClassifierPooler(Pooler):
653
    """A pooling layer for classification tasks.
654
655

    This layer does the following:
656
657
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
658
    3. Applies an activation function to the output.
659
660
    """

661
662
663
664
665
666
667
668
    @staticmethod
    def act_fn_for_seq_cls(config: ModelConfig):
        return get_classification_activation_function(config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(config: ModelConfig):
        return get_cross_encoder_activation_function(config.hf_config)

669
670
    def __init__(
        self,
671
672
        pooling: PoolingFn,
        classifier: ClassifierFn,
673
        act_fn: PoolerActivation,
674
    ) -> None:
675
        super().__init__()
676
677

        self.pooling = pooling
678
        self.classifier = classifier
679
        self.act_fn = act_fn
680

681
682
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
683

684
685
    def forward(
        self,
686
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
687
688
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
689
        pooled_data = self.pooling(hidden_states, pooling_metadata)
690

691
692
693
694
695
        # apply classifier once on the full batch if possible
        if isinstance(pooled_data, torch.Tensor):
            pooled_output = self.classifier(pooled_data)
        elif len({data.shape for data in pooled_data}) <= 1:
            pooled_output = self.classifier(torch.stack(pooled_data))
696
        else:
697
            pooled_output = [self.classifier(data) for data in pooled_data]
698

699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
        scores = self.act_fn(pooled_output)

        return build_output(scores)


class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
                    f"Supported tasks: {pooler.get_supported_tasks()}")

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
720

721
722
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
723

724
725
726
727
728
729
730
731
732
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

        if isinstance(hidden_states, list):
            hidden_states_lst = hidden_states
733
        else:
734
735
            prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
            hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
736

737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
        outputs = list[PoolingSequenceGroupOutput]()
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
                    f"Supported tasks: {self.get_supported_tasks()}")

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
                hidden_states_lst[offset:offset + num_items],
                pooling_metadata[offset:offset + num_items],
            )

            outputs.extend(group_output.outputs)
            offset += num_items

        return PoolerOutput(outputs)