pooler.py 23.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from abc import ABC, abstractmethod
from dataclasses import dataclass
5
from enum import IntEnum
6
from typing import Callable, Literal, Optional, TypeVar, Union
7
8
9

import torch
import torch.nn as nn
10
import torch.nn.functional as F
11
from transformers import PretrainedConfig
12
from typing_extensions import assert_never
13

14
from vllm.config import ModelConfig, PoolerConfig
15
16
17
from vllm.model_executor.pooling_metadata import (  # noqa: E501
    PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
18
from vllm.pooling_params import PoolingParams
19
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
20
from vllm.utils import resolve_obj_by_qualname
21
22
23
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
24
PoolingTask = Literal["encode", "embed", "classify", "score"]
25
26
27
28
29


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
30
    ALL = 1
31
    CLS = 2
32
    STEP = 3
33
    MEAN = 4
34
35


36
37
38
@dataclass(frozen=True)
class ResolvedPoolingConfig:
    pooling_type: PoolingType
39

40
41
42
43
    normalize: bool
    softmax: bool
    step_tag_id: Optional[int]
    returned_token_ids: Optional[list[int]]
44

45
46
47
48
    @classmethod
    def from_config_with_defaults(
        cls,
        pooler_config: PoolerConfig,
49
50
51
52
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
53
        returned_token_ids: Optional[list[int]] = None,
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    ) -> "ResolvedPoolingConfig":
        return cls(
            pooling_type=PoolingType[pooler_config.pooling_type]
            if pooler_config.pooling_type is not None else pooling_type,
            normalize=pooler_config.normalize
            if pooler_config.normalize is not None else normalize,
            softmax=pooler_config.softmax
            if pooler_config.softmax is not None else softmax,
            step_tag_id=pooler_config.step_tag_id
            if pooler_config.step_tag_id is not None else step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids
            if pooler_config.returned_token_ids is not None else
            returned_token_ids,
        )
68
69


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @staticmethod
    def from_config_with_defaults(
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
        step_tag_id: Optional[int] = None,
        returned_token_ids: Optional[list[int]] = None,
    ) -> "Pooler":
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
            pooling_type=pooling_type,
            normalize=normalize,
            softmax=softmax,
            step_tag_id=step_tag_id,
            returned_token_ids=returned_token_ids,
        )

        if pooling_type == PoolingType.STEP:
            return StepPooler.from_config(resolved_config)

        return SimplePooler.from_config(resolved_config)

    def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
        """
        Construct the pooling parameters to use for a task,
        or `None` if the task is not supported.
        """
        return None

    @abstractmethod
    def forward(
        self,
        hidden_states: Union[list[torch.Tensor], torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError


112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def get_prompt_lens(
    hidden_states: Union[torch.Tensor, list[torch.Tensor]],
    pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
    if isinstance(pooling_metadata, V1PoolingMetadata):
        return pooling_metadata.prompt_lens

    assert isinstance(hidden_states, torch.Tensor)
    return PoolingTensors.from_pooling_metadata(
        pooling_metadata, hidden_states.device).prompt_lens


def get_classification_activation_function(config: PretrainedConfig):
    return PoolerClassify()


def get_cross_encoder_activation_function(config: PretrainedConfig):
    function_name: Optional[str] = None
    if (hasattr(config, "sentence_transformers")
            and "activation_fn" in config.sentence_transformers):
        function_name = config.sentence_transformers["activation_fn"]
    elif (hasattr(config, "sbert_ce_default_activation_function")
          and config.sbert_ce_default_activation_function is not None):
        function_name = config.sbert_ce_default_activation_function

    if function_name is not None:
        assert function_name.startswith("torch.nn.modules."), (
            "Loading of activation functions is restricted to "
            "torch.nn.modules for security reasons")
        fn = resolve_obj_by_qualname(function_name)()
        return PoolerActivation.wraps(fn)
143

144
    return PoolerScore()
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

def build_output(all_data: torch.Tensor) -> PoolerOutput:
    all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
    return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

    @staticmethod
    def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
        if pooling_type == PoolingType.LAST:
            return LastPool()
        if pooling_type == PoolingType.ALL:
            return AllPool()
        if pooling_type == PoolingType.CLS:
            return CLSPool()
        if pooling_type == PoolingType.MEAN:
            return MeanPool()

        raise NotImplementedError(f"Unsupported method: {pooling_type}")

167
168
169
170
    @abstractmethod
    def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
        raise NotImplementedError

171
172
173
174
175
    @abstractmethod
    def forward_one(
        self,
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
176
    ) -> torch.Tensor:
177
178
179
180
181
        """
        Note:
            `prompt_len=None` means `prompt_len=len(hidden_states)`.
        """
        raise NotImplementedError
182

183
184
    @abstractmethod
    def forward_all(
185
        self,
186
187
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
188
189
190
191
192
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        raise NotImplementedError

    def forward(
        self,
193
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
194
        pooling_metadata: PoolingMetadata,
195
196
197
198
199
200
201
202
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)

        if isinstance(hidden_states, list):
            return [
                self.forward_one(h, prompt_len)
                for h, prompt_len in zip(hidden_states, prompt_lens)
            ]
203

204
        return self.forward_all(hidden_states, prompt_lens)
205
206


207
208
class CLSPool(PoolingMethod):

209
210
211
212
213
214
215
216
    def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
        # The equalities are split up to keep mypy happy
        if (task == "encode" or task == "embed" or task == "classify"
                or task == "score"):
            return PoolingParams()

        assert_never(task)

217
    def forward_one(
218
        self,
219
220
221
222
223
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with CLS pooling"
224

225
        return hidden_states[0]
226

227
228
229
230
231
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
232
233
234
235
236
        first_token_flat_indices = torch.zeros_like(prompt_lens)
        first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
        return hidden_states[first_token_flat_indices]


237
class LastPool(PoolingMethod):
238

239
240
241
242
243
244
245
246
    def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
        # The equalities are split up to keep mypy happy
        if (task == "encode" or task == "embed" or task == "classify"
                or task == "score"):
            return PoolingParams()

        assert_never(task)

247
    def forward_one(
248
        self,
249
250
251
252
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return hidden_states[-1]
253

254
255
256
257
258
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
259
260
261
262
        last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
        return hidden_states[last_token_flat_indices]


263
class AllPool(PoolingMethod):
264

265
266
267
268
269
270
271
272
273
274
    def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
        if task == "encode":
            return PoolingParams()

        # The equalities are split up to keep mypy happy
        if task == "embed" or task == "classify" or task == "score":
            return None

        assert_never(task)

275
    def forward_one(
276
        self,
277
278
279
280
281
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with ALL pooling"
282

283
        return hidden_states
284

285
286
287
288
289
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
290
291
        offset = 0
        pooled_data = list[torch.Tensor]()
292

293
294
295
296
297
298
299
        for prompt_len in prompt_lens:
            pooled_data.append(hidden_states[offset:offset + prompt_len])
            offset += prompt_len

        return pooled_data


300
class MeanPool(PoolingMethod):
301

302
303
304
305
306
307
308
309
    def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
        # The equalities are split up to keep mypy happy
        if (task == "encode" or task == "embed" or task == "classify"
                or task == "score"):
            return PoolingParams()

        assert_never(task)

310
    def forward_one(
311
        self,
312
313
314
315
316
        hidden_states: torch.Tensor,
        prompt_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with MEAN pooling"
317

318
        return hidden_states.mean(dim=0, dtype=torch.float32)
319

320
321
322
323
324
    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
325
326
327
328
        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

329
330
331
332
333
334
335
336
337
        start_indices = torch.cat([
            torch.tensor([0], device=hidden_states.device),
            torch.cumsum(prompt_lens[:-1], dim=0)
        ])
        end_indices = torch.cumsum(prompt_lens, dim=0)
        return (cumsum[end_indices - 1] - cumsum[start_indices] +
                hidden_states[start_indices]) / prompt_lens.unsqueeze(1)


338
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
339
340


341
class BasePoolerActivation(nn.Module, ABC):
342

343
344
345
346
347
348
349
    @abstractmethod
    def forward(self, pooled_data: _T) -> _T:
        # shape:
        # classify (& score) -> (batch_size, num_classes)
        # embed -> (batch_size, embedding_dim) or list(embedding_dim)
        #          (batch_size, dimensions) or list(dimensions) if using MRL
        raise NotImplementedError
350

351

352
class PoolerActivation(BasePoolerActivation):
353

354
355
356
357
358
359
    @staticmethod
    def wraps(module: nn.Module):
        if isinstance(module, nn.Identity):
            return PoolerIdentity()
        if isinstance(module, (nn.Sigmoid, nn.Softmax)):
            return PoolerClassify()
360

361
362
363
364
365
366
367
368
369
370
371
        return LambdaPoolerActivation(module)

    @abstractmethod
    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, pooled_data: _T) -> _T:
        if isinstance(pooled_data, list):
            return [self.forward_chunk(data) for data in pooled_data]

        return self.forward_chunk(pooled_data)
372
373


374
375
376
class PoolerIdentity(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
377
378
379
        return pooled_data


380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
class PoolerNormalize(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        x = F.normalize(pooled_data.float(), p=2, dim=-1)
        return x.to(pooled_data.dtype)


class PoolerClassify(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)


class PoolerScore(PoolerActivation):

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        num_labels = pooled_data.shape[-1]
        if num_labels < 2:
            return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)

        return pooled_data


class LambdaPoolerActivation(PoolerActivation):

    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()

        self.fn = fn

    def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
        return self.fn(pooled_data)


418
419
class PoolerHead(nn.Module):

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    @classmethod
    def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
        if pooler_config.normalize and pooler_config.softmax:
            raise ValueError("`normalize=True` and `softmax=True` should not "
                             "be set together")

        activation: PoolerActivation
        if pooler_config.normalize:
            activation = PoolerNormalize()
        elif pooler_config.softmax:
            activation = PoolerClassify()
        else:
            activation = PoolerIdentity()

        return cls(activation)

    def __init__(self, activation: PoolerActivation) -> None:
437
        super().__init__()
438

439
        self.activation = activation
440

441
442
443
    def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
                pooling_metadata: PoolingMetadata):

444
445
446
447
448
449
450
        # Using float32 in PoolerHead
        if isinstance(pooled_data, list):
            for i in range(len(pooled_data)):
                pooled_data[i] = pooled_data[i].to(torch.float32)
        else:
            pooled_data = pooled_data.to(torch.float32)

451
        # for matryoshka representation
452
453
454
455
456
457
458
459
460
461
462
        if isinstance(pooling_metadata, V0PoolingMetadata):
            dimensions_list = [
                pooling_param.dimensions
                for _, pooling_param in pooling_metadata.seq_groups
            ]
        else:
            assert isinstance(pooled_data, list)
            dimensions_list = [
                pooling_param.dimensions
                for pooling_param in pooling_metadata.pooling_params
            ]
463
464
465
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
466
467
468
469
470
471
472
473
474
475
            if len(set(dimensions_list)) == 1 and not isinstance(
                    pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]
476

477
        return self.activation(pooled_data)
478
479


480
class SimplePooler(Pooler):
481
    """A layer that pools specific information from hidden states.
482

483
484
485
486
487
    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.
    """
488
489

    @classmethod
490
    def from_config_with_defaults(  # type: ignore[override]
491
492
493
494
495
        cls,
        pooler_config: PoolerConfig,
        pooling_type: PoolingType,
        normalize: bool,
        softmax: bool,
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    ) -> "SimplePooler":
        resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
            pooler_config=pooler_config,
            pooling_type=pooling_type,
            normalize=normalize,
            softmax=softmax,
        )
        assert resolved_config.pooling_type != PoolingType.STEP

        return cls.from_config(resolved_config)

    @classmethod
    def from_config(
        cls,
        pooler_config: ResolvedPoolingConfig,
    ) -> "SimplePooler":
        pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
        head = PoolerHead.from_config(pooler_config)

        return cls(pooling, head)

    def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

523
524
525
    def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
        return self.pooling.get_pooling_params(task)

526
527
528
529
530
531
532
533
534
535
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


536
class StepPooler(Pooler):
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551

    @classmethod
    def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
        assert pooler_config.pooling_type == PoolingType.STEP

        return cls(
            PoolerHead.from_config(pooler_config),
            step_tag_id=pooler_config.step_tag_id,
            returned_token_ids=pooler_config.returned_token_ids,
        )

    def __init__(
        self,
        head: PoolerHead,
        *,
552
        step_tag_id: Optional[int] = None,
553
        returned_token_ids: Optional[list[int]] = None,
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
    ) -> None:
        super().__init__()

        self.pooling = AllPool()
        self.head = head
        self.step_tag_id = step_tag_id
        self.returned_token_ids = returned_token_ids

    def get_prompt_token_ids(
        self,
        pooling_metadata: PoolingMetadata,
    ) -> list[torch.Tensor]:
        if isinstance(pooling_metadata, V1PoolingMetadata):
            return [
                pooling_metadata.prompt_token_ids[i, :num]
                for i, num in enumerate(pooling_metadata.prompt_lens)
            ]
        return [
            torch.tensor(seq_data_i.prompt_token_ids)
            for seq_data_i in pooling_metadata.seq_data.values()
        ]

    def extract_states(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
        prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)

        pooled_data = list[torch.Tensor]()
        returned_token_ids = self.returned_token_ids
        step_tag_id = self.step_tag_id

        for data, token_id in zip(pooled_data_lst, prompt_token_ids):
            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]
            pooled_data.append(data)

        return pooled_data

598
599
600
601
602
603
604
605
606
607
    def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
        if task == "encode":
            return PoolingParams(logits_processing_needs_token_ids=True)

        # The equalities are split up to keep mypy happy
        if task == "embed" or task == "classify" or task == "score":
            return None

        assert_never(task)

608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.extract_states(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)


PoolingFn = Callable[
    [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
    Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]

623

624
625
class ClassifierPooler(nn.Module):
    """A pooling layer for classification tasks.
626
627

    This layer does the following:
628
629
630
631
632
633
    1. Applies a classification layer to the hidden states.
    2. Optionally applies a pooler layer.
    3. Applies an activation function to the output. In the case of
       classification models it is either sigmoid or softmax. In the
       case of scoring models, the same behavior is configuration
       dependent, as in the sentence-transformers library.
634
635
636
637
    """

    def __init__(
        self,
638
        config: ModelConfig,
639
640
641
642
        pooling: PoolingFn,
        classifier: ClassifierFn,
        act_fn: Optional[PoolerActivation] = None,
    ) -> None:
643
        super().__init__()
644
645

        self.pooling = pooling
646
        self.classifier = classifier
647

648
        self.classification_act_fn = get_classification_activation_function(
649
            config.hf_config) if act_fn is None else act_fn
650
        self.cross_encoder_act_fn = get_cross_encoder_activation_function(
651
            config.hf_config) if act_fn is None else act_fn
652
653
654
655

    def _get_act_fn(self, use_cross_encoder: bool):
        return (self.cross_encoder_act_fn
                if use_cross_encoder else self.classification_act_fn)
656

657
658
659
660
661
662
663
664
665
666
667
668
    def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
        if task == "encode":
            return PoolingParams()
        if task == "embed":
            return None
        if task == "classify":
            return PoolingParams()
        if task == "score":
            return PoolingParams(use_cross_encoder=True)

        assert_never(task)

669
670
    def forward(
        self,
671
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
672
673
674
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        """Pools sentence pair scores from the hidden_states."""
675
        pooled_data = self.pooling(hidden_states, pooling_metadata)
676

677
678
679
680
681
        # apply classifier once on the full batch if possible
        if isinstance(pooled_data, torch.Tensor):
            pooled_output = self.classifier(pooled_data)
        elif len({data.shape for data in pooled_data}) <= 1:
            pooled_output = self.classifier(torch.stack(pooled_data))
682
        else:
683
            pooled_output = [self.classifier(data) for data in pooled_data]
684

685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
        if isinstance(pooling_metadata, V0PoolingMetadata):
            use_cross_encoder_list = [
                pooling_param.use_cross_encoder
                for _, pooling_param in pooling_metadata.seq_groups
            ]
        else:
            use_cross_encoder_list = [
                pooling_param.use_cross_encoder
                for pooling_param in pooling_metadata.pooling_params
            ]

        # shape of scores: (batch_size, num_labels)
        if all(use_cross_encoder == use_cross_encoder_list[0]
               for use_cross_encoder in use_cross_encoder_list):
            act_fn = self._get_act_fn(use_cross_encoder_list[0])
            scores = act_fn(pooled_output)
        else:
            scores = torch.stack([
                self._get_act_fn(use_cross_encoder)(vecs)
                for use_cross_encoder, vecs in zip(use_cross_encoder_list,
                                                   pooled_output)
            ])
707

708
        return build_output(scores)