gen_data.py 4.75 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Generate line data for line retrieval task.

Usage:
python3 gen_data.py --number 1000
"""
import argparse
from collections import defaultdict
import json

from tqdm import tqdm
import numpy as np


def generate_lines(random_words, num_lines, redirect_ratio):
    prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
    suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resovling the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"

    # Raw lines
    visited_indices = set([None])
    visited_values = set([None])

    lines = []
    redirects = []
    indices = []
    values = []
    for i in tqdm(range(num_lines)):
        line_index = None
        while line_index in visited_indices:
            line_index = "-".join(np.random.choice(random_words, size=(2,)))
        visited_indices.add(line_index)

        line_value = np.random.randint(low=0, high=999999)
        line_value = f"{line_value:06}"

        line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
        lines.append(line)
        redirects.append(None)
        indices.append(line_index)
        values.append(line_value)

    # Add redirect
    if redirect_ratio > 0:
        num_redirect_lines = int(len(lines) * redirect_ratio)
        redirect_indices = np.random.choice(np.arange(len(lines)),
            size=(num_redirect_lines,), replace=False)
        for i in redirect_indices:
            target_idx = np.random.choice(min(i * 2 + 100, num_lines))
            lines[i] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
            redirects[i] = target_idx

    # Build links and find sources
    links = [[] for _ in range(num_lines)]
    contains_ring = set()
    for i in range(num_lines):
        if redirects[i] is None:
            continue

        tmp_link = []
        cur = i
        visited = set()
        while redirects[cur] is not None:
            visited.add(cur)
            tmp_link.append(redirects[cur])
            cur = redirects[cur]

            if cur in visited:
                contains_ring.add(i)
                tmp_link = None
                break
        values[i] = values[cur]
        links[i] = tmp_link

    # Group by num_links
    group_by_num_hoops = defaultdict(list)
    for i in range(num_lines):
        if i in contains_ring:
            continue
        group_by_num_hoops[len(links[i]) + 1].append(i)

    keys = sorted(list(group_by_num_hoops.keys()))
    for num_links in keys:
        print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")

    # Append few-shot examples
    hoop1_candidates = list(group_by_num_hoops[1])
    hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
    hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
    hoop2_candidates = list(group_by_num_hoops[2])
    hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
    hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])

    i = hoop1_candidates[5]
    suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
    if len(hoop2_candidates):
        i = hoop2_candidates[0]
        suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
        i = hoop2_candidates[1]
        suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
    else:
        i = hoop1_candidates[1]
        suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
        i = hoop1_candidates[10]
        suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])

    obj = {
        "prefix": prefix,
        "suffix": suffix,
        "lines": lines,
        "indices": indices,
        "values": values,
        "links": links,
        "group_by_num_hoops": group_by_num_hoops,
        "contains_ring": sorted(list(contains_ring)),
    }
    return obj


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--number", type=int)
    parser.add_argument("--redirect-ratio", type=float, default=0.0)
    args = parser.parse_args()

    num_lines = args.number

    random_words_filename = "random_words.json"
    random_words = json.load(open(random_words_filename, "r"))

    np.random.seed(42)
    obj = generate_lines(random_words, num_lines, args.redirect_ratio)

    fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
    with open(fout, "w") as fout:
        json.dump(obj, fout, indent=2)