gritlm.py 8.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Set
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8
9
10

from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
11
12
13
14
15
16
17
from vllm.model_executor.layers.pooler import (
    DispatchPooler,
    Pooler,
    PoolerHead,
    PoolerNormalize,
    PoolingParamsUpdate,
)
18
from vllm.model_executor.models.llama import LlamaForCausalLM
19
from vllm.tasks import PoolingTask
20
from vllm.tokenizers import cached_tokenizer_from_config
21
from vllm.v1.outputs import PoolerOutput
22
from vllm.v1.pool.metadata import PoolingMetadata
23

24
from .interfaces_base import default_pooling_type
25

26
27
28
logger = init_logger(__name__)


29
30
class GritLMMeanPool(nn.Module):
    """As `MeanPool`, but only includes non-instruction tokens."""
31
32
33
34
35
36

    def __init__(self, model_config: ModelConfig):
        super().__init__()

        self.model_config = model_config

37
        tokenizer = cached_tokenizer_from_config(self.model_config)
38
39
40
41
42
43
44
45
46
47

        # Collect the tokens needed for pattern matching.
        # "▁<" is different from "_<". The former uses "▁" to indicate that
        # the next token is the start of a word.
        # "<0x0A>" is the newline token (i.e. "\n")."
        self.token_ids = {
            tok: tokenizer.convert_tokens_to_ids([tok])[0]
            for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
        }

48
49
        def tokens_to_ids(tokens: list[str]) -> np.ndarray:
            return np.array([self.token_ids[token] for token in tokens])
50

51
        self.user_pattern_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"])
52
        self.embed_newline_pattern_ids = tokens_to_ids(
53
54
55
            ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]
        )
        self.embed_pattern_ids = tokens_to_ids(["▁<", "|", "embed", "|", ">", "<0x0A>"])
56

57
58
59
60
61
    def _find_array(
        self,
        arr: np.ndarray,
        target: np.ndarray,
        start_idx: int = 0,
62
        end_idx: int | None = None,
63
    ) -> int:
64
        """
65
66
        Find the first occurrence of `target` in `arr` starting from
        `start_idx`.
67
68

        Args:
69
70
71
72
            arr: The array to search within.
            target: The consecutive subsequence to find.
            start_idx: The starting index to search from (inclusive).
            end_idx: The ending index to search from (exclusive).
73
74

        Returns:
75
            The index of the first occurrence of `target` in `arr`.
76
77
        """
        if start_idx < 0:
78
79
80
            raise ValueError("`start_idx` must be non-negative")
        if len(arr) == 0 or len(target) == 0:
            raise ValueError("Empty `arr` or `target` not allowed")
81

82
        arr_len = len(arr)
83
        target_len = len(target)
84
85
86
87
88

        if end_idx is None:
            end_idx = arr_len

        for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
89
            if (arr[i : i + target_len] == target).all():
90
                return i
91

92
93
        return -1

94
    def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
95
96
97
98
99
100
101
102
103
104
105
106
107
        """
        Get the length of the instruction in the prompt.

        We do a pattern matching to find the instruction in the prompt,
        and then return the length of the instruction.

        The pattern matching is done using integers instead of strings
        because the prompt is given as a list of token IDs.
        """
        instruction_len = 0

        # Return no instruction in case of missing BOS token.
        if prompt_token_ids[0] != self.token_ids["<s>"]:
108
109
110
111
112
            logger.warning(
                "BOS token not found in prompt, "
                "thus using empty string for instruction. "
                "GritLM requires BOS token in prompt."
            )
113
114
115
116
117
            return instruction_len

        # If user pattern is found in the prompt, that means there should be
        # a newline token before the embed pattern.
        embed_pattern_ids = self.embed_pattern_ids
118
119
120
121
122
123
        if (
            self._find_array(
                prompt_token_ids, self.user_pattern_ids, start_idx=1, end_idx=2
            )
            == 1
        ):
124
125
126
            embed_pattern_ids = self.embed_newline_pattern_ids

        # Find the embed pattern in the prompt.
127
128
129
        found_embed_pattern_idx = self._find_array(
            prompt_token_ids, embed_pattern_ids, start_idx=1
        )
130
131
132
133

        if found_embed_pattern_idx != -1:
            instruction_len = found_embed_pattern_idx + len(embed_pattern_ids)
        else:
134
135
136
137
138
            logger.warning(
                "Query instruction not found in prompt, "
                "thus using BOS token as instruction instead. "
                "GritLM requires query instruction in prompt."
            )
139
140
141
142
            instruction_len = 1

        return instruction_len

143
144
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed"}
145

146
147
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
148
149
150

    def forward(
        self,
151
        hidden_states: torch.Tensor | list[torch.Tensor],
152
        pooling_metadata: PoolingMetadata,
153
    ) -> list[torch.Tensor] | torch.Tensor:
154
        prompt_lens = pooling_metadata.prompt_lens
155
        instr_lens = torch.tensor(
156
            [
157
                self._get_instruction_len(token_ids.cpu().numpy())
158
                for token_ids in pooling_metadata.get_prompt_token_ids()
159
            ],
160
            device="cpu",
161
162
        )

163
164
165
166
167
168
        offset = 0
        pooled_data = list[torch.Tensor]()
        for prompt_len, instr_len in zip(prompt_lens, instr_lens):
            pooled_data.append(
                hidden_states[offset + instr_len : offset + prompt_len].mean(
                    dim=0, dtype=torch.float32
169
                )
170
171
            )
            offset += prompt_len
172

173
        return pooled_data
174
175


176
177
178
class GritLMPooler(Pooler):
    def __init__(self, model_config: ModelConfig):
        super().__init__()
179

180
181
        self.pooling = GritLMMeanPool(model_config)
        self.head = PoolerHead(PoolerNormalize())
182

183
184
185
186
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
187
        return self.pooling.get_pooling_updates(task)
188

189
190
191
192
193
194
195
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
196
        return pooled_data
197
198


199
200
@default_pooling_type("MEAN")
class GritLM(LlamaForCausalLM):
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    """This class implements the embedding model for parasail-ai/GritLM-7B-vllm.

    The class inherits from LlamaForCausalLM and provides a custom pooling
    layer.

    The main difference between the pooling layer in GritLM and the one in
    LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
    when pooling the hidden states.

    Embedding prompts should be in the following format:
    - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT".
    - Without instruction: "<|embed|>\nPROMPT".

    Generation prompts should be in the following format:
    - "<|user|>\nPROMPT\n<|assistant|>\n"
    """

218
219
    is_pooling_model = True

220
221
222
223
224
225
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
226
227
228
        if vllm_config.model_config.runner_type == "pooling":
            hf_config = vllm_config.model_config.hf_config
            hf_config.is_causal = False
229

230
            vllm_config.cache_config.sliding_window = None
231

232
            hf_config.sliding_window = None
233

234
        super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
235

236
237
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config is not None:
238
239
            self.pooler = DispatchPooler(
                {
240
                    "token_embed": Pooler.for_token_embed(pooler_config),
241
242
243
                    "embed": GritLMPooler(vllm_config.model_config),
                }
            )