test_gritlm.py 7.26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import numpy as np
4
5
6
import pytest
from scipy.spatial.distance import cosine

7
from vllm import LLM, SamplingParams
8
from vllm.config import ModelConfig
9
10

from ....utils import RemoteOpenAIServer
11
from .embed_utils import run_client_embeddings
12
13
14

MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN = 4000
15
ATOL = 0.002
16
17
18
19
20
21


def _arr(arr):
    """
    Convert a list of integers to an array of integers.
    """
22
    return np.array(arr)
23
24


25
def test_find_array():
26
    from vllm.model_executor.models.gritlm import GritLMMeanPool
27

28
29
    model_config = ModelConfig(
        MODEL_NAME,
30
        runner="pooling",
31
32
33
        dtype="bfloat16",
        seed=0,
    )
34
    pooling = GritLMMeanPool(model_config=model_config)
35

36
    arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
37

38
39
40
41
42
43
    assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
    assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
    assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
    assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=3) == -1
    assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=4) == 3
    assert pooling._find_array(arr, _arr([3, 5]), start_idx=0) == -1
44

45
    with pytest.raises(ValueError):
46
        pooling._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
47
48


49
def run_llm_encode(
50
    llm: LLM,
51
52
    queries: list[str],
    instruction: str,
53
54
) -> list[list[float]]:
    outputs = llm.embed([instruction + q for q in queries])
55
56
57
58
    return [output.outputs.embedding for output in outputs]


def gritlm_instruction(instruction):
59
60
61
    return (
        "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"
    )
62
63
64
65
66
67
68
69


def get_test_data():
    """
    Grabbed this test data and the expected values from
    README.md in https://github.com/ContextualAI/gritlm
    """
    q_instruction = gritlm_instruction(
70
71
        "Given a scientific paper title, retrieve the paper's abstract",
    )
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    queries = [
        "Bitcoin: A Peer-to-Peer Electronic Cash System",
        "Generative Representational Instruction Tuning",
    ]

    d_instruction = gritlm_instruction("")
    documents = [
        # ruff: noqa: E501
        "A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
        "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
    ]

    return queries, q_instruction, documents, d_instruction


87
def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
88
    cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0])
89
    assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=ATOL)
90
91

    cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1])
92
    assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=ATOL)
93
94

    cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0])
95
    assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=ATOL)
96
97

    cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1])
98
    assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=ATOL)
99
100


101
102
def test_gritlm_offline_embedding(vllm_runner):
    queries, q_instruction, documents, d_instruction = get_test_data()
103

104
    with vllm_runner(
105
106
107
        MODEL_NAME,
        runner="pooling",
        max_model_len=MAX_MODEL_LEN,
108
    ) as vllm_model:
109
        llm = vllm_model.llm
110

111
112
113
114
115
116
117
118
119
120
        d_rep = run_llm_encode(
            llm,
            documents,
            d_instruction,
        )
        q_rep = run_llm_encode(
            llm,
            queries,
            q_instruction,
        )
121

122
    validate_embed_output(q_rep, d_rep)
123
124
125
126
127
128


@pytest.mark.asyncio
async def test_gritlm_api_server_embedding():
    queries, q_instruction, documents, d_instruction = get_test_data()

129
    args = ["--runner", "pooling", "--max_model_len", str(MAX_MODEL_LEN)]
130

131
    with RemoteOpenAIServer(MODEL_NAME, args) as server:
132
        client_embedding = server.get_async_client()
133

134
135
        d_rep = await run_client_embeddings(
            client_embedding,
136
            MODEL_NAME,
137
138
139
            documents,
            d_instruction,
        )
140
141
        q_rep = await run_client_embeddings(
            client_embedding,
142
            MODEL_NAME,
143
144
145
            queries,
            q_instruction,
        )
146

147
    validate_embed_output(q_rep, d_rep)
148
149


150
def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner):
151
    input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
152

153
    with vllm_runner(
154
155
156
        MODEL_NAME,
        runner="generate",
        max_model_len=MAX_MODEL_LEN,
157
    ) as vllm_model:
158
        llm = vllm_model.llm
159

160
161
        sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
        outputs = llm.generate(input, sampling_params=sampling_params)
162

163
    assert outputs[0].outputs[0].text == "The capital of France is Paris."
164
165
166


@pytest.mark.asyncio
167
async def test_gritlm_api_server_generate():
168
169
    input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"

170
    args = ["--runner", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
171

172
    with RemoteOpenAIServer(MODEL_NAME, args) as server:
173
174
175
176
177
178
179
180
        client_generate = server.get_async_client()

        outputs = await client_generate.completions.create(
            model=MODEL_NAME,
            prompt=input,
            max_tokens=256,
            temperature=0.0,
        )
181
182

    assert outputs.choices[0].text == "The capital of France is Paris."