pooler.py 25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Callable, Mapping, Set
5
from dataclasses import dataclass
6
from enum import IntEnum
7
from itertools import groupby
8
from typing import TypeVar
9
10
11

import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
from transformers import PretrainedConfig
14

15
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
16
from vllm.logger import init_logger
17
from vllm.model_executor.models.adapters import _load_st_projector
18
19
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
20
from vllm.utils.import_utils import resolve_obj_by_qualname
21
from vllm.v1.outputs import PoolerOutput
22
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
23

24
25
logger = init_logger(__name__)

26
PoolingFn = Callable[
27
28
    [torch.Tensor | list[torch.Tensor], PoolingMetadata],
    torch.Tensor | list[torch.Tensor],
29
]
30
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
31
32
33
34


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
35

36
    LAST = 0
37
    ALL = 1
38
    CLS = 2
39
    STEP = 3
40
    MEAN = 4
41
42


43
44
45
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
46
    task: PoolingTask
47

48
    @classmethod
49
    def from_config(
50
        cls,
51
        task: PoolingTask,
52
53
        pooler_config: PoolerConfig,
    ) -> "ResolvedPoolingConfig":
54
        assert pooler_config.pooling_type is not None
55
        return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
56
57


58
59
60
61
62
63
64
65
66
@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False
    """Set this flag to enable `get_prompt_token_ids` for your pooler."""

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids


67
def get_classification_activation_function(config: PretrainedConfig):
68
69
70
71
72
73
74
75
76
    # Implement alignment with transformers ForSequenceClassificationLoss
    # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
    problem_type = getattr(config, "problem_type", "")
    if problem_type == "regression":
        return PoolerIdentity()
    if problem_type == "single_label_classification":
        return PoolerClassify()
    if problem_type == "multi_label_classification":
        return PoolerMultiLabelClassify()
77
78
79
80
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
81
    function_name: str | None = None
82
83
84
85
    if (
        hasattr(config, "sentence_transformers")
        and "activation_fn" in config.sentence_transformers
    ):
86
        function_name = config.sentence_transformers["activation_fn"]
87
88
89
90
    elif (
        hasattr(config, "sbert_ce_default_activation_function")
        and config.sbert_ce_default_activation_function is not None
    ):
91
92
93
94
95
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
96
97
            "torch.nn.modules for security reasons"
        )
98
99
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
100

101
    return PoolerClassify()
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

class PoolingMethod(nn.Module, ABC):
    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

118
    @abstractmethod
119
    def get_supported_tasks(self) -> Set[PoolingTask]:
120
121
        raise NotImplementedError

122
123
124
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

125
126
    @abstractmethod
    def forward_all(
127
        self,
128
        hidden_states: torch.Tensor,
129
        pooling_cursor: PoolingCursor,
130
    ) -> list[torch.Tensor] | torch.Tensor:
131
132
133
134
        raise NotImplementedError

    def forward(
        self,
135
        hidden_states: torch.Tensor,
136
        pooling_metadata: PoolingMetadata,
137
    ) -> list[torch.Tensor] | torch.Tensor:
138
139
        pooling_cursor = pooling_metadata.pooling_cursor
        return self.forward_all(hidden_states, pooling_cursor)
140
141


142
class CLSPool(PoolingMethod):
143
    def get_supported_tasks(self) -> Set[PoolingTask]:
144
        return {"token_embed", "token_classify", "embed", "classify", "score"}
145

146
147
148
    def forward_all(
        self,
        hidden_states: torch.Tensor,
149
        pooling_cursor: PoolingCursor,
150
    ) -> list[torch.Tensor] | torch.Tensor:
151
        assert not pooling_cursor.is_partial_prefill(), (
152
            "partial prefill not supported with CLS pooling"
153
        )
154
155

        return hidden_states[pooling_cursor.first_token_indices_gpu]
156
157


158
class LastPool(PoolingMethod):
159
    def get_supported_tasks(self) -> Set[PoolingTask]:
160
        return {"token_embed", "token_classify", "embed", "classify", "score"}
161

162
163
164
    def forward_all(
        self,
        hidden_states: torch.Tensor,
165
        pooling_cursor: PoolingCursor,
166
    ) -> list[torch.Tensor] | torch.Tensor:
167
        return hidden_states[pooling_cursor.last_token_indices_gpu]
168
169


170
class AllPool(PoolingMethod):
171
    def get_supported_tasks(self) -> Set[PoolingTask]:
172
        return {"token_embed", "token_classify"}
173

174
175
176
    def forward_all(
        self,
        hidden_states: torch.Tensor,
177
        pooling_cursor: PoolingCursor,
178
    ) -> list[torch.Tensor] | torch.Tensor:
179
        assert not pooling_cursor.is_partial_prefill(), (
180
            "partial prefill not supported with ALL pooling"
181
        )
182
183

        hidden_states_lst = list(
184
185
            hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
        )
186
        return [hidden_states_lst[i] for i in pooling_cursor.index]
187
188


189
class MeanPool(PoolingMethod):
190
    def get_supported_tasks(self) -> Set[PoolingTask]:
191
        return {"token_embed", "token_classify", "embed", "classify", "score"}
192

193
    def forward_all(
194
        self,
195
        hidden_states: torch.Tensor,
196
        pooling_cursor: PoolingCursor,
197
    ) -> list[torch.Tensor] | torch.Tensor:
198
        assert not pooling_cursor.is_partial_prefill(), (
199
            "partial prefill not supported with MEAN pooling"
200
        )
201

202
203
204
        prompt_lens = pooling_cursor.prompt_lens_cpu.to(
            hidden_states.device, non_blocking=True
        )
205

206
207
208
209
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

210
211
        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu
212
213
214
        return (
            cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
        ) / prompt_lens.unsqueeze(1)
215
216


217
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
218
219


220
221
222
223
224
225
226
227
class BasePoolerActivation(nn.Module, ABC):
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
228

229

230
231
232
233
234
235
236
class PoolerActivation(BasePoolerActivation):
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
237

238
239
240
241
242
243
244
245
246
247
248
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
249
250


251
252
class PoolerIdentity(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
253
254
255
        return pooled_data


256
257
class PoolerNormalize(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
258
        return F.normalize(pooled_data, p=2, dim=-1)
259
260


261
262
class PoolerMultiLabelClassify(PoolerActivation):
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
263
        return F.sigmoid(pooled_data)
264
265


266
class PoolerClassify(PoolerActivation):
267
268
    def __init__(self, *, static_num_labels: bool = True) -> None:
        super().__init__()
269

270
271
        if static_num_labels:
            vllm_config = get_current_vllm_config()
272
273
274
            self.num_labels = getattr(
                vllm_config.model_config.hf_config, "num_labels", 0
            )
275
            if self.num_labels == 0:
276
277
278
279
280
                logger.warning(
                    "num_labels should be > 0 for classification"
                    "models, falling back to softmax. "
                    "Please check if the configuration is correct."
                )
281
282
        else:
            self.num_labels = None
283
284

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
285
286
287
        num_labels = (
            self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
        )
288

289
        if num_labels < 2:
290
            return F.sigmoid(pooled_data)
291

292
        return F.softmax(pooled_data, dim=-1)
293
294
295
296
297
298
299
300
301
302
303
304


class LambdaPoolerActivation(PoolerActivation):
    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
    def for_token_embed(pooler_config: PoolerConfig):
        head = TokenEmbeddingPoolerHead()

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_token_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None = None,
        act_fn: PoolerActivation | str | None = None,
    ):
        head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)

        if pooler_config.pooling_type == "STEP":
            return StepPooler(head=head)

        return AllPooler(head=head)

    @staticmethod
    def for_embed(pooler_config: PoolerConfig):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="embed",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
        head = EmbeddingPoolerHead()

        return SimplePooler(pooling=pooling, head=head)

    @staticmethod
    def for_classify(
        pooler_config: PoolerConfig,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ):
        resolved_config = ResolvedPoolingConfig.from_config(
            task="classify",
            pooler_config=pooler_config,
        )

        pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)

        return ClassifierPooler(
            pooling=pooling,
            classifier=classifier,
            act_fn=act_fn,
        )

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        """
        Construct the updated pooling parameters to use for a supported task.
        """
        return PoolingParamsUpdate()

    @abstractmethod
    def forward(
        self,
        hidden_states: list[torch.Tensor] | torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


381
382
383
384
385
386
387
388
389
390
391
392
class DummyPooler(Pooler):
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"plugin", "score"}

    def forward(
        self,
        hidden_states: list[torch.Tensor] | torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        return hidden_states


393
class PoolerHead(nn.Module):
394
    def __init__(self, activation: PoolerActivation) -> None:
395
        super().__init__()
396
        self.activation = activation
397

398
399
    def forward(
        self,
400
        pooled_data: list[torch.Tensor] | torch.Tensor,
401
402
        pooling_metadata: PoolingMetadata,
    ):
403
404
405
406
407
408
409
        return self.activation(pooled_data)


class EmbeddingPoolerHead(PoolerHead):
    def __init__(self) -> None:
        super().__init__(activation=PoolerNormalize())

410
411
        # Load ST projector if available
        vllm_config = get_current_vllm_config()
412
        self.projector: nn.Module | None = (
413
414
            _load_st_projector(vllm_config.model_config) if vllm_config else None
        )
415
        self.head_dtype = vllm_config.model_config.head_dtype
416

417
418
    def forward(
        self,
419
        pooled_data: list[torch.Tensor] | torch.Tensor,
420
421
        pooling_metadata: PoolingMetadata,
    ):
422
423
424
425
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_dimension]

426
427
        pooled_data = pooled_data.to(self.head_dtype)

428
429
        # Apply ST projector
        if self.projector is not None:
430
            pooled_data = self.projector(pooled_data)
431
        # pooled_data shape: [batchsize, embedding_dimension]
432

433
        pooling_params = pooling_metadata.pooling_params
434

435
        # for matryoshka representation
436
        dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
437
438
439
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
440
            if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
441
442
443
444
445
446
447
448
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
449

450
451
452
453
454
455
456
457
458
459
460
        # for normalize
        flags = [p.normalize for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

461
        # pooled_data shape: [batchsize, embedding_dimension]
462
463
464
        return pooled_data


465
class SimplePooler(Pooler):
466
    """A layer that pools specific information from hidden states.
467

468
469
470
471
472
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
473

474
475
476
477
478
479
    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

480
481
482
483
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
484
        return self.pooling.get_pooling_updates(task)
485

486
487
    def forward(
        self,
488
        hidden_states: torch.Tensor | list[torch.Tensor],
489
490
491
492
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
493
        return pooled_data
494
495


496
class ClassifierPooler(Pooler):
497
    """A pooling layer for classification tasks.
498
499

    This layer does the following:
500
501
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
502
    3. Applies an activation function to the output.
503
504
    """

505
    @staticmethod
506
507
508
509
510
511
    def act_fn_for_seq_cls(model_config: ModelConfig):
        return get_classification_activation_function(model_config.hf_config)

    @staticmethod
    def act_fn_for_cross_encoder(model_config: ModelConfig):
        return get_cross_encoder_activation_function(model_config.hf_config)
512
513

    @staticmethod
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
    def resolve_act_fn(
        model_config: ModelConfig,
        static_num_labels: bool = True,
        act_fn: PoolerActivation | str | None = None,
    ):
        if isinstance(act_fn, str):
            if act_fn == "classify":
                return ClassifierPooler.act_fn_for_seq_cls(model_config)
            elif act_fn == "score":
                return ClassifierPooler.act_fn_for_cross_encoder(model_config)
            else:
                raise ValueError(f"act_fn [{act_fn=}] not supported.")
        elif act_fn is None:
            return PoolerClassify(static_num_labels=static_num_labels)
        else:
            assert callable(act_fn)
            return act_fn
531

532
533
    def __init__(
        self,
534
        pooling: PoolingFn,
535
        classifier: ClassifierFn | None,
536
        act_fn: PoolerActivation | str | None = None,
537
    ) -> None:
538
        super().__init__()
539

540
        vllm_config = get_current_vllm_config()
541
        self.pooling = pooling
542
        self.classifier = classifier
543
544
545
        self.act_fn = self.resolve_act_fn(
            vllm_config.model_config, static_num_labels=True, act_fn=act_fn
        )
546
        self.logit_bias: float | None = (
547
548
            vllm_config.model_config.pooler_config.logit_bias
        )
549
        self.head_dtype = vllm_config.model_config.head_dtype
550

551
552
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}
553

554
555
    def forward(
        self,
556
        hidden_states: torch.Tensor | list[torch.Tensor],
557
558
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
559
        pooled_data = self.pooling(hidden_states, pooling_metadata)
560
561
562
563
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

564
565
        pooled_data = pooled_data.to(self.head_dtype)

566
        if self.classifier is not None:
567
568
            pooled_data = self.classifier(pooled_data)
        # pooled_data shape: [batchsize, num_labels]
569

570
571
572
        if self.logit_bias is not None:
            pooled_data -= self.logit_bias

573
        pooling_params = pooling_metadata.pooling_params
574
        flags = [p.use_activation for p in pooling_params]
575
576
577
578
579

        if len(set(flags)) == 1:
            scores = self.act_fn(pooled_data) if flags[0] else pooled_data
        else:
            scores = [
580
                self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
581
            ]
582

583
        # scores shape: [batchsize, num_labels]
584
        return scores
585
586


587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
    def forward(
        self, pooled_data: torch.Tensor, pooling_param: PoolingParams
    ) -> torch.Tensor:
        pooled_data = pooled_data.to(self.head_dtype)
        # pooled_data shape: [n_tokens, hidden_dimension]

        # Apply ST projector
        if self.projector is not None:
            pooled_data = self.projector(pooled_data)
        # pooled_data shape: [n_tokens, embedding_dimension]

        # for matryoshka representation
        pooled_data = pooled_data[..., : pooling_param.dimensions]

        # for normalize
        if pooling_param.normalize:
            pooled_data = self.activation(pooled_data)

        # pooled_data shape: [n_tokens, embedding_dimension]
        return pooled_data


class TokenClassifierPoolerHead(nn.Module):
    def __init__(
        self,
        classifier: ClassifierFn | None,
        act_fn: PoolerActivation | str | None = None,
    ) -> None:
        super().__init__()
        vllm_config = get_current_vllm_config()

        self.classifier = classifier
        self.act_fn = ClassifierPooler.resolve_act_fn(
            vllm_config.model_config, static_num_labels=False, act_fn=act_fn
        )
        self.logit_bias: float | None = (
            vllm_config.model_config.pooler_config.logit_bias
        )
        self.head_dtype = vllm_config.model_config.head_dtype

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_classify"}

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_param: PoolingParams,
    ) -> torch.Tensor:
        hidden_states = hidden_states.to(self.head_dtype)
        # hidden_states shape: [n_token, hidden_size]

        if self.classifier is not None:
            scores = self.classifier(hidden_states)
        else:
            scores = hidden_states
        # scores shape: [n_token, num_labels]

        if self.logit_bias is not None:
            scores -= self.logit_bias

648
        if pooling_param.use_activation:
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
            scores = self.act_fn(scores)

        # scores shape: [n_token, num_labels]
        return scores


class AllPooler(Pooler):
    def __init__(self, head: nn.Module | PoolerHead) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
671
        pooling_params = pooling_metadata.pooling_params
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        assert len(pooled_data) == len(pooling_params)

        pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
        return pooled_data


class StepPooler(Pooler):
    def __init__(self, head: nn.Module | PoolerHead) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head

    def extract_states(
        self,
        hidden_states: torch.Tensor | list[torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> torch.Tensor | list[torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
691
        prompt_token_ids = pooling_metadata.get_prompt_token_ids()
692
693
694

        pooled_data = list[torch.Tensor]()

695
        pooling_params = pooling_metadata.pooling_params
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723

        for data, token_id, pooling_param in zip(
            pooled_data_lst, prompt_token_ids, pooling_params
        ):
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)

    def forward(
        self,
        hidden_states: torch.Tensor | list[torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
724
        pooling_params = pooling_metadata.pooling_params
725
726
727
728
729
730
        assert len(pooled_data) == len(pooling_params)

        pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
        return pooled_data


731
732
733
734
735
736
737
738
739
740
class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
741
742
                    f"Supported tasks: {pooler.get_supported_tasks()}"
                )
743
744
745
746
747

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)
748

749
750
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)
751

752
753
    def forward(
        self,
754
        hidden_states: torch.Tensor | list[torch.Tensor],
755
756
757
758
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

759
        outputs = list[torch.Tensor]()
760
        offset = 0
761
        for task, group in groupby(pooling_metadata.tasks):
762
763
764
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task} "
765
766
                    f"Supported tasks: {self.get_supported_tasks()}"
                )
767
768
769

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
770
                hidden_states,
771
                pooling_metadata[offset : offset + num_items],
772
773
            )

774
            outputs.extend(group_output)
775
776
            offset += num_items

777
        return outputs
778
779
780
781

    def extra_repr(self) -> str:
        s = f"supported_task={self.get_supported_tasks()}"
        return s