data_pre_precess_gpqa.py 2.08 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Data preprocess for GPQA
import csv
import json
import random
from tqdm import tqdm

# Paths to data
data_path = './GPQA/gpqa_extended.csv'
output_path = './GPQA/diamond.json'

# Define the keys we want to keep
keys_to_keep = [
    'id',
    'Question',
    'Subdomain',
    'High-level domain',
    'Correct Answer',
    'Incorrect Answer 1',
    'Incorrect Answer 2',
    'Incorrect Answer 3'
]

filtered_data = []
with open(data_path, mode='r', encoding='utf-8') as csv_file:
    csv_reader = csv.DictReader(csv_file)
    for idx, row in enumerate(tqdm(csv_reader), 0):
        # Add id field
        row['id'] = idx
        # Create new dictionary with only desired keys
        filtered_row = {key: row[key] for key in keys_to_keep}

        # Extract answers and shuffle them
        answers = [
            ('Correct Answer', filtered_row['Correct Answer']),
            ('Incorrect Answer 1', filtered_row['Incorrect Answer 1']),
            ('Incorrect Answer 2', filtered_row['Incorrect Answer 2']),
            ('Incorrect Answer 3', filtered_row['Incorrect Answer 3'])
        ]
        random.shuffle(answers)

        # Assign new choices A, B, C, D in order and determine the correct choice
        choices = ['A', 'B', 'C', 'D']
        formatted_answers = []
        correct_choice = None
        for i, (label, answer) in enumerate(answers):
            choice = choices[i]
            formatted_answers.append((choice, answer))
            if label == 'Correct Answer':
                correct_choice = choice

        # Update the Question field
        formatted_choices = "\n".join([f"({choice}) {answer}" for choice, answer in formatted_answers])
        filtered_row['Question'] = f"{filtered_row['Question']} Choices:\n{formatted_choices}\n"

        # Add the Correct Choice field
        filtered_row['Correct Choice'] = correct_choice

        # Append the updated row to filtered_data
        filtered_data.append(filtered_row)

# Write the updated data to JSON
with open(output_path, mode='w', encoding='utf-8') as json_file:
    json.dump(filtered_data, json_file, indent=4, ensure_ascii=False)