gritlm.py 8.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Set
4

5
import numpy as np
6
7
import torch

8
from vllm.config import ModelConfig, VllmConfig
9
from vllm.logger import init_logger
10
11
12
13
from vllm.model_executor.layers.pooler import (
    DispatchPooler,
    Pooler,
    PoolerNormalize,
14
    PoolingMethod,
15
    PoolingParamsUpdate,
16
17
    TokenPoolerHeadOutput,
    TokenPoolingMethodOutput,
18
)
19
from vllm.model_executor.models.llama import LlamaForCausalLM
20
from vllm.tasks import PoolingTask
21
from vllm.tokenizers import cached_tokenizer_from_config
22
from vllm.v1.outputs import TokenPoolerOutput
23
from vllm.v1.pool.metadata import PoolingMetadata
24

25
from .interfaces_base import default_pooling_type
26

27
28
29
logger = init_logger(__name__)


30
class GritLMMeanPool(PoolingMethod):
31
    """As `MeanPool`, but only includes non-instruction tokens."""
32

33
    def __init__(self, model_config: ModelConfig):
34
35
        super().__init__()

36
        self.model_config = model_config
37

38
        tokenizer = cached_tokenizer_from_config(self.model_config)
39
40
41
42
43
44
45
46
47
48

        # Collect the tokens needed for pattern matching.
        # "▁<" is different from "_<". The former uses "▁" to indicate that
        # the next token is the start of a word.
        # "<0x0A>" is the newline token (i.e. "\n")."
        self.token_ids = {
            tok: tokenizer.convert_tokens_to_ids([tok])[0]
            for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
        }

49
50
        def tokens_to_ids(tokens: list[str]) -> np.ndarray:
            return np.array([self.token_ids[token] for token in tokens])
51

52
        self.user_pattern_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"])
53
        self.embed_newline_pattern_ids = tokens_to_ids(
54
55
56
            ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]
        )
        self.embed_pattern_ids = tokens_to_ids(["▁<", "|", "embed", "|", ">", "<0x0A>"])
57

58
59
60
61
62
    def _find_array(
        self,
        arr: np.ndarray,
        target: np.ndarray,
        start_idx: int = 0,
63
        end_idx: int | None = None,
64
    ) -> int:
65
        """
66
67
        Find the first occurrence of `target` in `arr` starting from
        `start_idx`.
68
69

        Args:
70
71
72
73
            arr: The array to search within.
            target: The consecutive subsequence to find.
            start_idx: The starting index to search from (inclusive).
            end_idx: The ending index to search from (exclusive).
74
75

        Returns:
76
            The index of the first occurrence of `target` in `arr`.
77
78
        """
        if start_idx < 0:
79
80
81
            raise ValueError("`start_idx` must be non-negative")
        if len(arr) == 0 or len(target) == 0:
            raise ValueError("Empty `arr` or `target` not allowed")
82

83
        arr_len = len(arr)
84
        target_len = len(target)
85
86
87
88
89

        if end_idx is None:
            end_idx = arr_len

        for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
90
            if (arr[i : i + target_len] == target).all():
91
                return i
92

93
94
        return -1

95
    def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
96
97
98
99
100
101
102
103
104
105
106
107
108
        """
        Get the length of the instruction in the prompt.

        We do a pattern matching to find the instruction in the prompt,
        and then return the length of the instruction.

        The pattern matching is done using integers instead of strings
        because the prompt is given as a list of token IDs.
        """
        instruction_len = 0

        # Return no instruction in case of missing BOS token.
        if prompt_token_ids[0] != self.token_ids["<s>"]:
109
110
111
112
113
            logger.warning(
                "BOS token not found in prompt, "
                "thus using empty string for instruction. "
                "GritLM requires BOS token in prompt."
            )
114
115
116
117
118
            return instruction_len

        # If user pattern is found in the prompt, that means there should be
        # a newline token before the embed pattern.
        embed_pattern_ids = self.embed_pattern_ids
119
120
121
122
123
124
        if (
            self._find_array(
                prompt_token_ids, self.user_pattern_ids, start_idx=1, end_idx=2
            )
            == 1
        ):
125
126
127
            embed_pattern_ids = self.embed_newline_pattern_ids

        # Find the embed pattern in the prompt.
128
129
130
        found_embed_pattern_idx = self._find_array(
            prompt_token_ids, embed_pattern_ids, start_idx=1
        )
131
132
133
134

        if found_embed_pattern_idx != -1:
            instruction_len = found_embed_pattern_idx + len(embed_pattern_ids)
        else:
135
136
137
138
139
            logger.warning(
                "Query instruction not found in prompt, "
                "thus using BOS token as instruction instead. "
                "GritLM requires query instruction in prompt."
            )
140
141
142
143
            instruction_len = 1

        return instruction_len

144
    def get_supported_tasks(self) -> Set[PoolingTask]:
145
        return {"embed"}
146

147
148
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
149
150
151

    def forward(
        self,
152
        hidden_states: torch.Tensor,
153
        pooling_metadata: PoolingMetadata,
154
    ) -> TokenPoolingMethodOutput:
155
        prompt_lens = pooling_metadata.prompt_lens
156
        instr_lens = torch.tensor(
157
            [
158
                self._get_instruction_len(token_ids.cpu().numpy())
159
                for token_ids in pooling_metadata.get_prompt_token_ids()
160
            ],
161
            device="cpu",
162
163
        )

164
165
166
167
168
169
        offset = 0
        pooled_data = list[torch.Tensor]()
        for prompt_len, instr_len in zip(prompt_lens, instr_lens):
            pooled_data.append(
                hidden_states[offset + instr_len : offset + prompt_len].mean(
                    dim=0, dtype=torch.float32
170
                )
171
172
            )
            offset += prompt_len
173

174
        return pooled_data
175
176


177
class GritLMPooler(Pooler):
178
    def __init__(self, model_config: ModelConfig):
179
        super().__init__()
180

181
        self.pooling = GritLMMeanPool(model_config)
182
        self.activation = PoolerNormalize()
183

184
185
186
187
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
188
        return self.pooling.get_pooling_updates(task)
189

190
191
192
193
194
195
196
    def head(
        self,
        pooled_data: TokenPoolingMethodOutput,
        pooling_metadata: PoolingMetadata,
    ) -> TokenPoolerHeadOutput:
        return self.activation(pooled_data)

197
198
199
200
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
201
    ) -> TokenPoolerOutput:
202
203
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
204
        return pooled_data
205
206


207
208
@default_pooling_type("MEAN")
class GritLM(LlamaForCausalLM):
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    """This class implements the embedding model for parasail-ai/GritLM-7B-vllm.

    The class inherits from LlamaForCausalLM and provides a custom pooling
    layer.

    The main difference between the pooling layer in GritLM and the one in
    LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
    when pooling the hidden states.

    Embedding prompts should be in the following format:
    - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT".
    - Without instruction: "<|embed|>\nPROMPT".

    Generation prompts should be in the following format:
    - "<|user|>\nPROMPT\n<|assistant|>\n"
    """

226
227
    is_pooling_model = True

228
229
230
231
232
233
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
234
235
236
        if vllm_config.model_config.runner_type == "pooling":
            hf_config = vllm_config.model_config.hf_config
            hf_config.is_causal = False
237

238
            vllm_config.cache_config.sliding_window = None
239

240
            hf_config.sliding_window = None
241

242
        super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
243

244
245
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config is not None:
246
247
            self.pooler = DispatchPooler(
                {
248
                    "token_embed": Pooler.for_token_embed(pooler_config),
249
                    "embed": GritLMPooler(vllm_config.model_config),
250
251
                }
            )