bench_dspy_intro.py 6.14 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
Adapted from
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
"""
import argparse

import dspy
from dspy.datasets import HotPotQA


class BasicQA(dspy.Signature):
    """Answer questions with short factoid answers."""

    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")


class GenerateAnswer(dspy.Signature):
    """Answer questions with short factoid answers."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")


class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()

        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = self.retrieve(question).passages
        prediction = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=prediction.answer)


def main(args):
    #lm = dspy.OpenAI(model='gpt-3.5-turbo')
    if args.backend == "tgi":
        lm = dspy.HFClientTGI(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
            url="http://localhost")
    elif args.backend == "sglang":
        lm = dspy.HFClientSGLang(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
            url="http://localhost")
    elif args.backend == "vllm":
        lm = dspy.HFClientVLLM(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
            url="http://localhost")
    else:
        raise ValueError(f"Invalid backend: {args.backend}")

    colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
    dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)

    # Load the dataset.
    dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size,
        test_size=0)

    # Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
    trainset = [x.with_inputs('question') for x in dataset.train]
    devset = [x.with_inputs('question') for x in dataset.dev]

    print(len(trainset), len(devset))

    train_example = trainset[0]
    print(f"Question: {train_example.question}")
    print(f"Answer: {train_example.answer}")

    dev_example = devset[18]
    print(f"Question: {dev_example.question}")
    print(f"Answer: {dev_example.answer}")
    print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")

    print(f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}")
    print(f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}")

    # Define the predictor.
    generate_answer = dspy.Predict(BasicQA)
    
    # Call the predictor on a particular input.
    pred = generate_answer(question=dev_example.question)
    
    # Print the input and the prediction.
    print(f"Question: {dev_example.question}")
    print(f"Predicted Answer: {pred.answer}")

    lm.inspect_history(n=1)

    # Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.
    generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
    
    # Call the predictor on the same input.
    pred = generate_answer_with_chain_of_thought(question=dev_example.question)
    
    # Print the input, the chain of thought, and the prediction.
    print(f"Question: {dev_example.question}")
    print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
    print(f"Predicted Answer: {pred.answer}")

    retrieve = dspy.Retrieve(k=3)
    topK_passages = retrieve(dev_example.question).passages

    print(f"Top {retrieve.k} passages for question: {dev_example.question} \n", '-' * 30, '\n')

    for idx, passage in enumerate(topK_passages):
        print(f'{idx+1}]', passage, '\n')

    retrieve("When was the first FIFA World Cup held?").passages[0]

    from dspy.teleprompt import BootstrapFewShot
    
    # Validation logic: check that the predicted answer is correct.
    # Also check that the retrieved context does actually contain that answer.
    def validate_context_and_answer(example, pred, trace=None):
        answer_EM = dspy.evaluate.answer_exact_match(example, pred)
        answer_PM = dspy.evaluate.answer_passage_match(example, pred)
        return answer_EM and answer_PM
    
    # Set up a basic teleprompter, which will compile our RAG program.
    teleprompter = BootstrapFewShot(metric=validate_context_and_answer)

    # Compile!
    compiled_rag = teleprompter.compile(RAG(), trainset=trainset)

    # Ask any question you like to this simple RAG program.
    my_question = "What castle did David Gregory inherit?"
    
    # Get the prediction. This contains `pred.context` and `pred.answer`.
    pred = compiled_rag(my_question)
    
    # Print the contexts and the answer.
    print(f"Question: {my_question}")
    print(f"Predicted Answer: {pred.answer}")
    print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")

    from dspy.evaluate.evaluate import Evaluate

    # Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
    evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=args.num_threads, display_progress=True, display_table=5)
    
    # Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
    metric = dspy.evaluate.answer_exact_match
    evaluate_on_hotpotqa(compiled_rag, metric=metric)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int)
    parser.add_argument("--num-threads", type=int, default=32)
    parser.add_argument("--dev-size", type=int, default=150)
    parser.add_argument("--backend", type=str, choices=["sglang", "tgi", "vllm"],
        default="sglang")
    args = parser.parse_args()

    if args.port is None:
        default_port = {
            "vllm": 21000,
            "lightllm": 22000,
            "tgi": 24000,
            "sglang": 30000,
        }
        args.port = default_port.get(args.backend, None)

    main(args)