"applications/llm/count/src/bin/mock_worker.rs" did not exist on "6e0cfbd967147e4d48ab0542127760939c0a2b68"
generate_conversation_dataset.py 2.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import argparse
import json

from datasets import load_dataset


def generate_alpaca():
    # We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
    conversation_dataset = []
    dataset = load_dataset("tatsu-lab/alpaca", split="train")

    instructions = dataset["instruction"]
    inputs = dataset["input"]
    outputs = dataset["output"]

    assert len(instructions) == len(inputs) == len(outputs)

    for idx in range(len(instructions)):
        human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
        human = {"from": "human", "value": human_utterance}

        gpt_utterance = outputs[idx]
        gpt = {"from": "gpt", "value": gpt_utterance}

        conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
        conversation_dataset.append(conversation)

    return conversation_dataset


def generate_sharegpt():
    # ShareGPT data requires less processing.
    conversation_dataset = []
34
35
36
37
38
    dataset = load_dataset(
        "anon8231489123/ShareGPT_Vicuna_unfiltered",
        data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
        split="train",
    )
39
40
41
42
43
44
45
46
47

    conversations = dataset["conversations"]

    for idx in range(len(conversations)):
        for conv in conversations[idx]:
            # We don't need markdown and text value.
            del conv["markdown"]
            del conv["text"]

48
49
50
        conversation = dict(
            type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
        )
51
52
53
54
55
        conversation_dataset.append(conversation)

    return conversation_dataset


56
if __name__ == "__main__":
57
    parser = argparse.ArgumentParser()
58
59
60
61
62
63
64
65
    parser.add_argument(
        "--dataset",
        type=str,
        default="All",
        choices=["Alpaca", "ShareGPT", "All"],
        help="which dataset to convert, All will combine Alpaca and ShareGPT",
    )
    parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    args = parser.parse_args()

    conversation_dataset = []

    if args.dataset == "Alpaca":
        conversation_dataset.extend(generate_alpaca())
    elif args.dataset == "ShareGPT":
        conversation_dataset.extend(generate_sharegpt())
    else:
        conversation_dataset.extend(generate_alpaca())
        conversation_dataset.extend(generate_sharegpt())

    for idx, sample in enumerate(conversation_dataset):
        sample["id"] = idx + 1

81
    with open(args.save_path, mode="w") as f:
82
        json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)