test_gritlm.py 8.44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
from __future__ import annotations
3

4
5
6
7
8
9
10
11
12
13
14
import importlib.util
import math
from array import array

import openai
import pytest
import pytest_asyncio
from scipy.spatial.distance import cosine

import vllm
import vllm.config
15
from vllm.utils import STR_BACKEND_ENV_VAR
16
17
18
19

from ....utils import RemoteOpenAIServer

# GritLM embedding implementation is only supported by XFormers backend.
20
21
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
                                reason="GritLM requires XFormers")
22
23
24
25
26
27
28
29
30
31
32
33

MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN = 4000


def _arr(arr):
    """
    Convert a list of integers to an array of integers.
    """
    return array("i", arr)


34
def test_find_array(monkeypatch: pytest.MonkeyPatch):
35
    # GritLM embedding implementation is only supported by XFormers backend.
36
37
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
38

39
        from vllm.model_executor.models.gritlm import GritLMPooler
40

41
42
43
        # Create an LLM object to get the model config.
        llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
        pooler = GritLMPooler(model_config=llm.llm_engine.model_config)
44

45
        arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
46

47
48
49
50
        assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
        assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
        assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
        assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
51

52
53
        with pytest.raises(ValueError):
            pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
54
55
56
57
58


@pytest.fixture(scope="module")
def server_embedding():
    # GritLM embedding implementation is only supported by XFormers backend.
59
60
61
    args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server
62
63
64
65
66
67
68
69
70
71


@pytest.fixture(scope="module")
def server_generate():
    args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
72
73
74
75
76
77
async def client_embedding(monkeypatch: pytest.MonkeyPatch,
                           server_embedding: RemoteOpenAIServer):
    with monkeypatch.context() as m:
        m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
        async with server_embedding.get_async_client() as async_client:
            yield async_client
78
79
80
81
82
83
84
85


@pytest_asyncio.fixture
async def client_generate(server_generate: RemoteOpenAIServer):
    async with server_generate.get_async_client() as async_client:
        yield async_client


86
87
88
89
90
def run_llm_encode(
    llm: vllm.LLM,
    queries: list[str],
    instruction: str,
) -> list[float]:
91
92
93
94
    outputs = llm.encode([instruction + q for q in queries], )
    return [output.outputs.embedding for output in outputs]


95
96
97
98
99
async def run_client_embeddings(
    client: vllm.LLM,
    queries: list[str],
    instruction: str,
) -> list[float]:
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    outputs = await client.embeddings.create(
        model=MODEL_NAME,
        input=[instruction + q for q in queries],
    )
    return [data.embedding for data in outputs.data]


def gritlm_instruction(instruction):
    return ("<|user|>\n" + instruction +
            "\n<|embed|>\n" if instruction else "<|embed|>\n")


def get_test_data():
    """
    Grabbed this test data and the expected values from
    README.md in https://github.com/ContextualAI/gritlm
    """
    q_instruction = gritlm_instruction(
118
        "Given a scientific paper title, retrieve the paper's abstract", )
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    queries = [
        "Bitcoin: A Peer-to-Peer Electronic Cash System",
        "Generative Representational Instruction Tuning",
    ]

    d_instruction = gritlm_instruction("")
    documents = [
        # ruff: noqa: E501
        "A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
        "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
    ]

    return queries, q_instruction, documents, d_instruction


134
def validate_embed_output(q_rep: list[float], d_rep: list[float]):
135
136
137
138
139
140
141
142
143
144
145
146
147
    cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0])
    assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001)

    cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1])
    assert math.isclose(cosine_sim_q0_d1, 0.101, abs_tol=0.001)

    cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0])
    assert math.isclose(cosine_sim_q1_d0, 0.120, abs_tol=0.001)

    cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1])
    assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001)


148
def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch):
149
    # GritLM embedding implementation is only supported by XFormers backend.
150
151
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
152

153
        queries, q_instruction, documents, d_instruction = get_test_data()
154

155
        llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
156

157
158
159
160
161
162
163
164
165
166
        d_rep = run_llm_encode(
            llm,
            documents,
            d_instruction,
        )
        q_rep = run_llm_encode(
            llm,
            queries,
            q_instruction,
        )
167

168
        validate_embed_output(q_rep, d_rep)
169
170
171
172


@pytest.mark.asyncio
async def test_gritlm_api_server_embedding(
173
    client_embedding: openai.AsyncOpenAI, ):
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    queries, q_instruction, documents, d_instruction = get_test_data()

    d_rep = await run_client_embeddings(
        client_embedding,
        documents,
        d_instruction,
    )
    q_rep = await run_client_embeddings(
        client_embedding,
        queries,
        q_instruction,
    )

    validate_embed_output(q_rep, d_rep)


def test_gritlm_offline_gen():
    input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"

    llm = vllm.LLM(MODEL_NAME, max_model_len=MAX_MODEL_LEN)
    sampling_params = vllm.SamplingParams(temperature=0.0, max_tokens=256)
    outputs = llm.generate(input, sampling_params=sampling_params)

    assert outputs[0].outputs[0].text == "The capital of France is Paris."


@pytest.mark.asyncio
async def test_gritlm_api_server_gen(client_generate: openai.AsyncOpenAI):
    input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"

    outputs = await client_generate.completions.create(
        model=MODEL_NAME,
        prompt=input,
        max_tokens=256,
        temperature=0.0,
    )

    assert outputs.choices[0].text == "The capital of France is Paris."