gritlm.py 9.07 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Set
4
from typing import Optional, Union
5

6
import numpy as np
7
import torch
8
import torch.nn as nn
9
10
11

from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
12
13
14
15
16
17
18
19
20
from vllm.model_executor.layers.pooler import (
    DispatchPooler,
    Pooler,
    PoolerHead,
    PoolerNormalize,
    PoolingParamsUpdate,
    get_prompt_lens,
    get_prompt_token_ids,
)
21
from vllm.model_executor.models.llama import LlamaForCausalLM
22
from vllm.tasks import PoolingTask
23
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
24
from vllm.v1.outputs import PoolerOutput
25
from vllm.v1.pool.metadata import PoolingMetadata
26

27
from .interfaces_base import default_pooling_type
28

29
30
31
logger = init_logger(__name__)


32
33
class GritLMMeanPool(nn.Module):
    """As `MeanPool`, but only includes non-instruction tokens."""
34
35
36
37
38
39

    def __init__(self, model_config: ModelConfig):
        super().__init__()

        self.model_config = model_config

40
        tokenizer = cached_tokenizer_from_config(self.model_config)
41
42
43
44
45
46
47
48
49
50

        # Collect the tokens needed for pattern matching.
        # "▁<" is different from "_<". The former uses "▁" to indicate that
        # the next token is the start of a word.
        # "<0x0A>" is the newline token (i.e. "\n")."
        self.token_ids = {
            tok: tokenizer.convert_tokens_to_ids([tok])[0]
            for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
        }

51
52
        def tokens_to_ids(tokens: list[str]) -> np.ndarray:
            return np.array([self.token_ids[token] for token in tokens])
53

54
        self.user_pattern_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"])
55
        self.embed_newline_pattern_ids = tokens_to_ids(
56
57
58
            ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]
        )
        self.embed_pattern_ids = tokens_to_ids(["▁<", "|", "embed", "|", ">", "<0x0A>"])
59

60
61
62
63
64
65
66
    def _find_array(
        self,
        arr: np.ndarray,
        target: np.ndarray,
        start_idx: int = 0,
        end_idx: Optional[int] = None,
    ) -> int:
67
        """
68
69
        Find the first occurrence of `target` in `arr` starting from
        `start_idx`.
70
71

        Args:
72
73
74
75
            arr: The array to search within.
            target: The consecutive subsequence to find.
            start_idx: The starting index to search from (inclusive).
            end_idx: The ending index to search from (exclusive).
76
77

        Returns:
78
            The index of the first occurrence of `target` in `arr`.
79
80
        """
        if start_idx < 0:
81
82
83
            raise ValueError("`start_idx` must be non-negative")
        if len(arr) == 0 or len(target) == 0:
            raise ValueError("Empty `arr` or `target` not allowed")
84

85
        arr_len = len(arr)
86
        target_len = len(target)
87
88
89
90
91

        if end_idx is None:
            end_idx = arr_len

        for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
92
            if (arr[i : i + target_len] == target).all():
93
                return i
94

95
96
        return -1

97
    def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
98
99
100
101
102
103
104
105
106
107
108
109
110
        """
        Get the length of the instruction in the prompt.

        We do a pattern matching to find the instruction in the prompt,
        and then return the length of the instruction.

        The pattern matching is done using integers instead of strings
        because the prompt is given as a list of token IDs.
        """
        instruction_len = 0

        # Return no instruction in case of missing BOS token.
        if prompt_token_ids[0] != self.token_ids["<s>"]:
111
112
113
114
115
            logger.warning(
                "BOS token not found in prompt, "
                "thus using empty string for instruction. "
                "GritLM requires BOS token in prompt."
            )
116
117
118
119
120
            return instruction_len

        # If user pattern is found in the prompt, that means there should be
        # a newline token before the embed pattern.
        embed_pattern_ids = self.embed_pattern_ids
121
122
123
124
125
126
        if (
            self._find_array(
                prompt_token_ids, self.user_pattern_ids, start_idx=1, end_idx=2
            )
            == 1
        ):
127
128
129
            embed_pattern_ids = self.embed_newline_pattern_ids

        # Find the embed pattern in the prompt.
130
131
132
        found_embed_pattern_idx = self._find_array(
            prompt_token_ids, embed_pattern_ids, start_idx=1
        )
133
134
135
136

        if found_embed_pattern_idx != -1:
            instruction_len = found_embed_pattern_idx + len(embed_pattern_ids)
        else:
137
138
139
140
141
            logger.warning(
                "Query instruction not found in prompt, "
                "thus using BOS token as instruction instead. "
                "GritLM requires query instruction in prompt."
            )
142
143
144
145
            instruction_len = 1

        return instruction_len

146
147
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed"}
148

149
150
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
151
152

    def forward_one(
153
154
        self,
        hidden_states: torch.Tensor,
155
156
157
        prompt_len: Optional[torch.Tensor] = None,
        instr_len: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
158
        assert prompt_len is None or prompt_len == hidden_states.shape[0], (
159
            "partial prefill not supported with MEAN pooling"
160
        )
161
162
163
164
165
166
167
168
169
170
171
172
173

        return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32)

    def forward_all(
        self,
        hidden_states: torch.Tensor,
        prompt_lens: torch.Tensor,
        instr_lens: torch.Tensor,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        offset = 0
        pooled_data = list[torch.Tensor]()

        for prompt_len, instr_len in zip(prompt_lens, instr_lens):
174
175
176
177
178
            pooled_data.append(
                hidden_states[offset + instr_len : offset + prompt_len].mean(
                    dim=0, dtype=torch.float32
                )
            )
179
            offset += prompt_len
180

181
182
183
184
185
186
187
188
189
        return pooled_data

    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[list[torch.Tensor], torch.Tensor]:
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
        instr_lens = torch.tensor(
190
            [
191
192
                self._get_instruction_len(token_ids.cpu().numpy())
                for token_ids in get_prompt_token_ids(pooling_metadata)
193
            ],
194
            device=prompt_lens.device,
195
196
        )

197
198
        if isinstance(hidden_states, list):
            return [
199
200
201
202
                self.forward_one(h, prompt_len, instr_len)
                for h, prompt_len, instr_len in zip(
                    hidden_states, prompt_lens, instr_lens
                )
203
            ]
204

205
        return self.forward_all(hidden_states, prompt_lens, instr_lens)
206
207


208
209
210
class GritLMPooler(Pooler):
    def __init__(self, model_config: ModelConfig):
        super().__init__()
211

212
213
        self.pooling = GritLMMeanPool(model_config)
        self.head = PoolerHead(PoolerNormalize())
214

215
216
217
218
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
219
        return self.pooling.get_pooling_updates(task)
220

221
222
223
224
225
226
227
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
228
        return pooled_data
229
230


231
232
@default_pooling_type("MEAN")
class GritLM(LlamaForCausalLM):
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    """This class implements the embedding model for parasail-ai/GritLM-7B-vllm.

    The class inherits from LlamaForCausalLM and provides a custom pooling
    layer.

    The main difference between the pooling layer in GritLM and the one in
    LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
    when pooling the hidden states.

    Embedding prompts should be in the following format:
    - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT".
    - Without instruction: "<|embed|>\nPROMPT".

    Generation prompts should be in the following format:
    - "<|user|>\nPROMPT\n<|assistant|>\n"
    """

250
251
    is_pooling_model = True

252
253
254
255
256
257
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
258
259
260
        if vllm_config.model_config.runner_type == "pooling":
            hf_config = vllm_config.model_config.hf_config
            hf_config.is_causal = False
261

262
            vllm_config.cache_config.sliding_window = None
263

264
            hf_config.sliding_window = None
265

266
        super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
267

268
269
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config is not None:
270
271
272
273
274
275
            self.pooler = DispatchPooler(
                {
                    "encode": Pooler.for_encode(pooler_config),
                    "embed": GritLMPooler(vllm_config.model_config),
                }
            )