mmlu_test.py 7.27 KB
Newer Older
liam's avatar
liam committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import argparse
import random
import time
import json
import requests
import pandas as pd
from datasets import load_dataset

import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['https_proxy'] = ''
os.environ['http_proxy'] = ''
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'


class DataEvaluator:
    def __init__(self):
        # self.template_prompt = template_prompt
        self.data = []

    def load_data(self, file_path):
        """
        Load data from a Parquet file into a list.
        Each record in the Parquet file should represent an individual record.
        """
        # 读取 Parquet 文件
        # dataset = load_dataset('parquet', data_files=file_path)
qiyuxinlin's avatar
qiyuxinlin committed
28
29
30
31
        splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
                  'dev': 'all/dev-00000-of-00001.parquet',
                  'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
        df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
liam's avatar
liam committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

        for _, row in df.iterrows():
            self.data.append(row.to_dict())

    def get_prompt(self, record):
        """
        Combine fields from a record with the template prompt to create a full prompt.
        :param record: Dictionary containing fields to populate the template.
        :return: A formatted prompt string.
        """
        # 查看ABCD。。。的选项
        options_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(record['choices'])])
        prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '"
        return prompt
        
    def post_processing(self, text):
        """
        Perform post-processing on the prediction string.
        :param text: The raw prediction string.
        :return: Processed prediction string.
        """
liam's avatar
liam committed
53
54
        text = text.lstrip('\n').split('\n')[-1]
        return text[-1:]
liam's avatar
liam committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    def score(self, pred, answers):
        """
        Calculate scores between the prediction and the answer.
        Uses ROUGE scores as the evaluation metric.
        :param pred: The predicted string.
        :param answer: The reference answer string.
        :return: A dictionary containing ROUGE scores.
        """
        for answer in answers:
            if pred == answer:
                return 1

        return 0

# Function to generate text using API
def generate_text(api_url, question, model_name, stream=False):
    headers = {
        'accept': 'application/json',
        'Content-Type': 'application/json',
        # 添加 API Key
        'Authorization' : 'Bearer '
    }
    data = {
        "messages": [{"content": question, "role": "user"}],
        "model": model_name,
        "stream": stream,
liam's avatar
liam committed
82
        # "temperature": 0.0
liam's avatar
liam committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    }
    
    print("POST data:", data)
    response = requests.post(api_url, headers=headers, json=data)
    
    if response.status_code == 200:
        result = response.json()
        return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()
    else:
        print(f"API Request failed with status code {response.status_code}")
        return None

# Main function to handle multiple evaluations
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
    start_total_time = time.time()

    total_score = 0

    results = []
   # 设置随机数种子
    random.seed(42)
    random.shuffle(data_evaluator.data)
    for i in range(min(concurrent_requests, len(data_evaluator.data))):
        # Randomly select a data item from data for each request
        data_item = data_evaluator.data[i]
        question = data_evaluator.get_prompt(data_item)
        # print(question)

        # Start the timer for this evaluation
        start_time = time.time()
        try:
            # Generate prediction using the API
            prediction = generate_text(api_url, question, model_name)

            if prediction is None:
                raise Exception(f"Failed to get prediction for {question}")

            answer = chr(data_item['answer'] + 65)
            # Compute score
            score = data_evaluator.score(data_evaluator.post_processing(prediction), answer)

            # Calculate the time taken
            elapsed_time = time.time() - start_time

            # Collect the result data
            result_data = {
                "question_id": i,
                "answer": answer,
                "prediction": data_evaluator.post_processing(prediction),
                "score": score,
                "time": elapsed_time
            }

            # Write results to result.json with each field on a new line
            with open(result_file, 'a', encoding='utf-8') as f:
                json.dump(result_data, f, ensure_ascii=False, indent=4)
                f.write("\n")  # Ensure each JSON object is on a new line

            results.append(result_data)

            # Aggregate scores
            total_score += score

        except Exception as e:
            print(f"Error processing request {i}: {e}")

    # Calculate total time and throughput
    total_time = time.time() - start_total_time
    throughput = concurrent_requests / total_time

    # Log the total time, throughput, and average ROUGE scores
    with open(log_file, 'a', encoding='utf-8') as log_f:
        log_f.write(f"Total Time: {total_time:.2f} seconds\n")
        log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
        log_f.write(f"Average Scores: {total_score / concurrent_requests}\n")
        log_f.write('-' * 40 + '\n')

    print(f"Results saved to {result_file}")
    print(f"Log saved to {log_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="API Generate Tester")
    parser.add_argument("--concurrent", type=int, default=1000, help="Number of concurrent evaluations")
    parser.add_argument("--file", type=str, default="cais/mmlu", help="Path to the mmlu.jsonl file")
    parser.add_argument("--result", type=str, default="./mmlu_result_silicon.json", help="Path to save the result JSON file")
    parser.add_argument("--log", type=str, default="./mmlu_result_silicon.log", help="Path to save the log file")
    parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
    parser.add_argument("--api_url", type=str, default="http://localhost:10003/v1/chat/completions", help="API URL")
    # parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")

    args = parser.parse_args()

    # Load the data from the provided file
    # template_prompt = hint + "\nQuestion: {question}\nA. {options}\nB. {option_b}\nC. {option_c}\nD. {option_d}\nAnswer: '"
    # template_prompt_pro = hint + "\nQuestion: {question}\nA. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nE. {options[4]}\nF. {options[5]}\nG. \
        # {options[6]}\nH. {options[7]}\nI. {options[8]}\nJ. {options[9]}\nAnswer: '"


    # Load the data from the provided file
    data_evaluator = DataEvaluator()
    data_evaluator.load_data(args.file)

    # Run the main function with the specified number of concurrent evaluations
    main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)