bench_dspy_intro.py 6.32 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
"""
Adapted from
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
"""
Liangsheng Yin's avatar
Liangsheng Yin committed
5

Lianmin Zheng's avatar
Lianmin Zheng committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import argparse

import dspy
from dspy.datasets import HotPotQA


class BasicQA(dspy.Signature):
    """Answer questions with short factoid answers."""

    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")


class GenerateAnswer(dspy.Signature):
    """Answer questions with short factoid answers."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")


class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()

        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
Liangsheng Yin's avatar
Liangsheng Yin committed
33

Lianmin Zheng's avatar
Lianmin Zheng committed
34
35
36
37
38
39
40
    def forward(self, question):
        context = self.retrieve(question).passages
        prediction = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=prediction.answer)


def main(args):
Liangsheng Yin's avatar
Liangsheng Yin committed
41
    # lm = dspy.OpenAI(model='gpt-3.5-turbo')
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    if args.backend == "tgi":
Liangsheng Yin's avatar
Liangsheng Yin committed
43
44
45
46
47
        lm = dspy.HFClientTGI(
            model="meta-llama/Llama-2-7b-chat-hf",
            port=args.port,
            url="http://localhost",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    elif args.backend == "sglang":
Liangsheng Yin's avatar
Liangsheng Yin committed
49
50
51
52
53
        lm = dspy.HFClientSGLang(
            model="meta-llama/Llama-2-7b-chat-hf",
            port=args.port,
            url="http://localhost",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    elif args.backend == "vllm":
Liangsheng Yin's avatar
Liangsheng Yin committed
55
56
57
58
59
        lm = dspy.HFClientVLLM(
            model="meta-llama/Llama-2-7b-chat-hf",
            port=args.port,
            url="http://localhost",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
    else:
        raise ValueError(f"Invalid backend: {args.backend}")

Liangsheng Yin's avatar
Liangsheng Yin committed
63
64
65
    colbertv2_wiki17_abstracts = dspy.ColBERTv2(
        url="http://20.102.90.50:2017/wiki17_abstracts"
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68
    dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)

    # Load the dataset.
Liangsheng Yin's avatar
Liangsheng Yin committed
69
70
71
    dataset = HotPotQA(
        train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, test_size=0
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
72
73

    # Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
Liangsheng Yin's avatar
Liangsheng Yin committed
74
75
    trainset = [x.with_inputs("question") for x in dataset.train]
    devset = [x.with_inputs("question") for x in dataset.dev]
Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
78
79
80
81
82
83
84
85
86
87

    print(len(trainset), len(devset))

    train_example = trainset[0]
    print(f"Question: {train_example.question}")
    print(f"Answer: {train_example.answer}")

    dev_example = devset[18]
    print(f"Question: {dev_example.question}")
    print(f"Answer: {dev_example.answer}")
    print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")

Liangsheng Yin's avatar
Liangsheng Yin committed
88
89
90
91
92
93
    print(
        f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}"
    )
    print(
        f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}"
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
96

    # Define the predictor.
    generate_answer = dspy.Predict(BasicQA)
Liangsheng Yin's avatar
Liangsheng Yin committed
97

Lianmin Zheng's avatar
Lianmin Zheng committed
98
99
    # Call the predictor on a particular input.
    pred = generate_answer(question=dev_example.question)
Liangsheng Yin's avatar
Liangsheng Yin committed
100

Lianmin Zheng's avatar
Lianmin Zheng committed
101
102
103
104
105
106
107
108
    # Print the input and the prediction.
    print(f"Question: {dev_example.question}")
    print(f"Predicted Answer: {pred.answer}")

    lm.inspect_history(n=1)

    # Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.
    generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
Liangsheng Yin's avatar
Liangsheng Yin committed
109

Lianmin Zheng's avatar
Lianmin Zheng committed
110
111
    # Call the predictor on the same input.
    pred = generate_answer_with_chain_of_thought(question=dev_example.question)
Liangsheng Yin's avatar
Liangsheng Yin committed
112

Lianmin Zheng's avatar
Lianmin Zheng committed
113
114
115
116
117
118
119
120
    # Print the input, the chain of thought, and the prediction.
    print(f"Question: {dev_example.question}")
    print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
    print(f"Predicted Answer: {pred.answer}")

    retrieve = dspy.Retrieve(k=3)
    topK_passages = retrieve(dev_example.question).passages

Liangsheng Yin's avatar
Liangsheng Yin committed
121
122
123
124
125
    print(
        f"Top {retrieve.k} passages for question: {dev_example.question} \n",
        "-" * 30,
        "\n",
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
126
127

    for idx, passage in enumerate(topK_passages):
Liangsheng Yin's avatar
Liangsheng Yin committed
128
        print(f"{idx+1}]", passage, "\n")
Lianmin Zheng's avatar
Lianmin Zheng committed
129
130
131
132

    retrieve("When was the first FIFA World Cup held?").passages[0]

    from dspy.teleprompt import BootstrapFewShot
Liangsheng Yin's avatar
Liangsheng Yin committed
133

Lianmin Zheng's avatar
Lianmin Zheng committed
134
135
136
137
138
139
    # Validation logic: check that the predicted answer is correct.
    # Also check that the retrieved context does actually contain that answer.
    def validate_context_and_answer(example, pred, trace=None):
        answer_EM = dspy.evaluate.answer_exact_match(example, pred)
        answer_PM = dspy.evaluate.answer_passage_match(example, pred)
        return answer_EM and answer_PM
Liangsheng Yin's avatar
Liangsheng Yin committed
140

Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
143
144
145
146
147
148
    # Set up a basic teleprompter, which will compile our RAG program.
    teleprompter = BootstrapFewShot(metric=validate_context_and_answer)

    # Compile!
    compiled_rag = teleprompter.compile(RAG(), trainset=trainset)

    # Ask any question you like to this simple RAG program.
    my_question = "What castle did David Gregory inherit?"
Liangsheng Yin's avatar
Liangsheng Yin committed
149

Lianmin Zheng's avatar
Lianmin Zheng committed
150
151
    # Get the prediction. This contains `pred.context` and `pred.answer`.
    pred = compiled_rag(my_question)
Liangsheng Yin's avatar
Liangsheng Yin committed
152

Lianmin Zheng's avatar
Lianmin Zheng committed
153
154
155
156
157
158
159
160
    # Print the contexts and the answer.
    print(f"Question: {my_question}")
    print(f"Predicted Answer: {pred.answer}")
    print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")

    from dspy.evaluate.evaluate import Evaluate

    # Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
Liangsheng Yin's avatar
Liangsheng Yin committed
161
162
163
164
165
166
167
    evaluate_on_hotpotqa = Evaluate(
        devset=devset,
        num_threads=args.num_threads,
        display_progress=True,
        display_table=5,
    )

Lianmin Zheng's avatar
Lianmin Zheng committed
168
169
170
    # Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
    metric = dspy.evaluate.answer_exact_match
    evaluate_on_hotpotqa(compiled_rag, metric=metric)
Liangsheng Yin's avatar
Liangsheng Yin committed
171

Lianmin Zheng's avatar
Lianmin Zheng committed
172
173
174
175
176
177

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int)
    parser.add_argument("--num-threads", type=int, default=32)
    parser.add_argument("--dev-size", type=int, default=150)
Liangsheng Yin's avatar
Liangsheng Yin committed
178
179
180
    parser.add_argument(
        "--backend", type=str, choices=["sglang", "tgi", "vllm"], default="sglang"
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
181
182
183
184
185
186
187
188
189
190
191
192
    args = parser.parse_args()

    if args.port is None:
        default_port = {
            "vllm": 21000,
            "lightllm": 22000,
            "tgi": 24000,
            "sglang": 30000,
        }
        args.port = default_port.get(args.backend, None)

    main(args)