test_indexed_dataset.py 2.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
import argparse
import os
import sys

import torch

script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../"))

10
from megatron.data import indexed_dataset, FullBertTokenizer, AlbertDataset
11
12

def test_indexed_dataset(args):
13
    ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
14
    tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
15
16
17
18
19
20
    print(len(ds.doc_idx))
    print(len(ds))
    print(ds.doc_idx[-1])
    if ds.supports_prefetch:
        # just prefetch the whole thing in test (so assume it is small)
        ds.prefetch(range(len(ds)))
21
    for i in range(len(ds.doc_idx)-1):
22
23
        start = ds.doc_idx[i]
        end = ds.doc_idx[i+1]
24
25
        ids = ds[start:end]
        for s in ids:
26
            assert len(s) > 0
27
28
            l = s.data.tolist()
            tokens = tokenizer.convert_ids_to_tokens(l)
29
30
31
32
            for t in tokens:
                if '\n' in t:
                    print("Newline in string!")
        print(i)
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def test_albert_dataset(args):
    # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
    # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
    # ds = AlbertDataset(idataset, tokenizer)
    ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
                                  args.epochs, args.max_num_samples,
                                  args.masked_lm_prob, args.seq_length,
                                  args.short_seq_prob, args.seed)
    truncated = 0
    total = 0
    for s in ds:
        ids = s['text']
        tokens = ds.tokenizer.convert_ids_to_tokens(ids)
        print(tokens)
        exit()

50
51
52
53
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, help='prefix to data files')
    parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
54
55
    parser.add_argument('--dataset-impl', type=str, default='infer',
                        choices=['lazy', 'cached', 'mmap', 'infer'])
56
57
58
59
60
61
62
63
64
65
66
67
    parser.add_argument('--epochs', type=int, default=5,
                        help='Number of epochs to plan for')
    parser.add_argument('--max-num-samples', type=int, default=None,
                        help='Maximum number of samples to plan for')
    parser.add_argument('--masked-lm-prob', type=float, default=0.15,
                        help='probability of masking tokens')
    parser.add_argument('--seq-length', type=int, default=512,
                        help='maximum sequence length')
    parser.add_argument('--short-seq-prob', type=float, default=0.1,
                        help='probability of creating a short sequence')
    parser.add_argument('--seed', type=int, default=1234,
                        help='random seed')
68
69
    args = parser.parse_args()

70
71
72
    if args.dataset_impl == "infer":
        args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)

73
74
    test_albert_dataset(args)
#    test_indexed_dataset(args)
75
76
77

if __name__ == "__main__":
    main()