gritlm.py 9.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Set
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8
9
10

from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
11
12
13
14
15
16
17
18
19
from vllm.model_executor.layers.pooler import (
    DispatchPooler,
    Pooler,
    PoolerHead,
    PoolerNormalize,
    PoolingParamsUpdate,
    get_prompt_lens,
    get_prompt_token_ids,
)
20
from vllm.model_executor.models.llama import LlamaForCausalLM
21
from vllm.tasks import PoolingTask
22
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
23
from vllm.v1.outputs import PoolerOutput
24
from vllm.v1.pool.metadata import PoolingMetadata
25

26
from .interfaces_base import default_pooling_type
27

28
29
30
logger = init_logger(__name__)


31
32
class GritLMMeanPool(nn.Module):
    """As `MeanPool`, but only includes non-instruction tokens."""
33
34
35
36
37
38

    def __init__(self, model_config: ModelConfig):
        super().__init__()

        self.model_config = model_config

39
        tokenizer = cached_tokenizer_from_config(self.model_config)
40
41
42
43
44
45
46
47
48
49

        # Collect the tokens needed for pattern matching.
        # "▁<" is different from "_<". The former uses "▁" to indicate that
        # the next token is the start of a word.
        # "<0x0A>" is the newline token (i.e. "\n")."
        self.token_ids = {
            tok: tokenizer.convert_tokens_to_ids([tok])[0]
            for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
        }

50
51
        def tokens_to_ids(tokens: list[str]) -> np.ndarray:
            return np.array([self.token_ids[token] for token in tokens])
52

53
        self.user_pattern_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"])
54
        self.embed_newline_pattern_ids = tokens_to_ids(
55
56
57
            ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]
        )
        self.embed_pattern_ids = tokens_to_ids(["▁<", "|", "embed", "|", ">", "<0x0A>"])
58

59
60
61
62
63
    def _find_array(
        self,
        arr: np.ndarray,
        target: np.ndarray,
        start_idx: int = 0,
64
        end_idx: int | None = None,
65
    ) -> int:
66
        """
67
68
        Find the first occurrence of `target` in `arr` starting from
        `start_idx`.
69
70

        Args:
71
72
73
74
            arr: The array to search within.
            target: The consecutive subsequence to find.
            start_idx: The starting index to search from (inclusive).
            end_idx: The ending index to search from (exclusive).
75
76

        Returns:
77
            The index of the first occurrence of `target` in `arr`.
78
79
        """
        if start_idx < 0:
80
81
82
            raise ValueError("`start_idx` must be non-negative")
        if len(arr) == 0 or len(target) == 0:
            raise ValueError("Empty `arr` or `target` not allowed")
83

84
        arr_len = len(arr)
85
        target_len = len(target)
86
87
88
89
90

        if end_idx is None:
            end_idx = arr_len

        for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
91
            if (arr[i : i + target_len] == target).all():
92
                return i
93

94
95
        return -1

96
    def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
97
98
99
100
101
102
103
104
105
106
107
108
109
        """
        Get the length of the instruction in the prompt.

        We do a pattern matching to find the instruction in the prompt,
        and then return the length of the instruction.

        The pattern matching is done using integers instead of strings
        because the prompt is given as a list of token IDs.
        """
        instruction_len = 0

        # Return no instruction in case of missing BOS token.
        if prompt_token_ids[0] != self.token_ids["<s>"]:
110
111
112
113
114
            logger.warning(
                "BOS token not found in prompt, "
                "thus using empty string for instruction. "
                "GritLM requires BOS token in prompt."
            )
115
116
117
118
119
            return instruction_len

        # If user pattern is found in the prompt, that means there should be
        # a newline token before the embed pattern.
        embed_pattern_ids = self.embed_pattern_ids
120
121
122
123
124
125
        if (
            self._find_array(
                prompt_token_ids, self.user_pattern_ids, start_idx=1, end_idx=2
            )
            == 1
        ):
126
127
128
            embed_pattern_ids = self.embed_newline_pattern_ids

        # Find the embed pattern in the prompt.
129
130
131
        found_embed_pattern_idx = self._find_array(
            prompt_token_ids, embed_pattern_ids, start_idx=1
        )
132
133
134
135

        if found_embed_pattern_idx != -1:
            instruction_len = found_embed_pattern_idx + len(embed_pattern_ids)
        else:
136
137
138
139
140
            logger.warning(
                "Query instruction not found in prompt, "
                "thus using BOS token as instruction instead. "
                "GritLM requires query instruction in prompt."
            )
141
142
143
144
            instruction_len = 1

        return instruction_len

145
146
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed"}
147

148
149
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
150
151

    def forward_one(
152
153
        self,
        hidden_states: torch.Tensor,
154
155
        prompt_len: torch.Tensor | None = None,
        instr_len: torch.Tensor | None = None,
156
    ) -> torch.Tensor:
157
        assert prompt_len is None or prompt_len == hidden_states.shape[0], (
158
            "partial prefill not supported with MEAN pooling"
159
        )
160
161
162
163
164
165
166
167

        return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32)

    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
        instr_lens: torch.Tensor,
168
    ) -> list[torch.Tensor] | torch.Tensor:
169
170
171
172
        offset = 0
        pooled_data = list[torch.Tensor]()

        for prompt_len, instr_len in zip(prompt_lens, instr_lens):
173
174
175
176
177
            pooled_data.append(
                hidden_states[offset + instr_len : offset + prompt_len].mean(
                    dim=0, dtype=torch.float32
                )
            )
178
            offset += prompt_len
179

180
181
182
183
        return pooled_data

    def forward(
        self,
184
        hidden_states: torch.Tensor | list[torch.Tensor],
185
        pooling_metadata: PoolingMetadata,
186
    ) -> list[torch.Tensor] | torch.Tensor:
187
188
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
        instr_lens = torch.tensor(
189
            [
190
191
                self._get_instruction_len(token_ids.cpu().numpy())
                for token_ids in get_prompt_token_ids(pooling_metadata)
192
            ],
193
            device=prompt_lens.device,
194
195
        )

196
197
        if isinstance(hidden_states, list):
            return [
198
199
200
201
                self.forward_one(h, prompt_len, instr_len)
                for h, prompt_len, instr_len in zip(
                    hidden_states, prompt_lens, instr_lens
                )
202
            ]
203

204
        return self.forward_all(hidden_states, prompt_lens, instr_lens)
205
206


207
208
209
class GritLMPooler(Pooler):
    def __init__(self, model_config: ModelConfig):
        super().__init__()
210

211
212
        self.pooling = GritLMMeanPool(model_config)
        self.head = PoolerHead(PoolerNormalize())
213

214
215
216
217
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
218
        return self.pooling.get_pooling_updates(task)
219

220
221
222
223
224
225
226
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
227
        return pooled_data
228
229


230
231
@default_pooling_type("MEAN")
class GritLM(LlamaForCausalLM):
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    """This class implements the embedding model for parasail-ai/GritLM-7B-vllm.

    The class inherits from LlamaForCausalLM and provides a custom pooling
    layer.

    The main difference between the pooling layer in GritLM and the one in
    LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
    when pooling the hidden states.

    Embedding prompts should be in the following format:
    - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT".
    - Without instruction: "<|embed|>\nPROMPT".

    Generation prompts should be in the following format:
    - "<|user|>\nPROMPT\n<|assistant|>\n"
    """

249
250
    is_pooling_model = True

251
252
253
254
255
256
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
257
258
259
        if vllm_config.model_config.runner_type == "pooling":
            hf_config = vllm_config.model_config.hf_config
            hf_config.is_causal = False
260

261
            vllm_config.cache_config.sliding_window = None
262

263
            hf_config.sliding_window = None
264

265
        super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
266

267
268
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config is not None:
269
270
271
272
273
274
            self.pooler = DispatchPooler(
                {
                    "encode": Pooler.for_encode(pooler_config),
                    "embed": GritLMPooler(vllm_config.model_config),
                }
            )