pooler.py 2.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
from enum import IntEnum

import torch
import torch.nn as nn

from vllm.model_executor.pooling_metadata import (PoolingMetadata,
                                                  PoolingTensors)
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput


class PoolingType(IntEnum):
    """Enumeration for different types of pooling methods."""
    LAST = 0
14
    ALL = 1
15
    CLS = 2
16
17
18
19
20
21
22
23
24
25
26


class Pooler(nn.Module):
    """A layer that pools specific information from hidden states.

    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Normalizes output if specified.
    3. Returns structured results as `PoolerOutput`.

    Attributes:
27
        pooling_type: The type of pooling to use (LAST, ALL, CLS).
28
29
30
        normalize: Whether to normalize the pooled data.
    """

31
32
33
34
    def __init__(self,
                 pooling_type: PoolingType,
                 normalize: bool,
                 softmax: bool = False):
35
        super().__init__()
36

37
38
        self.pooling_type = pooling_type
        self.normalize = normalize
39
        self.softmax = softmax
40
41
42
43
44
45
46

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        """Pools specific information from hidden states based on metadata."""
47

48
49
50
        prompt_lens = PoolingTensors.from_pooling_metadata(
            pooling_metadata, hidden_states.device).prompt_lens

51
52
53
54
55
56
        if self.pooling_type is PoolingType.CLS:
            first_token_flat_indices = torch.zeros_like(prompt_lens)
            first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
                                                         dim=0)[:-1]
            pooled_data = hidden_states[first_token_flat_indices]
        elif self.pooling_type == PoolingType.LAST:
57
58
            last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
            pooled_data = hidden_states[last_token_flat_indices]
59
60
61
62
63
64
        elif self.pooling_type == PoolingType.ALL:
            offset = 0
            pooled_data = []
            for prompt_len in prompt_lens:
                pooled_data.append(hidden_states[offset:offset + prompt_len])
                offset += prompt_len
65
66
67
68
69
70
        else:
            raise ValueError(f"Invalid pooling type: {self.pooling_type}")

        if self.normalize:
            pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)

71
72
73
        if self.softmax:
            pooled_data = nn.functional.softmax(pooled_data, dim=-1)

74
75
76
77
78
        pooled_outputs = [
            EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
        ]

        return PoolerOutput(outputs=pooled_outputs)