pooler.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import Callable, Optional, TypeVar, Union, cast
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig
16
17
18
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
19
from vllm.pooling_params import PoolingParams
20
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
21
from vllm.tasks import PoolingTask
22
23
from vllm.utils import current_stream, resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingCursor
24
25
26
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
27
28
29
30
PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
31
32
33
34
35


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
36
    ALL = 1
37
    CLS = 2
38
    STEP = 3
39
    MEAN = 4
40
41


42
43
44
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
45
    task: PoolingTask
46

47
    @classmethod
48
    def from_config(
49
        cls,
50
        task: PoolingTask,
51
52
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
53
        assert pooler_config.pooling_type is not None
54
        return cls(task=task,
55
                   pooling_type=PoolingType[pooler_config.pooling_type])
56
57


58
59
60
61
62
63
64
65
66
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


67
68
69
70
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
71
72
    def for_encode(pooler_config: PoolerConfig):
        if pooler_config.pooling_type == "STEP":
73
            return StepPooler()
74

75
76
77
        resolved_config = ResolvedPoolingConfig(task="encode",
                                                pooling_type=PoolingType.ALL)

78
79
        return SimplePooler.from_config(resolved_config)

80
    @staticmethod
81
82
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
83
            task="embed",
84
85
86
87
88
89
90
91
92
93
            pooler_config=pooler_config,
        )

        return SimplePooler.from_config(resolved_config)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: Optional[ClassifierFn],
    ):
94
        resolved_config = ResolvedPoolingConfig.from_config(
95
            task="classify",
96
97
            pooler_config=pooler_config,
        )
98
99

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
100
101

        return ClassifierPooler(
102
            pooling=pooling,
103
104
105
106
107
108
109
110
111
            classifier=classifier,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
112
        """
113
        Construct the updated pooling parameters to use for a supported task.
114
        """
115
        return PoolingParamsUpdate()
116
117
118
119
120
121
122
123
124
125

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


126
127
128
129
130
131
132
133
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        return pooling_metadata.prompt_lens

    return PoolingTensors.from_pooling_metadata(
134
        pooling_metadata, hidden_states[0].device).prompt_lens
135
136


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_prompt_token_ids(
        pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        assert pooling_metadata.prompt_token_ids is not None, (
            "Please set `requires_token_ids=True` in `get_pooling_updates`")

        return [
            pooling_metadata.prompt_token_ids[i, :num]
            for i, num in enumerate(pooling_metadata.prompt_lens)
        ]

    return [
        torch.tensor(seq_data_i.prompt_token_ids)
        for seq_data_i in pooling_metadata.seq_data.values()
    ]


154
155
def get_pooling_params(
        pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
156
157
158
159
    if isinstance(pooling_metadata, V0PoolingMetadata):
        pooling_params = [p for _, p in pooling_metadata.seq_groups]
    else:
        pooling_params = pooling_metadata.pooling_params
160
161
162
163
164
    return pooling_params


def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
    pooling_params = get_pooling_params(pooling_metadata)
165
166
167
168
169
170
171
172
173
174

    tasks: list[PoolingTask] = [
        task for pooling_param in pooling_params
        if (task := pooling_param.task) is not None
    ]
    assert len(pooling_params) == len(tasks)

    return tasks


175
def get_classification_activation_function(config: PretrainedConfig):
176
177
178
179
180
181
182
183
184
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
203

204
    return PoolerScore()
205

206

207
208
def build_output(
    all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
209
210
211
212
213
214
215
    # Pooling models D2H & synchronize occurs here
    if isinstance(all_data, list):
        all_data = [d.to("cpu", non_blocking=True) for d in all_data]
    else:
        all_data = all_data.to("cpu", non_blocking=True)
    current_stream().synchronize()

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

235
    @abstractmethod
236
    def get_supported_tasks(self) -> Set[PoolingTask]:
237
238
        raise NotImplementedError

239
240
241
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

242
243
    @abstractmethod
    def forward_all(
244
        self,
245
        hidden_states: torch.Tensor,
246
        pooling_cursor: PoolingCursor,
247
248
249
250
251
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
252
        hidden_states: torch.Tensor,
253
        pooling_metadata: PoolingMetadata,
254
    ) -> Union[list[torch.Tensor], torch.Tensor]:
255
256
        pooling_cursor = pooling_metadata.pooling_cursor
        return self.forward_all(hidden_states, pooling_cursor)
257
258


259
260
class CLSPool(PoolingMethod):

261
262
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
263

264
265
266
    def forward_all(
        self,
        hidden_states: torch.Tensor,
267
        pooling_cursor: PoolingCursor,
268
    ) -> Union[list[torch.Tensor], torch.Tensor]:
269
270
271
272
        assert not pooling_cursor.is_partial_prefill(), \
            "partial prefill not supported with CLS pooling"

        return hidden_states[pooling_cursor.first_token_indices_gpu]
273
274


275
class LastPool(PoolingMethod):
276

277
278
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
279

280
281
282
    def forward_all(
        self,
        hidden_states: torch.Tensor,
283
        pooling_cursor: PoolingCursor,
284
    ) -> Union[list[torch.Tensor], torch.Tensor]:
285
        return hidden_states[pooling_cursor.last_token_indices_gpu]
286
287


288
class AllPool(PoolingMethod):
289

290
291
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
292

293
294
295
    def forward_all(
        self,
        hidden_states: torch.Tensor,
296
        pooling_cursor: PoolingCursor,
297
    ) -> Union[list[torch.Tensor], torch.Tensor]:
298
299
300
301
302
303
304
305

        assert not pooling_cursor.is_partial_prefill(), \
            "partial prefill not supported with ALL pooling"

        hidden_states_lst = list(
            hidden_states.split(
                pooling_cursor.num_scheduled_tokens_cpu.tolist()))
        return [hidden_states_lst[i] for i in pooling_cursor.index]
306
307


308
class MeanPool(PoolingMethod):
309

310
311
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed", "classify", "score"}
312

313
    def forward_all(
314
        self,
315
        hidden_states: torch.Tensor,
316
317
318
319
        pooling_cursor: PoolingCursor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:

        assert not pooling_cursor.is_partial_prefill(), \
320
            "partial prefill not supported with MEAN pooling"
321

322
323
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device,
                                                        non_blocking=True)
324

325
326
327
328
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

329
330
331
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
        return (cumsum[end_indices] - cumsum[start_indices] +
332
333
334
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


335
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
336
337


338
class BasePoolerActivation(nn.Module, ABC):
339

340
341
342
343
344
345
346
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
347

348

349
class PoolerActivation(BasePoolerActivation):
350

351
352
353
354
355
356
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
357

358
359
360
361
362
363
364
365
366
367
368
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
369
370


371
372
373
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
374
375
376
        return pooled_data


377
378
379
380
381
382
383
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


384
385
386
387
388
389
class PoolerMultiLabelClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)


390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
class PoolerClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)


class PoolerScore(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return pooled_data


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


421
422
class PoolerHead(nn.Module):

423
    def __init__(self, activation: PoolerActivation) -> None:
424
        super().__init__()
425
        self.activation = activation
426

427
428
429
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

430
431
432
433
434
435
436
437
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

438
439
440
441
442
443
444
445
        # Load ST projector if available
        from vllm.config import get_current_vllm_config
        from vllm.model_executor.models.adapters import _load_st_projector

        vllm_config = get_current_vllm_config()
        self.projector = _load_st_projector(
            vllm_config.model_config) if vllm_config else None

446
447
448
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

449
450
451
452
453
454
455
456
457
458
459
460
461
462
        # Apply ST projector
        if self.projector is not None:
            projector = cast(nn.Module, self.projector)

            def _proj(x: torch.Tensor) -> torch.Tensor:
                orig_dtype = x.dtype
                y = projector(x.to(torch.float32))
                return y.to(orig_dtype)

            if isinstance(pooled_data, torch.Tensor):
                pooled_data = _proj(pooled_data)
            else:
                pooled_data = [_proj(t) for t in pooled_data]

463
        pooling_params = get_pooling_params(pooling_metadata)
464

465
466
467
468
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, embedding_dimension]

469
        # for matryoshka representation
470
471
472
        dimensions_list = [
            pooling_param.dimensions for pooling_param in pooling_params
        ]
473
474
475
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
476
477
478
479
480
481
482
483
484
485
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
486

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data


class RewardPoolerHead(PoolerHead):

    def __init__(self) -> None:
        super().__init__(activation=PoolerClassify())

    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):
        pooling_params = get_pooling_params(pooling_metadata)

        # for softmax
        flags = [p.softmax for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

        return pooled_data
522
523


524
class SimplePooler(Pooler):
525
    """A layer that pools specific information from hidden states.
526

527
528
529
530
531
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
532

533
534
535
536
537
538
    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
539
540
541
542
543
544
        if pooler_config.task == "embed":
            head = EmbeddingPoolerHead()
        elif pooler_config.task == "encode":
            head = RewardPoolerHead()
        else:
            raise NotImplementedError(f"Unknown task: {pooler_config.task}")
545
546
547
548
549
550
551
552
        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

553
554
555
556
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
557
        return self.pooling.get_pooling_updates(task)
558

559
560
561
562
563
564
565
566
567
568
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


569
class StepPooler(Pooler):
570

571
    def __init__(self, ) -> None:
572
573
574
        super().__init__()

        self.pooling = AllPool()
575
        self.head = RewardPoolerHead()
576
577
578
579
580
581
582

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
583
        prompt_token_ids = get_prompt_token_ids(pooling_metadata)
584
585
586

        pooled_data = list[torch.Tensor]()

587
588
589
590
591
592
593
594
        pooling_params = get_pooling_params(pooling_metadata)

        for data, token_id, pooling_param in zip(pooled_data_lst,
                                                 prompt_token_ids,
                                                 pooling_params):
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

595
596
597
598
599
600
601
602
603
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

604
605
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode"}
606

607
608
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
609

610
611
612
613
614
615
616
617
618
619
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


620
class ClassifierPooler(Pooler):
621
    """A pooling layer for classification tasks.
622
623

    This layer does the following:
624
625
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
626
    3. Applies an activation function to the output.
627
628
    """

629
630
631
632
633
634
635
636
    @staticmethod
    def act_fn_for_seq_cls(config: ModelConfig):
        return get_classification_activation_function(config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(config: ModelConfig):
        return get_cross_encoder_activation_function(config.hf_config)

637
638
    def __init__(
        self,
639
        pooling: PoolingFn,
640
641
        classifier: Optional[ClassifierFn],
        act_fn: Optional[PoolerActivation] = None,
642
    ) -> None:
643
        super().__init__()
644
645

        self.pooling = pooling
646
        self.classifier = classifier
647
        self.act_fn = act_fn or PoolerClassify()
648

649
650
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
651

652
653
    def forward(
        self,
654
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
655
656
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
657
        pooled_data = self.pooling(hidden_states, pooling_metadata)
658

659
660
661
662
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

663
664
665
666
667
668
669
670
        if self.classifier is not None:
            # apply classifier once on the full batch if possible
            if isinstance(pooled_data, torch.Tensor):
                pooled_data = self.classifier(pooled_data)
            elif len({data.shape for data in pooled_data}) <= 1:
                pooled_data = self.classifier(torch.stack(pooled_data))
            else:
                pooled_data = [self.classifier(data) for data in pooled_data]
671

672
673
674
675
676
677
678
679
680
681
        pooling_params = get_pooling_params(pooling_metadata)
        flags = [p.activation for p in pooling_params]

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
                self.act_fn(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701

        return build_output(scores)


class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
                    f"Supported tasks: {pooler.get_supported_tasks()}")

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
702

703
704
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
705

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

        outputs = list[PoolingSequenceGroupOutput]()
        offset = 0
        for task, group in groupby(get_tasks(pooling_metadata)):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
                    f"Supported tasks: {self.get_supported_tasks()}")

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
723
                hidden_states,
724
725
726
727
728
729
730
                pooling_metadata[offset:offset + num_items],
            )

            outputs.extend(group_output.outputs)
            offset += num_items

        return PoolerOutput(outputs)