test_indexed_dataset.py 1.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import argparse
import os
import sys

import torch

script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../"))

from megatron.data import indexed_dataset, FullBertTokenizer

def test_indexed_dataset(args):
13
    ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
14
    tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    print(len(ds.doc_idx))
    print(len(ds))
    print(ds.doc_idx[-1])
    if ds.supports_prefetch:
        # just prefetch the whole thing in test (so assume it is small)
        ds.prefetch(range(len(ds)))
    for i in range(1):
        start = ds.doc_idx[i]
        end = ds.doc_idx[i+1]
        print(start, end)
        for j in range(start, end):
            ids = ds[j].data.tolist()
            print(ids)
            tokens = tokenizer.convert_ids_to_tokens(ids)
            print(tokens)
        print("******** END DOCUMENT **********")
31
32
33
34
35

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, help='prefix to data files')
    parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
36
37
    parser.add_argument('--dataset-impl', type=str, default='infer',
                        choices=['lazy', 'cached', 'mmap', 'infer'])
38
39
    args = parser.parse_args()

40
41
42
    if args.dataset_impl == "infer":
        args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)

43
44
45
46
    test_indexed_dataset(args)

if __name__ == "__main__":
    main()