test_prompt_template_with_reasoning.py 3.74 KB
Newer Older
zzg_666's avatar
zzg_666 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from dataflow.operators.reasoning import (
    ReasoningQuestionGenerator,
    ReasoningAnswerGenerator,
)
from dataflow.operators.reasoning import ReasoningQuestionFilter, ReasoningAnswerNgramFilter, ReasoningAnswerModelJudgeFilter
from dataflow.utils.storage import FileStorage
from dataflow.serving import APILLMServing_request
from dataflow.core import LLMServingABC
from dataflow.prompts.reasoning.general import (
    GeneralQuestionFilterPrompt,
    GeneralAnswerGeneratorPrompt,
    GeneralQuestionSynthesisPrompt,
)
from dataflow.prompts.model_evaluation.general import AnswerJudgePrompt
class GeneralReasoning_APIPipeline():
    def __init__(self, llm_serving: LLMServingABC = None):
        
        self.storage = FileStorage(
            first_entry_file_name="../dataflow/example/ReasoningPipeline/pipeline_general.json",
            cache_path="./cache_local",
            file_name_prefix="dataflow_cache_step",
            cache_type="jsonl",
        )

        # use API server as LLM serving
        self.llm_serving = APILLMServing_request(
                    api_url="http://123.129.219.111:3000/v1/chat/completions",
                    model_name="gpt-4o",
                    max_workers=30
        )

        # from dataflow.core.prompt import DIYPromptABC
        # class CustomPrompt(DIYPromptABC):
        #     def __init__(self):
        #         super().__init__()
        #     def build_prompt(self, a):
        #         print("This is a custom prompt")
        self.question_filter_step1 = ReasoningQuestionFilter(
            system_prompt="You are an expert in evaluating mathematical problems. Follow the user's instructions strictly and output your final judgment in the required JSON format.",
            llm_serving=self.llm_serving,
            prompt_template=GeneralQuestionFilterPrompt()
            # prompt_template="sdasdsa"
            # prompt_template=CustomPrompt()
        )


        print(self.question_filter_step1.ALLOWED_PROMPTS)

        print()
        
        # self.question_gen_step2 = QuestionGenerator(
        #     num_prompts=1,
        #     llm_serving=self.llm_serving,
        #     prompt_template=GeneralQuestionSynthesisPrompt()
        # )
        
        # self.answer_generator_step3 = AnswerGenerator(
        #     llm_serving=self.llm_serving,
        #     prompt_template=GeneralAnswerGeneratorPrompt()
        # )
        # self.answer_model_judge_step4 = AnswerModelJudge(
        #     llm_serving=self.llm_serving,
        #     prompt_template=AnswerJudgePrompt(),
        #     keep_all_samples=True
        # )
        # self.answer_ngram_filter_step5 = AnswerNgramFilter(
        #     min_score = 0.1,
        #     max_score = 1.0,
        #     ngrams = 5
        # )
        
    def forward(self):
        self.question_filter_step1.run(
            storage = self.storage.step(),
            input_key = "instruction",
        )
        self.question_filter_step1

        # self.question_gen_step2.run(
        #     storage = self.storage.step(),
        #     input_key = "instruction",
        # )
        # self.answer_generator_step3.run(
        #     storage = self.storage.step(),
        #     input_key = "instruction", 
        #     output_key = "generated_cot"
        # ),
        # self.answer_model_judge_step4.run(
        #     storage = self.storage.step(),
        #     input_question_key = "instruction",
        #     input_answer_key = "generated_cot",
        #     input_reference_key = "golden_answer"
        # ),
        # self.answer_ngram_filter_step5.run(
        #     storage = self.storage.step(),
        #     input_question_key = "instruction",
        #     input_answer_key = "generated_cot"
        # )

if __name__ == "__main__":
    pl = GeneralReasoning_APIPipeline()
    pl.forward()