wikitext2_data.py 4.24 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
5

6
from collections import namedtuple
7
from distutils.version import LooseVersion
8
import io
9
import operator
10
import tempfile
11
12
13

import torch
from torch.utils.data import DataLoader
14
from torch.utils.data.distributed import DistributedSampler
15
import torchtext
16
17
from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive
18
19
20
21
22

if operator.ge(torchtext.__version__, LooseVersion("0.10.0")):
    from torchtext.legacy.vocab import build_vocab_from_iterator
else:
    from torchtext.vocab import build_vocab_from_iterator
23
24
25
26
27
28
29
30
31
32
33
34
35


def _batchify(data, batch_size):
    data = torch.tensor(data)
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // batch_size
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * batch_size)
    # Evenly divide the data across the bsz batches.
    data = data.view(batch_size, -1).t().contiguous()
    return data


anj-s's avatar
anj-s committed
36
37
def _get_total_batch_size(benchmark_config, model_specs):
    return model_specs["seq_len"] * benchmark_config["batch_size"]
38
39


40
41
DatasetsInfo = namedtuple("DataSetsInfo", ["ntokens", "train_dataset", "valid_dataset", "test_dataset"])

42

43
def get_real_datasets():
44
    url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
45
46
    tmpdir = tempfile.TemporaryDirectory()
    test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=tmpdir.name))
47
48
49
50
51
52
53
54
55
56
57
    tokenizer = get_tokenizer("basic_english")

    def data_process(raw_text_iter):
        data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter]
        return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

    vocab = build_vocab_from_iterator(map(tokenizer, iter(io.open(train_filepath, encoding="utf8"))))

    train_dataset = data_process(iter(io.open(train_filepath, encoding="utf8")))
    valid_dataset = data_process(iter(io.open(valid_filepath, encoding="utf8")))
    test_dataset = data_process(iter(io.open(test_filepath, encoding="utf8")))
58
59
60
61
62
    return DatasetsInfo(len(vocab.stoi), train_dataset, valid_dataset, test_dataset)


def get_dataloaders(datasets_info, benchmark_config, model_specs, num_replicas=1, rank=0):
    ntokens, train_dataset, valid_dataset, test_dataset = datasets_info
63
64

    def batchify(data):
65
        batch_size = benchmark_config["batch_size"]
66
67
        return _batchify(data, batch_size)

anj-s's avatar
anj-s committed
68
    total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    train_dataloader = DataLoader(
        train_dataset,
        sampler=DistributedSampler(train_dataset, num_replicas=num_replicas, rank=rank),
        batch_size=total_batch_size,
        collate_fn=batchify,
    )
    valid_dataloader = DataLoader(
        valid_dataset,
        sampler=DistributedSampler(valid_dataset, num_replicas=num_replicas, rank=rank),
        batch_size=total_batch_size,
        collate_fn=batchify,
    )
    test_dataloader = DataLoader(
        test_dataset,
        sampler=DistributedSampler(test_dataset, num_replicas=num_replicas, rank=rank),
        batch_size=total_batch_size,
        collate_fn=batchify,
    )
    return train_dataloader, valid_dataloader, test_dataloader
88
89


90
91
92
93
94
95
96
def get_real_dataloaders(args, benchmark_config, model_specs, num_replicas=1, rank=0):
    """Return real dataloaders for training, testing and validation."""
    dataset_info = get_real_datasets()
    train_dataloader, valid_dataloader, test_dataloader = get_dataloaders(
        dataset_info, benchmark_config, model_specs, num_replicas, rank
    )
    return dataset_info.ntokens, train_dataloader, valid_dataloder, test_dataloader
97
98


99
def get_synthetic_datasets():
100
101
    # vocab_size is 10000 and length of the real data is 2049990.
    lm_dataset = torch.randint(1, 10000, (2049990,))
102
    return DatasetsInfo(10000, lm_dataset, lm_dataset, lm_dataset)
103

104
105
106
107

def get_synthetic_dataloaders(args, benchmark_config, model_specs, num_replicas=1, rank=0):
    """Return synthetic dataloaders for training, testing and validation."""
    return get_dataloaders(get_synthetic_datasets(), benchmark_config, model_specs, num_replicas, rank)