"vllm/distributed/eplb/policy/__init__.py" did not exist on "e10c84e06af7264d5c0b3e7ec5604ada2eee7094"
gritlm.py 8.11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Set
4

5
import numpy as np
6
7
import torch

8
from vllm.config import ModelConfig, VllmConfig
9
from vllm.logger import init_logger
10
11
12
13
from vllm.model_executor.layers.pooler import (
    DispatchPooler,
    PoolingParamsUpdate,
)
14
15
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
from vllm.model_executor.layers.pooler.seqwise import (
16
    EmbeddingPoolerHead,
17
18
19
    SequencePooler,
    SequencePoolingMethod,
    SequencePoolingMethodOutput,
20
    get_seq_pooling_method,
21
22
)
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
23
from vllm.model_executor.models.llama import LlamaForCausalLM
24
from vllm.tasks import PoolingTask
25
from vllm.tokenizers import cached_tokenizer_from_config
26
from vllm.v1.pool.metadata import PoolingMetadata
27

28
from .interfaces_base import default_pooling_type
29

30
31
32
logger = init_logger(__name__)


33
class GritLMMeanPool(SequencePoolingMethod):
34
    """As `MeanPool`, but only includes non-instruction tokens."""
35

36
    def __init__(self, model_config: ModelConfig):
37
38
        super().__init__()

39
        self.model_config = model_config
40

41
        tokenizer = cached_tokenizer_from_config(self.model_config)
42
43
44
45
46
47
48
49
50
51

        # Collect the tokens needed for pattern matching.
        # "▁<" is different from "_<". The former uses "▁" to indicate that
        # the next token is the start of a word.
        # "<0x0A>" is the newline token (i.e. "\n")."
        self.token_ids = {
            tok: tokenizer.convert_tokens_to_ids([tok])[0]
            for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
        }

52
53
        def tokens_to_ids(tokens: list[str]) -> np.ndarray:
            return np.array([self.token_ids[token] for token in tokens])
54

55
        self.user_pattern_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"])
56
        self.embed_newline_pattern_ids = tokens_to_ids(
57
58
59
            ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]
        )
        self.embed_pattern_ids = tokens_to_ids(["▁<", "|", "embed", "|", ">", "<0x0A>"])
60

61
62
63
64
65
    def _find_array(
        self,
        arr: np.ndarray,
        target: np.ndarray,
        start_idx: int = 0,
66
        end_idx: int | None = None,
67
    ) -> int:
68
        """
69
70
        Find the first occurrence of `target` in `arr` starting from
        `start_idx`.
71
72

        Args:
73
74
75
76
            arr: The array to search within.
            target: The consecutive subsequence to find.
            start_idx: The starting index to search from (inclusive).
            end_idx: The ending index to search from (exclusive).
77
78

        Returns:
79
            The index of the first occurrence of `target` in `arr`.
80
81
        """
        if start_idx < 0:
82
83
84
            raise ValueError("`start_idx` must be non-negative")
        if len(arr) == 0 or len(target) == 0:
            raise ValueError("Empty `arr` or `target` not allowed")
85

86
        arr_len = len(arr)
87
        target_len = len(target)
88
89
90
91
92

        if end_idx is None:
            end_idx = arr_len

        for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
93
            if (arr[i : i + target_len] == target).all():
94
                return i
95

96
97
        return -1

98
    def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
99
100
101
102
103
104
105
106
107
108
109
110
111
        """
        Get the length of the instruction in the prompt.

        We do a pattern matching to find the instruction in the prompt,
        and then return the length of the instruction.

        The pattern matching is done using integers instead of strings
        because the prompt is given as a list of token IDs.
        """
        instruction_len = 0

        # Return no instruction in case of missing BOS token.
        if prompt_token_ids[0] != self.token_ids["<s>"]:
112
113
114
115
116
            logger.warning(
                "BOS token not found in prompt, "
                "thus using empty string for instruction. "
                "GritLM requires BOS token in prompt."
            )
117
118
119
120
121
            return instruction_len

        # If user pattern is found in the prompt, that means there should be
        # a newline token before the embed pattern.
        embed_pattern_ids = self.embed_pattern_ids
122
123
124
125
126
127
        if (
            self._find_array(
                prompt_token_ids, self.user_pattern_ids, start_idx=1, end_idx=2
            )
            == 1
        ):
128
129
130
            embed_pattern_ids = self.embed_newline_pattern_ids

        # Find the embed pattern in the prompt.
131
132
133
        found_embed_pattern_idx = self._find_array(
            prompt_token_ids, embed_pattern_ids, start_idx=1
        )
134
135
136
137

        if found_embed_pattern_idx != -1:
            instruction_len = found_embed_pattern_idx + len(embed_pattern_ids)
        else:
138
139
140
141
142
            logger.warning(
                "Query instruction not found in prompt, "
                "thus using BOS token as instruction instead. "
                "GritLM requires query instruction in prompt."
            )
143
144
145
146
            instruction_len = 1

        return instruction_len

147
    def get_supported_tasks(self) -> Set[PoolingTask]:
148
        return {"embed"}
149

150
151
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)
152
153
154

    def forward(
        self,
155
        hidden_states: torch.Tensor,
156
        pooling_metadata: PoolingMetadata,
157
    ) -> SequencePoolingMethodOutput:
158
        prompt_lens = pooling_metadata.prompt_lens
159
        prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
160
        instr_lens = torch.tensor(
161
            [
162
163
                self._get_instruction_len(token_ids.numpy())
                for token_ids in prompt_token_ids
164
            ],
165
            device="cpu",
166
167
        )

168
169
170
171
172
173
        offset = 0
        pooled_data = list[torch.Tensor]()
        for prompt_len, instr_len in zip(prompt_lens, instr_lens):
            pooled_data.append(
                hidden_states[offset + instr_len : offset + prompt_len].mean(
                    dim=0, dtype=torch.float32
174
                )
175
176
            )
            offset += prompt_len
177

178
        return pooled_data
179
180


181
class GritLMPooler(SequencePooler):
182
183
184
185
    def __init__(self, model_config: ModelConfig):
        pooler_config = model_config.pooler_config
        assert pooler_config is not None

186
        super().__init__(
187
188
189
190
191
            pooling=(
                GritLMMeanPool(model_config)
                if pooler_config.seq_pooling_type == "MEAN"
                else get_seq_pooling_method(pooler_config.seq_pooling_type)
            ),
192
193
194
195
            head=EmbeddingPoolerHead(
                head_dtype=model_config.head_dtype,
                activation=PoolerNormalize(),
            ),
196
        )
197
198


199
@default_pooling_type(seq_pooling_type="MEAN")
200
class GritLM(LlamaForCausalLM):
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    """This class implements the embedding model for parasail-ai/GritLM-7B-vllm.

    The class inherits from LlamaForCausalLM and provides a custom pooling
    layer.

    The main difference between the pooling layer in GritLM and the one in
    LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
    when pooling the hidden states.

    Embedding prompts should be in the following format:
    - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT".
    - Without instruction: "<|embed|>\nPROMPT".

    Generation prompts should be in the following format:
    - "<|user|>\nPROMPT\n<|assistant|>\n"
    """

218
219
    is_pooling_model = True

220
221
222
223
224
225
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
226
227
228
        if vllm_config.model_config.runner_type == "pooling":
            hf_config = vllm_config.model_config.hf_config
            hf_config.is_causal = False
229

230
            vllm_config.cache_config.sliding_window = None
231

232
            hf_config.sliding_window = None
233

234
        super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
235

236
237
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config is not None:
238
239
            self.pooler = DispatchPooler(
                {
240
                    "token_embed": pooler_for_token_embed(pooler_config),
241
                    "embed": GritLMPooler(vllm_config.model_config),
242
243
                }
            )