gen_simple_wikipedia_v1.py 1.65 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import json
import argparse


paras = argparse.ArgumentParser()
paras.add_argument('--root_path', type=str, default='./datasets/simple_wikipedia_v1')
args = paras.parse_args()


def load_sentence_info(file_path):
    try:
        # 假设的实现,具体逻辑依据实际情况调整
        with open(file_path, 'r', encoding='utf-8') as file:
            sentences = file.readlines()
        # 去除每行句子的首尾空白字符
        return [sentence.strip() for sentence in sentences]
    except Exception as e:
        # 如果遇到异常,打印错误信息并向上抛出异常
        print(f"Failed to load sentence info from {file_path}: {e}")
        raise


def main():
    root_path = args.root_path
    if not os.path.exists(root_path):
        print(f"{root_path} not exist, please check it")
        return
    # 尝试加载句子信息
    try:
        sentence1 = load_sentence_info(os.path.join(root_path, 'wiki.simple'))
        sentence2 = load_sentence_info(os.path.join(root_path, 'wiki.unsimplified'))
    except Exception as e:
        print(f"Error loading sentence info: {e}")
        return
    # 检查两个数据集长度是否一致
    if len(sentence1) != len(sentence2):
        print(f'The simple_wikipedia_v1 data length is not equal, please check it')
        return
    # 将加载的句子对写入到文件中
    with open(os.path.join(root_path, 'simple_wiki_pair.txt'), 'w', encoding='utf-8') as wfile:
        for indx in range(len(sentence1)):
            wfile.write(json.dumps({'sentence1': sentence1[indx], 'sentence2': sentence2[indx]}, ensure_ascii=False) + '\n')
    return


if __name__ == '__main__':
    main()