pooler.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import Callable, Optional, TypeVar, Union
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig
16
17
18
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
19
from vllm.pooling_params import PoolingParams
20
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
21
from vllm.tasks import PoolingTask
22
from vllm.utils import resolve_obj_by_qualname
23
24
25
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
26
27
28
29
PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
30
31
32
33
34


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
35
    ALL = 1
36
    CLS = 2
37
    STEP = 3
38
    MEAN = 4
39
40


41
42
43
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
44
    task: PoolingTask
45

46
47
48
    @classmethod
    def from_config_with_defaults(
        cls,
49
        task: PoolingTask,
50
        pooler_config: PoolerConfig,
51
        pooling_type: PoolingType,
52
    ) -> "ResolvedPoolingConfig":
53
54
55
        return cls(task=task,
                   pooling_type=PoolingType[pooler_config.pooling_type]
                   if pooler_config.pooling_type is not None else pooling_type)
56
57


58
59
60
61
62
63
64
65
66
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


67
68
69
70
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
71
    def for_encode(
72
        pooler_config: PoolerConfig,
73
74
75
        *,
        default_pooling_type: PoolingType = PoolingType.ALL,
    ):
76
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
77
            task="encode",
78
            pooler_config=pooler_config,
79
            pooling_type=default_pooling_type,
80
81
        )

82
        if resolved_config.pooling_type == PoolingType.STEP:
83
            return StepPooler()
84
85
86

        return SimplePooler.from_config(resolved_config)

87
88
89
90
91
92
93
    @staticmethod
    def for_embed(
        pooler_config: PoolerConfig,
        *,
        default_pooling_type: PoolingType = PoolingType.LAST,
    ):
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
94
            task="embed",
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            pooler_config=pooler_config,
            pooling_type=default_pooling_type,
        )

        return SimplePooler.from_config(resolved_config)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: Optional[ClassifierFn],
        *,
        default_pooling_type: PoolingType = PoolingType.LAST,
    ):
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
109
            task="classify",
110
111
112
            pooler_config=pooler_config,
            pooling_type=default_pooling_type,
        )
113
114

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
115
116

        return ClassifierPooler(
117
            pooling=pooling,
118
119
120
121
122
123
124
125
126
            classifier=classifier,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
127
        """
128
        Construct the updated pooling parameters to use for a supported task.
129
        """
130
        return PoolingParamsUpdate()
131
132
133
134
135
136
137
138
139
140

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


141
142
143
144
145
146
147
148
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        return pooling_metadata.prompt_lens

    return PoolingTensors.from_pooling_metadata(
149
        pooling_metadata, hidden_states[0].device).prompt_lens
150
151


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def get_prompt_token_ids(
        pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        assert pooling_metadata.prompt_token_ids is not None, (
            "Please set `requires_token_ids=True` in `get_pooling_updates`")

        return [
            pooling_metadata.prompt_token_ids[i, :num]
            for i, num in enumerate(pooling_metadata.prompt_lens)
        ]

    return [
        torch.tensor(seq_data_i.prompt_token_ids)
        for seq_data_i in pooling_metadata.seq_data.values()
    ]


169
170
def get_pooling_params(
        pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
171
172
173
174
    if isinstance(pooling_metadata, V0PoolingMetadata):
        pooling_params = [p for _, p in pooling_metadata.seq_groups]
    else:
        pooling_params = pooling_metadata.pooling_params
175
176
177
178
179
    return pooling_params


def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    pooling_params = get_pooling_params(pooling_metadata)
180
181
182
183
184
185
186
187
188
189

    tasks: list[PoolingTask] = [
        task for pooling_param in pooling_params
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def get_classification_activation_function(config: PretrainedConfig):
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
209

210
    return PoolerScore()
211

212

213
214
def build_output(
    all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

234
    @abstractmethod
235
    def get_supported_tasks(self) -> Set[PoolingTask]:
236
237
        raise NotImplementedError

238
239
240
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

241
242
243
244
245
    @abstractmethod
    def forward_one(
        self,
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
246
    ) -> torch.Tensor:
247
248
249
250
251
        """
        Note:
            `prompt_len=None` means `prompt_len=len(hidden_states)`.
        """
        raise NotImplementedError
252

253
254
    @abstractmethod
    def forward_all(
255
        self,
256
257
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
258
259
260
261
262
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
263
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
264
        pooling_metadata: PoolingMetadata,
265
266
267
268
269
270
271
272
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)

        if isinstance(hidden_states, list):
            return [
                self.forward_one(h, prompt_len)
                for h, prompt_len in zip(hidden_states, prompt_lens)
            ]
273

274
        return self.forward_all(hidden_states, prompt_lens)
275
276


277
278
class CLSPool(PoolingMethod):

279
280
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
281

282
    def forward_one(
283
        self,
284
285
286
287
288
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with CLS pooling"
289

290
        return hidden_states[0]
291

292
293
294
295
296
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
297
298
299
300
301
        first_token_flat_indices = torch.zeros_like(prompt_lens)
        first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
        return hidden_states[first_token_flat_indices]


302
class LastPool(PoolingMethod):
303

304
305
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
306

307
    def forward_one(
308
        self,
309
310
311
312
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return hidden_states[-1]
313

314
315
316
317
318
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
319
320
321
322
        last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
        return hidden_states[last_token_flat_indices]


323
class AllPool(PoolingMethod):
324

325
326
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
327

328
    def forward_one(
329
        self,
330
331
332
333
334
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with ALL pooling"
335

336
        return hidden_states
337

338
339
340
341
342
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
343
        return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
344
345


346
class MeanPool(PoolingMethod):
347

348
349
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
350

351
    def forward_one(
352
        self,
353
354
355
356
357
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with MEAN pooling"
358

359
        return hidden_states.mean(dim=0, dtype=torch.float32)
360

361
362
363
364
365
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
366
367
368
369
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

370
371
372
373
374
375
376
377
378
        start_indices = torch.cat([
            torch.tensor([0], device=hidden_states.device),
            torch.cumsum(prompt_lens[:-1], dim=0)
        ])
        end_indices = torch.cumsum(prompt_lens, dim=0)
        return (cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


379
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
380
381


382
class BasePoolerActivation(nn.Module, ABC):
383

384
385
386
387
388
389
390
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
391

392

393
class PoolerActivation(BasePoolerActivation):
394

395
396
397
398
399
400
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
401

402
403
404
405
406
407
408
409
410
411
412
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
413
414


415
416
417
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
418
419
420
        return pooled_data


421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


class PoolerClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)


class PoolerScore(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return pooled_data


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


459
460
class PoolerHead(nn.Module):

461
    def __init__(self, activation: PoolerActivation) -> None:
462
        super().__init__()
463
        self.activation = activation
464

465
466
467
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

468
469
470
471
472
473
474
475
476
477
478
479
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

        pooling_params = get_pooling_params(pooling_metadata)
480

481
        # for matryoshka representation
482
483
484
        dimensions_list = [
            pooling_param.dimensions for pooling_param in pooling_params
        ]
485
486
487
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
488
489
490
491
492
493
494
495
496
497
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
498

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data


class RewardPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerClassify())

    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):
        pooling_params = get_pooling_params(pooling_metadata)

        # for softmax
        flags = [p.softmax for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data
534
535


536
class SimplePooler(Pooler):
537
    """A layer that pools specific information from hidden states.
538

539
540
541
542
543
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
544

545
546
547
548
549
550
    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
551
552
553
554
555
556
        if pooler_config.task == "embed":
            head = EmbeddingPoolerHead()
        elif pooler_config.task == "encode":
            head = RewardPoolerHead()
        else:
            raise NotImplementedError(f"Unknown task: {pooler_config.task}")
557
558
559
560
561
562
563
564
        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

565
566
567
568
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
569
        return self.pooling.get_pooling_updates(task)
570

571
572
573
574
575
576
577
578
579
580
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


581
class StepPooler(Pooler):
582

583
    def __init__(self, ) -> None:
584
585
586
        super().__init__()

        self.pooling = AllPool()
587
        self.head = RewardPoolerHead()
588
589
590
591
592
593
594

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
595
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
596
597
598

        pooled_data = list[torch.Tensor]()

599
600
601
602
603
604
605
606
        pooling_params = get_pooling_params(pooling_metadata)

        for data, token_id, pooling_param in zip(pooled_data_lst,
                                                 prompt_token_ids,
                                                 pooling_params):
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

607
608
609
610
611
612
613
614
615
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

616
617
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
618

619
620
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
621

622
623
624
625
626
627
628
629
630
631
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


632
class ClassifierPooler(Pooler):
633
    """A pooling layer for classification tasks.
634
635

    This layer does the following:
636
637
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
638
    3. Applies an activation function to the output.
639
640
    """

641
642
643
644
645
646
647
648
    @staticmethod
    def act_fn_for_seq_cls(config: ModelConfig):
        return get_classification_activation_function(config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(config: ModelConfig):
        return get_cross_encoder_activation_function(config.hf_config)

649
650
    def __init__(
        self,
651
        pooling: PoolingFn,
652
653
        classifier: Optional[ClassifierFn],
        act_fn: Optional[PoolerActivation] = None,
654
    ) -> None:
655
        super().__init__()
656
657

        self.pooling = pooling
658
        self.classifier = classifier
659
        self.act_fn = act_fn or PoolerClassify()
660

661
662
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
663

664
665
    def forward(
        self,
666
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
667
668
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
669
        pooled_data = self.pooling(hidden_states, pooling_metadata)
670

671
672
673
674
675
676
677
678
        if self.classifier is not None:
            # apply classifier once on the full batch if possible
            if isinstance(pooled_data, torch.Tensor):
                pooled_data = self.classifier(pooled_data)
            elif len({data.shape for data in pooled_data}) <= 1:
                pooled_data = self.classifier(torch.stack(pooled_data))
            else:
                pooled_data = [self.classifier(data) for data in pooled_data]
679

680
681
682
683
684
685
686
687
688
689
        pooling_params = get_pooling_params(pooling_metadata)
        flags = [p.activation for p in pooling_params]

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
                self.act_fn(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709

        return build_output(scores)


class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
                    f"Supported tasks: {pooler.get_supported_tasks()}")

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
710

711
712
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
713

714
715
716
717
718
719
720
721
722
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

        if isinstance(hidden_states, list):
            hidden_states_lst = hidden_states
723
        else:
724
725
            prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
            hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
726

727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
        outputs = list[PoolingSequenceGroupOutput]()
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
                    f"Supported tasks: {self.get_supported_tasks()}")

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
                hidden_states_lst[offset:offset + num_items],
                pooling_metadata[offset:offset + num_items],
            )

            outputs.extend(group_output.outputs)
            offset += num_items

        return PoolerOutput(outputs)