"tests/vscode:/vscode.git/clone" did not exist on "f53a0586b9c88a78167157296555b7664c398055"
test_gritlm.py 8.54 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
from __future__ import annotations
3

4
5
6
7
8
9
10
11
import importlib.util
import math
from array import array

import openai
import pytest
from scipy.spatial.distance import cosine

12
13
from vllm import LLM, SamplingParams
from vllm.config import ModelConfig
14
from vllm.utils import STR_BACKEND_ENV_VAR
15
16
17
18

from ....utils import RemoteOpenAIServer

# GritLM embedding implementation is only supported by XFormers backend.
19
20
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
                                reason="GritLM requires XFormers")
21
22
23
24
25
26
27
28
29
30
31
32

MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN = 4000


def _arr(arr):
    """
    Convert a list of integers to an array of integers.
    """
    return array("i", arr)


33
34
def test_find_array():
    from vllm.model_executor.models.gritlm import GritLMPooler
35

36
37
38
39
40
41
42
43
44
45
    model_config = ModelConfig(
        MODEL_NAME,
        task="embed",
        tokenizer=MODEL_NAME,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="bfloat16",
        seed=0,
    )
    pooler = GritLMPooler(model_config=model_config)
46

47
    arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
48

49
50
51
52
    assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
    assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
    assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
    assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
53

54
55
    with pytest.raises(ValueError):
        pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
56
57


58
def run_llm_encode(
59
    llm: LLM,
60
61
    queries: list[str],
    instruction: str,
62
63
) -> list[list[float]]:
    outputs = llm.embed([instruction + q for q in queries])
64
65
66
    return [output.outputs.embedding for output in outputs]


67
async def run_client_embeddings(
68
    client: openai.AsyncOpenAI,
69
70
    queries: list[str],
    instruction: str,
71
) -> list[list[float]]:
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    outputs = await client.embeddings.create(
        model=MODEL_NAME,
        input=[instruction + q for q in queries],
    )
    return [data.embedding for data in outputs.data]


def gritlm_instruction(instruction):
    return ("<|user|>\n" + instruction +
            "\n<|embed|>\n" if instruction else "<|embed|>\n")


def get_test_data():
    """
    Grabbed this test data and the expected values from
    README.md in https://github.com/ContextualAI/gritlm
    """
    q_instruction = gritlm_instruction(
90
        "Given a scientific paper title, retrieve the paper's abstract", )
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    queries = [
        "Bitcoin: A Peer-to-Peer Electronic Cash System",
        "Generative Representational Instruction Tuning",
    ]

    d_instruction = gritlm_instruction("")
    documents = [
        # ruff: noqa: E501
        "A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
        "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
    ]

    return queries, q_instruction, documents, d_instruction


106
def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
107
108
109
110
111
112
113
114
115
116
    cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0])
    assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001)

    cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1])
    assert math.isclose(cosine_sim_q0_d1, 0.101, abs_tol=0.001)

    cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0])
    assert math.isclose(cosine_sim_q1_d0, 0.120, abs_tol=0.001)

    cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1])
117
    assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001)
118
119


120
121
def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch,
                                  vllm_runner):
122
    # GritLM embedding implementation is only supported by XFormers backend.
123
124
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
125

126
        queries, q_instruction, documents, d_instruction = get_test_data()
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        with vllm_runner(
                MODEL_NAME,
                task="embed",
                max_model_len=MAX_MODEL_LEN,
        ) as vllm_model:
            llm = vllm_model.model

            d_rep = run_llm_encode(
                llm,
                documents,
                d_instruction,
            )
            q_rep = run_llm_encode(
                llm,
                queries,
                q_instruction,
            )

        validate_embed_output(q_rep, d_rep)


@pytest.mark.asyncio
async def test_gritlm_api_server_embedding():
    queries, q_instruction, documents, d_instruction = get_test_data()

    # GritLM embedding implementation is only supported by XFormers backend.
    args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
    env_dict = {STR_BACKEND_ENV_VAR: "XFORMERS"}

    with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server:
        client_embedding = server.get_async_client()
159

160
161
        d_rep = await run_client_embeddings(
            client_embedding,
162
163
164
            documents,
            d_instruction,
        )
165
166
        q_rep = await run_client_embeddings(
            client_embedding,
167
168
169
            queries,
            q_instruction,
        )
170

171
    validate_embed_output(q_rep, d_rep)
172
173


174
175
176
177
178
def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner):
    # GritLM embedding implementation is only supported by XFormers backend.
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "0")
        m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
179

180
        input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
181

182
183
184
185
186
187
        with vllm_runner(
                MODEL_NAME,
                task="generate",
                max_model_len=MAX_MODEL_LEN,
        ) as vllm_model:
            llm = vllm_model.model
188

189
190
            sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
            outputs = llm.generate(input, sampling_params=sampling_params)
191

192
        assert outputs[0].outputs[0].text == "The capital of France is Paris."
193
194
195


@pytest.mark.asyncio
196
async def test_gritlm_api_server_generate():
197
198
    input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"

199
200
201
202
203
204
205
206
207
208
209
210
211
    # GritLM embedding implementation is only supported by XFormers backend.
    args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
    env_dict = {"VLLM_USE_V1": "0", STR_BACKEND_ENV_VAR: "XFORMERS"}

    with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server:
        client_generate = server.get_async_client()

        outputs = await client_generate.completions.create(
            model=MODEL_NAME,
            prompt=input,
            max_tokens=256,
            temperature=0.0,
        )
212
213

    assert outputs.choices[0].text == "The capital of France is Paris."