pack_dataset.py 2.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""Fill examples with bitext up to max_tokens without breaking up examples.
[['I went', 'yo fui'],
['to the store', 'a la tienda']
]
=> ['I went to the store', 'yo fui a la tienda']
"""

import argparse
from pathlib import Path

from tqdm import tqdm

from transformers import AutoTokenizer


def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):

    finished_src, finished_tgt = [], []
19

20
    sorted_examples = list(sorted(zip(src_examples, tgt_examples), key=lambda x: len(x[0])))
21
    new_src, new_tgt = sorted_examples[0]
22
23
24
25

    def is_too_big(strang):
        return tok(strang, return_tensors="pt").input_ids.shape[1] > max_tokens

26
    for src, tgt in tqdm(sorted_examples[1:]):
27
28
29
30
31
32
33
34
        cand_src = new_src + " " + src
        cand_tgt = new_tgt + " " + tgt
        if is_too_big(cand_src) or is_too_big(cand_tgt):  # cant fit, finalize example
            finished_src.append(new_src)
            finished_tgt.append(new_tgt)
            new_src, new_tgt = src, tgt
        else:  # can fit, keep adding
            new_src, new_tgt = cand_src, cand_tgt
35
        # import ipdb; ipdb.set_trace()
36

37
38
39
40
41
    # cleanup
    if new_src:
        assert new_tgt
        finished_src.append(new_src)
        finished_tgt.append(new_tgt)
42
43
44
    return finished_src, finished_tgt


45
46
47
48
49
50
51
52
53
54
def minify(src_dir: Path, dest_dir: Path, n: int):
    """Write first n lines of each file f in src_dir to dest_dir/f"""
    dest_dir.mkdir(exist_ok=True)
    for path in src_dir.iterdir():
        new = [x.rstrip() for x in list(path.open().readlines())][:n]
        dest_path = dest_dir.joinpath(path.name)
        print(dest_path)
        dest_path.open("w").write("\n".join(new))


55
56
57
58
59
def pack_data_dir(tok, data_dir: Path, max_tokens, save_path):
    save_path = Path(save_path)
    save_path.mkdir(exist_ok=True)
    for split in ["val", "test", "train"]:
        src_path, tgt_path = data_dir / f"{split}.source", data_dir / f"{split}.target"
60
61
62
63
64
65
        src_docs = [x.rstrip() for x in Path(src_path).open().readlines()]
        tgt_docs = [x.rstrip() for x in Path(tgt_path).open().readlines()]
        packed_src, packed_tgt = pack_examples(tok, src_docs, tgt_docs, max_tokens)
        print(f"packed {split} split from {len(src_docs)} examples -> {len(packed_src)}.")
        Path(save_path / f"{split}.source").open("w").write("\n".join(packed_src))
        Path(save_path / f"{split}.target").open("w").write("\n".join(packed_tgt))
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


def packer_cli():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tok_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
    parser.add_argument("--max_seq_len", type=int, default=128)
    parser.add_argument("--data_dir", type=str)
    parser.add_argument("--save_path", type=str)
    args = parser.parse_args()
    tokenizer = AutoTokenizer.from_pretrained(args.tok_name)
    return pack_data_dir(tokenizer, Path(args.data_dir), args.max_seq_len, args.save_path)


if __name__ == "__main__":
    packer_cli()