pooler.py 23.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import Callable, Optional, TypeVar, Union
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig
16
17
18
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
19
from vllm.pooling_params import PoolingParams
20
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
21
from vllm.tasks import PoolingTask
22
from vllm.utils import resolve_obj_by_qualname
23
24
25
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
26
27
28
29
PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
30
31
32
33
34


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
35
    ALL = 1
36
    CLS = 2
37
    STEP = 3
38
    MEAN = 4
39
40


41
42
43
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
44
    task: PoolingTask
45

46
    @classmethod
47
    def from_config(
48
        cls,
49
        task: PoolingTask,
50
51
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
52
        assert pooler_config.pooling_type is not None
53
        return cls(task=task,
54
                   pooling_type=PoolingType[pooler_config.pooling_type])
55
56


57
58
59
60
61
62
63
64
65
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


66
67
68
69
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
70
71
    def for_encode(pooler_config: PoolerConfig):
        if pooler_config.pooling_type == "STEP":
72
            return StepPooler()
73

74
75
76
        resolved_config = ResolvedPoolingConfig(task="encode",
                                                pooling_type=PoolingType.ALL)

77
78
        return SimplePooler.from_config(resolved_config)

79
    @staticmethod
80
81
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
82
            task="embed",
83
84
85
86
87
88
89
90
91
92
            pooler_config=pooler_config,
        )

        return SimplePooler.from_config(resolved_config)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: Optional[ClassifierFn],
    ):
93
        resolved_config = ResolvedPoolingConfig.from_config(
94
            task="classify",
95
96
            pooler_config=pooler_config,
        )
97
98

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
99
100

        return ClassifierPooler(
101
            pooling=pooling,
102
103
104
105
106
107
108
109
110
            classifier=classifier,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
111
        """
112
        Construct the updated pooling parameters to use for a supported task.
113
        """
114
        return PoolingParamsUpdate()
115
116
117
118
119
120
121
122
123
124

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


125
126
127
128
129
130
131
132
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        return pooling_metadata.prompt_lens

    return PoolingTensors.from_pooling_metadata(
133
        pooling_metadata, hidden_states[0].device).prompt_lens
134
135


136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def get_prompt_token_ids(
        pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        assert pooling_metadata.prompt_token_ids is not None, (
            "Please set `requires_token_ids=True` in `get_pooling_updates`")

        return [
            pooling_metadata.prompt_token_ids[i, :num]
            for i, num in enumerate(pooling_metadata.prompt_lens)
        ]

    return [
        torch.tensor(seq_data_i.prompt_token_ids)
        for seq_data_i in pooling_metadata.seq_data.values()
    ]


153
154
def get_pooling_params(
        pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
155
156
157
158
    if isinstance(pooling_metadata, V0PoolingMetadata):
        pooling_params = [p for _, p in pooling_metadata.seq_groups]
    else:
        pooling_params = pooling_metadata.pooling_params
159
160
161
162
163
    return pooling_params


def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    pooling_params = get_pooling_params(pooling_metadata)
164
165
166
167
168
169
170
171
172
173

    tasks: list[PoolingTask] = [
        task for pooling_param in pooling_params
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


174
def get_classification_activation_function(config: PretrainedConfig):
175
176
177
178
179
180
181
182
183
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
202

203
    return PoolerScore()
204

205

206
207
def build_output(
    all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

227
    @abstractmethod
228
    def get_supported_tasks(self) -> Set[PoolingTask]:
229
230
        raise NotImplementedError

231
232
233
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

234
235
236
237
238
    @abstractmethod
    def forward_one(
        self,
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
239
    ) -> torch.Tensor:
240
241
242
243
244
        """
        Note:
            `prompt_len=None` means `prompt_len=len(hidden_states)`.
        """
        raise NotImplementedError
245

246
247
    @abstractmethod
    def forward_all(
248
        self,
249
250
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
251
252
253
254
255
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
256
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
257
        pooling_metadata: PoolingMetadata,
258
259
260
261
262
263
264
265
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)

        if isinstance(hidden_states, list):
            return [
                self.forward_one(h, prompt_len)
                for h, prompt_len in zip(hidden_states, prompt_lens)
            ]
266

267
        return self.forward_all(hidden_states, prompt_lens)
268
269


270
271
class CLSPool(PoolingMethod):

272
273
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
274

275
    def forward_one(
276
        self,
277
278
279
280
281
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with CLS pooling"
282

283
        return hidden_states[0]
284

285
286
287
288
289
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
290
291
292
293
294
        first_token_flat_indices = torch.zeros_like(prompt_lens)
        first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
        return hidden_states[first_token_flat_indices]


295
class LastPool(PoolingMethod):
296

297
298
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
299

300
    def forward_one(
301
        self,
302
303
304
305
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return hidden_states[-1]
306

307
308
309
310
311
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
312
313
314
315
        last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
        return hidden_states[last_token_flat_indices]


316
class AllPool(PoolingMethod):
317

318
319
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
320

321
    def forward_one(
322
        self,
323
324
325
326
327
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with ALL pooling"
328

329
        return hidden_states
330

331
332
333
334
335
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
336
        return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
337
338


339
class MeanPool(PoolingMethod):
340

341
342
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
343

344
    def forward_one(
345
        self,
346
347
348
349
350
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with MEAN pooling"
351

352
        return hidden_states.mean(dim=0, dtype=torch.float32)
353

354
355
356
357
358
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
359
360
361
362
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

363
364
365
366
367
368
369
370
371
        start_indices = torch.cat([
            torch.tensor([0], device=hidden_states.device),
            torch.cumsum(prompt_lens[:-1], dim=0)
        ])
        end_indices = torch.cumsum(prompt_lens, dim=0)
        return (cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


372
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
373
374


375
class BasePoolerActivation(nn.Module, ABC):
376

377
378
379
380
381
382
383
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
384

385

386
class PoolerActivation(BasePoolerActivation):
387

388
389
390
391
392
393
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
394

395
396
397
398
399
400
401
402
403
404
405
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
406
407


408
409
410
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
411
412
413
        return pooled_data


414
415
416
417
418
419
420
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


421
422
423
424
425
426
class PoolerMultiLabelClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)


427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
class PoolerClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)


class PoolerScore(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return pooled_data


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


458
459
class PoolerHead(nn.Module):

460
    def __init__(self, activation: PoolerActivation) -> None:
461
        super().__init__()
462
        self.activation = activation
463

464
465
466
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

467
468
469
470
471
472
473
474
475
476
477
478
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

        pooling_params = get_pooling_params(pooling_metadata)
479

480
        # for matryoshka representation
481
482
483
        dimensions_list = [
            pooling_param.dimensions for pooling_param in pooling_params
        ]
484
485
486
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
487
488
489
490
491
492
493
494
495
496
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
497

498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data


class RewardPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerClassify())

    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):
        pooling_params = get_pooling_params(pooling_metadata)

        # for softmax
        flags = [p.softmax for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data
533
534


535
class SimplePooler(Pooler):
536
    """A layer that pools specific information from hidden states.
537

538
539
540
541
542
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
543

544
545
546
547
548
549
    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
550
551
552
553
554
555
        if pooler_config.task == "embed":
            head = EmbeddingPoolerHead()
        elif pooler_config.task == "encode":
            head = RewardPoolerHead()
        else:
            raise NotImplementedError(f"Unknown task: {pooler_config.task}")
556
557
558
559
560
561
562
563
        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

564
565
566
567
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
568
        return self.pooling.get_pooling_updates(task)
569

570
571
572
573
574
575
576
577
578
579
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


580
class StepPooler(Pooler):
581

582
    def __init__(self, ) -> None:
583
584
585
        super().__init__()

        self.pooling = AllPool()
586
        self.head = RewardPoolerHead()
587
588
589
590
591
592
593

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
594
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
595
596
597

        pooled_data = list[torch.Tensor]()

598
599
600
601
602
603
604
605
        pooling_params = get_pooling_params(pooling_metadata)

        for data, token_id, pooling_param in zip(pooled_data_lst,
                                                 prompt_token_ids,
                                                 pooling_params):
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

606
607
608
609
610
611
612
613
614
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

615
616
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
617

618
619
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
620

621
622
623
624
625
626
627
628
629
630
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


631
class ClassifierPooler(Pooler):
632
    """A pooling layer for classification tasks.
633
634

    This layer does the following:
635
636
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
637
    3. Applies an activation function to the output.
638
639
    """

640
641
642
643
644
645
646
647
    @staticmethod
    def act_fn_for_seq_cls(config: ModelConfig):
        return get_classification_activation_function(config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(config: ModelConfig):
        return get_cross_encoder_activation_function(config.hf_config)

648
649
    def __init__(
        self,
650
        pooling: PoolingFn,
651
652
        classifier: Optional[ClassifierFn],
        act_fn: Optional[PoolerActivation] = None,
653
    ) -> None:
654
        super().__init__()
655
656

        self.pooling = pooling
657
        self.classifier = classifier
658
        self.act_fn = act_fn or PoolerClassify()
659

660
661
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
662

663
664
    def forward(
        self,
665
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
666
667
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
668
        pooled_data = self.pooling(hidden_states, pooling_metadata)
669

670
671
672
673
674
675
676
677
        if self.classifier is not None:
            # apply classifier once on the full batch if possible
            if isinstance(pooled_data, torch.Tensor):
                pooled_data = self.classifier(pooled_data)
            elif len({data.shape for data in pooled_data}) <= 1:
                pooled_data = self.classifier(torch.stack(pooled_data))
            else:
                pooled_data = [self.classifier(data) for data in pooled_data]
678

679
680
681
682
683
684
685
686
687
688
        pooling_params = get_pooling_params(pooling_metadata)
        flags = [p.activation for p in pooling_params]

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
                self.act_fn(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708

        return build_output(scores)


class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
                    f"Supported tasks: {pooler.get_supported_tasks()}")

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
709

710
711
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
712

713
714
715
716
717
718
719
720
721
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

        if isinstance(hidden_states, list):
            hidden_states_lst = hidden_states
722
        else:
723
724
            prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
            hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
725

726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
        outputs = list[PoolingSequenceGroupOutput]()
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
                    f"Supported tasks: {self.get_supported_tasks()}")

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
                hidden_states_lst[offset:offset + num_items],
                pooling_metadata[offset:offset + num_items],
            )

            outputs.extend(group_output.outputs)
            offset += num_items

        return PoolerOutput(outputs)