gritlm.py 9.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional, Union
5

6
import numpy as np
7
import torch
8
import torch.nn as nn
9
from typing_extensions import assert_never
10
11
12

from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
13
14
15
16
17
from vllm.model_executor.layers.pooler import (Pooler, PoolerHead,
                                               PoolerNormalize,
                                               PoolingParamsUpdate,
                                               build_output, get_prompt_lens,
                                               get_prompt_token_ids)
18
from vllm.model_executor.models.llama import LlamaForCausalLM
19
20
21
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import PoolerOutput
22
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
23

24
25
from .interfaces import SupportsV0Only

26
27
28
logger = init_logger(__name__)


29
30
class GritLMMeanPool(nn.Module):
    """As `MeanPool`, but only includes non-instruction tokens."""
31
32
33
34
35
36

    def __init__(self, model_config: ModelConfig):
        super().__init__()

        self.model_config = model_config

37
        tokenizer = cached_tokenizer_from_config(self.model_config)
38
39
40
41
42
43
44
45
46
47

        # Collect the tokens needed for pattern matching.
        # "▁<" is different from "_<". The former uses "▁" to indicate that
        # the next token is the start of a word.
        # "<0x0A>" is the newline token (i.e. "\n")."
        self.token_ids = {
            tok: tokenizer.convert_tokens_to_ids([tok])[0]
            for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
        }

48
49
        def tokens_to_ids(tokens: list[str]) -> np.ndarray:
            return np.array([self.token_ids[token] for token in tokens])
50
51
52
53
54
55
56
57

        self.user_pattern_ids = tokens_to_ids(
            ["▁<", "|", "user", "|", ">", "<0x0A>"])
        self.embed_newline_pattern_ids = tokens_to_ids(
            ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"])
        self.embed_pattern_ids = tokens_to_ids(
            ["▁<", "|", "embed", "|", ">", "<0x0A>"])

58
59
60
61
62
63
64
    def _find_array(
        self,
        arr: np.ndarray,
        target: np.ndarray,
        start_idx: int = 0,
        end_idx: Optional[int] = None,
    ) -> int:
65
        """
66
67
        Find the first occurrence of `target` in `arr` starting from
        `start_idx`.
68
69

        Args:
70
71
72
73
            arr: The array to search within.
            target: The consecutive subsequence to find.
            start_idx: The starting index to search from (inclusive).
            end_idx: The ending index to search from (exclusive).
74
75

        Returns:
76
            The index of the first occurrence of `target` in `arr`.
77
78
        """
        if start_idx < 0:
79
80
81
            raise ValueError("`start_idx` must be non-negative")
        if len(arr) == 0 or len(target) == 0:
            raise ValueError("Empty `arr` or `target` not allowed")
82

83
        arr_len = len(arr)
84
        target_len = len(target)
85
86
87
88
89
90

        if end_idx is None:
            end_idx = arr_len

        for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
            if (arr[i:i + target_len] == target).all():
91
                return i
92

93
94
        return -1

95
    def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
96
97
98
99
100
101
102
103
104
105
106
107
108
        """
        Get the length of the instruction in the prompt.

        We do a pattern matching to find the instruction in the prompt,
        and then return the length of the instruction.

        The pattern matching is done using integers instead of strings
        because the prompt is given as a list of token IDs.
        """
        instruction_len = 0

        # Return no instruction in case of missing BOS token.
        if prompt_token_ids[0] != self.token_ids["<s>"]:
109
110
            logger.warning("BOS token not found in prompt, "
                           "thus using empty string for instruction. "
111
112
113
114
115
116
117
118
                           "GritLM requires BOS token in prompt.")
            return instruction_len

        # If user pattern is found in the prompt, that means there should be
        # a newline token before the embed pattern.
        embed_pattern_ids = self.embed_pattern_ids
        if self._find_array(prompt_token_ids,
                            self.user_pattern_ids,
119
120
                            start_idx=1,
                            end_idx=2) == 1:
121
122
123
124
125
126
127
128
129
130
            embed_pattern_ids = self.embed_newline_pattern_ids

        # Find the embed pattern in the prompt.
        found_embed_pattern_idx = self._find_array(prompt_token_ids,
                                                   embed_pattern_ids,
                                                   start_idx=1)

        if found_embed_pattern_idx != -1:
            instruction_len = found_embed_pattern_idx + len(embed_pattern_ids)
        else:
131
132
            logger.warning("Query instruction not found in prompt, "
                           "thus using BOS token as instruction instead. "
133
134
135
136
137
                           "GritLM requires query instruction in prompt.")
            instruction_len = 1

        return instruction_len

138
139
140
141
142
143
144
145
146
147
148
149
150
151
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
        # The equalities are split up to keep mypy happy
        if task == "encode" or task == "embed":
            return PoolingParamsUpdate(requires_token_ids=True)

        if task == "classify" or task == "score":
            return None

        assert_never(task)

    def forward_one(
152
153
        self,
        hidden_states: torch.Tensor,
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        prompt_len: Optional[torch.Tensor] = None,
        instr_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert prompt_len is None or prompt_len == hidden_states.shape[0], \
            "partial prefill not supported with MEAN pooling"

        return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32)

    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
        instr_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        offset = 0
        pooled_data = list[torch.Tensor]()

        for prompt_len, instr_len in zip(prompt_lens, instr_lens):
            pooled_data.append(hidden_states[offset + instr_len:offset +
                                             prompt_len].mean(
                                                 dim=0, dtype=torch.float32))
            offset += prompt_len
176

177
178
179
180
181
182
183
184
185
        return pooled_data

    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
        instr_lens = torch.tensor(
186
            [
187
188
                self._get_instruction_len(token_ids.cpu().numpy())
                for token_ids in get_prompt_token_ids(pooling_metadata)
189
            ],
190
            device=prompt_lens.device,
191
192
        )

193
194
195
196
197
        if isinstance(hidden_states, list):
            return [
                self.forward_one(h, prompt_len, instr_len) for h, prompt_len,
                instr_len in zip(hidden_states, prompt_lens, instr_lens)
            ]
198

199
        return self.forward_all(hidden_states, prompt_lens, instr_lens)
200
201


202
class GritLMPooler(Pooler):
203

204
205
    def __init__(self, model_config: ModelConfig):
        super().__init__()
206

207
208
        self.pooling = GritLMMeanPool(model_config)
        self.head = PoolerHead(PoolerNormalize())
209

210
211
212
213
214
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
        return self.pooling.get_pooling_updates(task)
215

216
217
218
219
220
221
222
223
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return build_output(pooled_data)
224
225


226
class GritLM(LlamaForCausalLM, SupportsV0Only):
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    """This class implements the embedding model for parasail-ai/GritLM-7B-vllm.

    The class inherits from LlamaForCausalLM and provides a custom pooling
    layer.

    The main difference between the pooling layer in GritLM and the one in
    LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
    when pooling the hidden states.

    Embedding prompts should be in the following format:
    - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT".
    - Without instruction: "<|embed|>\nPROMPT".

    Generation prompts should be in the following format:
    - "<|user|>\nPROMPT\n<|assistant|>\n"
    """

244
245
    is_pooling_model = True

246
247
248
249
250
251
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
252
        # Use full attention for pooling (this is why V1 is not supported yet)
253
254
255
        if vllm_config.model_config.runner_type == "pooling":
            hf_config = vllm_config.model_config.hf_config
            hf_config.is_causal = False
256

257
            vllm_config.cache_config.sliding_window = None
258

259
260
261
            for attr in ("sliding_window", "interleaved_sliding_window"):
                if hasattr(hf_config, attr):
                    delattr(hf_config, attr)
262

263
        super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
264

265
        self.pooler = GritLMPooler(vllm_config.model_config)