gen_data.py 4.79 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
"""
Generate line data for line retrieval task.

Usage:
python3 gen_data.py --number 1000
"""
Liangsheng Yin's avatar
Liangsheng Yin committed
7

Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
import argparse
import json
Liangsheng Yin's avatar
Liangsheng Yin committed
10
from collections import defaultdict
Lianmin Zheng's avatar
Lianmin Zheng committed
11
12

import numpy as np
Liangsheng Yin's avatar
Liangsheng Yin committed
13
from tqdm import tqdm
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45


def generate_lines(random_words, num_lines, redirect_ratio):
    prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
    suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resovling the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"

    # Raw lines
    visited_indices = set([None])
    visited_values = set([None])

    lines = []
    redirects = []
    indices = []
    values = []
    for i in tqdm(range(num_lines)):
        line_index = None
        while line_index in visited_indices:
            line_index = "-".join(np.random.choice(random_words, size=(2,)))
        visited_indices.add(line_index)

        line_value = np.random.randint(low=0, high=999999)
        line_value = f"{line_value:06}"

        line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
        lines.append(line)
        redirects.append(None)
        indices.append(line_index)
        values.append(line_value)

    # Add redirect
    if redirect_ratio > 0:
        num_redirect_lines = int(len(lines) * redirect_ratio)
Liangsheng Yin's avatar
Liangsheng Yin committed
46
47
48
        redirect_indices = np.random.choice(
            np.arange(len(lines)), size=(num_redirect_lines,), replace=False
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
        for i in redirect_indices:
            target_idx = np.random.choice(min(i * 2 + 100, num_lines))
Ying Sheng's avatar
Ying Sheng committed
51
52
53
            lines[i] = (
                f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
            redirects[i] = target_idx

    # Build links and find sources
    links = [[] for _ in range(num_lines)]
    contains_ring = set()
    for i in range(num_lines):
        if redirects[i] is None:
            continue

        tmp_link = []
        cur = i
        visited = set()
        while redirects[cur] is not None:
            visited.add(cur)
            tmp_link.append(redirects[cur])
            cur = redirects[cur]

            if cur in visited:
                contains_ring.add(i)
                tmp_link = None
                break
        values[i] = values[cur]
        links[i] = tmp_link

    # Group by num_links
    group_by_num_hoops = defaultdict(list)
    for i in range(num_lines):
        if i in contains_ring:
            continue
        group_by_num_hoops[len(links[i]) + 1].append(i)

    keys = sorted(list(group_by_num_hoops.keys()))
    for num_links in keys:
        print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")

    # Append few-shot examples
    hoop1_candidates = list(group_by_num_hoops[1])
    hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
    hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
    hoop2_candidates = list(group_by_num_hoops[2])
    hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
    hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])

    i = hoop1_candidates[5]
    suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
    if len(hoop2_candidates):
        i = hoop2_candidates[0]
        suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
        i = hoop2_candidates[1]
        suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
    else:
        i = hoop1_candidates[1]
        suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
        i = hoop1_candidates[10]
        suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])

    obj = {
        "prefix": prefix,
        "suffix": suffix,
        "lines": lines,
        "indices": indices,
        "values": values,
        "links": links,
        "group_by_num_hoops": group_by_num_hoops,
        "contains_ring": sorted(list(contains_ring)),
    }
    return obj


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--number", type=int)
    parser.add_argument("--redirect-ratio", type=float, default=0.0)
    args = parser.parse_args()

    num_lines = args.number

    random_words_filename = "random_words.json"
    random_words = json.load(open(random_words_filename, "r"))

    np.random.seed(42)
    obj = generate_lines(random_words, num_lines, args.redirect_ratio)

    fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
    with open(fout, "w") as fout:
        json.dump(obj, fout, indent=2)