wikitext2_data.py 3.32 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
5

6
from distutils.version import LooseVersion
7
import io
8
import operator
9
import tempfile
10
11
12

import torch
from torch.utils.data import DataLoader
13

14
import torchtext
15
16
from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive
17
18
19
20
21

if operator.ge(torchtext.__version__, LooseVersion("0.10.0")):
    from torchtext.legacy.vocab import build_vocab_from_iterator
else:
    from torchtext.vocab import build_vocab_from_iterator
22
23
24
25
26
27
28
29
30
31
32
33
34


def _batchify(data, batch_size):
    data = torch.tensor(data)
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // batch_size
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * batch_size)
    # Evenly divide the data across the bsz batches.
    data = data.view(batch_size, -1).t().contiguous()
    return data


anj-s's avatar
anj-s committed
35
36
def _get_total_batch_size(benchmark_config, model_specs):
    return model_specs["seq_len"] * benchmark_config["batch_size"]
37
38


anj-s's avatar
anj-s committed
39
def get_real_dataloaders(args, benchmark_config, model_specs):
40
41
42
    """Return real dataloaders for training, testing and validation."""

    url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
43
44
    tmpdir = tempfile.TemporaryDirectory()
    test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=tmpdir.name))
45
46
47
48
49
50
51
52
53
54
55
56
57
    tokenizer = get_tokenizer("basic_english")

    def data_process(raw_text_iter):
        data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter]
        return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

    vocab = build_vocab_from_iterator(map(tokenizer, iter(io.open(train_filepath, encoding="utf8"))))

    train_dataset = data_process(iter(io.open(train_filepath, encoding="utf8")))
    valid_dataset = data_process(iter(io.open(valid_filepath, encoding="utf8")))
    test_dataset = data_process(iter(io.open(test_filepath, encoding="utf8")))

    def batchify(data):
58
        batch_size = benchmark_config["batch_size"]
59
60
        return _batchify(data, batch_size)

anj-s's avatar
anj-s committed
61
    total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
62
63
64
65
66
67
    train_dataloader = DataLoader(train_dataset, batch_size=total_batch_size, collate_fn=batchify)
    valid_dataloader = DataLoader(valid_dataset, batch_size=total_batch_size, collate_fn=batchify)
    test_dataloader = DataLoader(test_dataset, batch_size=total_batch_size, collate_fn=batchify)
    return len(vocab.stoi), train_dataloader, valid_dataloader, test_dataloader


anj-s's avatar
anj-s committed
68
def get_synthetic_dataloaders(args, benchmark_config, model_specs):
69
70
71
    """Return synthetic dataloaders for training, testing and validation."""

    def batchify(data):
72
        batch_size = benchmark_config["batch_size"]
73
74
        return _batchify(data, batch_size)

anj-s's avatar
anj-s committed
75
    total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
76
77
78
79
80
81
82
    # vocab_size is 10000 and length of the real data is 2049990.
    lm_dataset = torch.randint(1, 10000, (2049990,))

    lm_dataloader = DataLoader(
        lm_dataset, batch_size=total_batch_size, shuffle=True, num_workers=0, collate_fn=batchify
    )
    return lm_dataloader, lm_dataloader, lm_dataloader