gritlm.py 8.09 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Set
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8
9
10

from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
11
12
13
14
15
16
17
18
19
from vllm.model_executor.layers.pooler import (
    DispatchPooler,
    Pooler,
    PoolerHead,
    PoolerNormalize,
    PoolingParamsUpdate,
    get_prompt_lens,
    get_prompt_token_ids,
)
20
from vllm.model_executor.models.llama import LlamaForCausalLM
21
from vllm.tasks import PoolingTask
22
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
23
from vllm.v1.outputs import PoolerOutput
24
from vllm.v1.pool.metadata import PoolingMetadata
25

26
from .interfaces_base import default_pooling_type
27

28
29
30
logger = init_logger(__name__)


31
32
class GritLMMeanPool(nn.Module):
    """As `MeanPool`, but only includes non-instruction tokens."""
33
34
35
36
37
38

    def __init__(self, model_config: ModelConfig):
        super().__init__()

        self.model_config = model_config

39
        tokenizer = cached_tokenizer_from_config(self.model_config)
40
41
42
43
44
45
46
47
48
49

        # Collect the tokens needed for pattern matching.
        # "▁<" is different from "_<". The former uses "▁" to indicate that
        # the next token is the start of a word.
        # "<0x0A>" is the newline token (i.e. "\n")."
        self.token_ids = {
            tok: tokenizer.convert_tokens_to_ids([tok])[0]
            for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
        }

50
51
        def tokens_to_ids(tokens: list[str]) -> np.ndarray:
            return np.array([self.token_ids[token] for token in tokens])
52

53
        self.user_pattern_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"])
54
        self.embed_newline_pattern_ids = tokens_to_ids(
55
56
57
            ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]
        )
        self.embed_pattern_ids = tokens_to_ids(["▁<", "|", "embed", "|", ">", "<0x0A>"])
58

59
60
61
62
63
    def _find_array(
        self,
        arr: np.ndarray,
        target: np.ndarray,
        start_idx: int = 0,
64
        end_idx: int | None = None,
65
    ) -> int:
66
        """
67
68
        Find the first occurrence of `target` in `arr` starting from
        `start_idx`.
69
70

        Args:
71
72
73
74
            arr: The array to search within.
            target: The consecutive subsequence to find.
            start_idx: The starting index to search from (inclusive).
            end_idx: The ending index to search from (exclusive).
75
76

        Returns:
77
            The index of the first occurrence of `target` in `arr`.
78
79
        """
        if start_idx < 0:
80
81
82
            raise ValueError("`start_idx` must be non-negative")
        if len(arr) == 0 or len(target) == 0:
            raise ValueError("Empty `arr` or `target` not allowed")
83

84
        arr_len = len(arr)
85
        target_len = len(target)
86
87
88
89
90

        if end_idx is None:
            end_idx = arr_len

        for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
91
            if (arr[i : i + target_len] == target).all():
92
                return i
93

94
95
        return -1

96
    def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
97
98
99
100
101
102
103
104
105
106
107
108
109
        """
        Get the length of the instruction in the prompt.

        We do a pattern matching to find the instruction in the prompt,
        and then return the length of the instruction.

        The pattern matching is done using integers instead of strings
        because the prompt is given as a list of token IDs.
        """
        instruction_len = 0

        # Return no instruction in case of missing BOS token.
        if prompt_token_ids[0] != self.token_ids["<s>"]:
110
111
112
113
114
            logger.warning(
                "BOS token not found in prompt, "
                "thus using empty string for instruction. "
                "GritLM requires BOS token in prompt."
            )
115
116
117
118
119
            return instruction_len

        # If user pattern is found in the prompt, that means there should be
        # a newline token before the embed pattern.
        embed_pattern_ids = self.embed_pattern_ids
120
121
122
123
124
125
        if (
            self._find_array(
                prompt_token_ids, self.user_pattern_ids, start_idx=1, end_idx=2
            )
            == 1
        ):
126
127
128
            embed_pattern_ids = self.embed_newline_pattern_ids

        # Find the embed pattern in the prompt.
129
130
131
        found_embed_pattern_idx = self._find_array(
            prompt_token_ids, embed_pattern_ids, start_idx=1
        )
132
133
134
135

        if found_embed_pattern_idx != -1:
            instruction_len = found_embed_pattern_idx + len(embed_pattern_ids)
        else:
136
137
138
139
140
            logger.warning(
                "Query instruction not found in prompt, "
                "thus using BOS token as instruction instead. "
                "GritLM requires query instruction in prompt."
            )
141
142
143
144
            instruction_len = 1

        return instruction_len

145
146
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"encode", "embed"}
147

148
149
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
150
151
152

    def forward(
        self,
153
        hidden_states: torch.Tensor | list[torch.Tensor],
154
        pooling_metadata: PoolingMetadata,
155
    ) -> list[torch.Tensor] | torch.Tensor:
156
157
        prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
        instr_lens = torch.tensor(
158
            [
159
160
                self._get_instruction_len(token_ids.cpu().numpy())
                for token_ids in get_prompt_token_ids(pooling_metadata)
161
            ],
162
            device="cpu",
163
164
        )

165
166
167
168
169
170
        offset = 0
        pooled_data = list[torch.Tensor]()
        for prompt_len, instr_len in zip(prompt_lens, instr_lens):
            pooled_data.append(
                hidden_states[offset + instr_len : offset + prompt_len].mean(
                    dim=0, dtype=torch.float32
171
                )
172
173
            )
            offset += prompt_len
174

175
        return pooled_data
176
177


178
179
180
class GritLMPooler(Pooler):
    def __init__(self, model_config: ModelConfig):
        super().__init__()
181

182
183
        self.pooling = GritLMMeanPool(model_config)
        self.head = PoolerHead(PoolerNormalize())
184

185
186
187
188
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
189
        return self.pooling.get_pooling_updates(task)
190

191
192
193
194
195
196
197
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
198
        return pooled_data
199
200


201
202
@default_pooling_type("MEAN")
class GritLM(LlamaForCausalLM):
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    """This class implements the embedding model for parasail-ai/GritLM-7B-vllm.

    The class inherits from LlamaForCausalLM and provides a custom pooling
    layer.

    The main difference between the pooling layer in GritLM and the one in
    LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
    when pooling the hidden states.

    Embedding prompts should be in the following format:
    - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT".
    - Without instruction: "<|embed|>\nPROMPT".

    Generation prompts should be in the following format:
    - "<|user|>\nPROMPT\n<|assistant|>\n"
    """

220
221
    is_pooling_model = True

222
223
224
225
226
227
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
228
229
230
        if vllm_config.model_config.runner_type == "pooling":
            hf_config = vllm_config.model_config.hf_config
            hf_config.is_causal = False
231

232
            vllm_config.cache_config.sliding_window = None
233

234
            hf_config.sliding_window = None
235

236
        super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
237

238
239
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config is not None:
240
241
            self.pooler = DispatchPooler(
                {
242
                    "token_embed": Pooler.for_token_embed(pooler_config),
243
244
245
                    "embed": GritLMPooler(vllm_config.model_config),
                }
            )