tf_dl.py 5.38 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""

Neel Kant's avatar
Neel Kant committed
17
18
import numpy as np
import torch
Raul Puri's avatar
Raul Puri committed
19
20
21
import queue
import threading

Raul Puri's avatar
Raul Puri committed
22
23
import tensorflow as tf
tf.enable_eager_execution()
Neel Kant's avatar
Neel Kant committed
24

Raul Puri's avatar
Raul Puri committed
25
26

class TFRecordDataLoader(object):
Neel Kant's avatar
Neel Kant committed
27
28
    def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq,
                 train, num_workers=2, seed=1, threaded_dl=False):
Raul Puri's avatar
Raul Puri committed
29
30
31
        assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
        tf.set_random_seed(seed)
        if isinstance(records, str):
Neel Kant's avatar
Neel Kant committed
32
            records = [records]
Raul Puri's avatar
Raul Puri committed
33
34
35
36
37
38
39
40
41

        self.record_converter = Record2Example({"input_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
                                                "input_mask": tf.FixedLenFeature([max_seq_len], tf.int64),
                                                "segment_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
                                                "masked_lm_positions": tf.FixedLenFeature([max_preds_per_seq], tf.int64),
                                                "masked_lm_ids": tf.FixedLenFeature([max_preds_per_seq], tf.int64),
                                                "masked_lm_weights": tf.FixedLenFeature([max_preds_per_seq], tf.float32),
                                                "next_sentence_labels": tf.FixedLenFeature([1], tf.int64)})

Neel Kant's avatar
Neel Kant committed
42
        # Instantiate dataset according to original BERT implementation
Raul Puri's avatar
Raul Puri committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        if train:
            self.dataset = tf.data.Dataset.from_tensor_slices(tf.constant(records))
            self.dataset = self.dataset.repeat()
            self.dataset = self.dataset.shuffle(buffer_size=len(records))

            # use sloppy tfrecord dataset
            self.dataset = self.dataset.apply(
                tf.contrib.data.parallel_interleave(
                    tf.data.TFRecordDataset,
                    sloppy=train,
                    cycle_length=min(num_workers, len(records))))
            self.dataset = self.dataset.shuffle(buffer_size=100)
        else:
            self.dataset = tf.data.TFRecordDataset(records)
            self.dataset = self.dataset.repeat()

        # Instantiate dataloader (do not drop remainder for eval)
Neel Kant's avatar
Neel Kant committed
60
        loader_args = {'batch_size': batch_size,
Raul Puri's avatar
Raul Puri committed
61
62
                       'num_parallel_batches': num_workers,
                       'drop_remainder': train}
Neel Kant's avatar
Neel Kant committed
63
64
65
        self.dataloader = self.dataset.apply(
            tf.contrib.data.map_and_batch(
                self.record_converter, **loader_args))
Raul Puri's avatar
Raul Puri committed
66
67
        self.threaded_dl = threaded_dl
        self.num_workers = num_workers
Raul Puri's avatar
Raul Puri committed
68
69

    def __iter__(self):
Raul Puri's avatar
Raul Puri committed
70
71
72
73
74
75
76
77
        if self.threaded_dl:
            data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers))
            for item in data_iter:
                yield item
        else:
            data_iter = iter(self.dataloader)
            for item in data_iter:
                yield convert_tf_example_to_torch_tensors(item)
Raul Puri's avatar
Raul Puri committed
78

Neel Kant's avatar
Neel Kant committed
79

Raul Puri's avatar
Raul Puri committed
80
81
82
83
84
85
86
87
88
89
90
91
class Record2Example(object):
    def __init__(self, feature_map):
        self.feature_map = feature_map

    def __call__(self, record):
        """Decodes a BERT TF record to a TF example."""
        example = tf.parse_single_example(record, self.feature_map)
        for k, v in list(example.items()):
            if v.dtype == tf.int64:
                example[k] = tf.to_int32(v)
        return example

Neel Kant's avatar
Neel Kant committed
92

Raul Puri's avatar
Raul Puri committed
93
def convert_tf_example_to_torch_tensors(example):
Neel Kant's avatar
Neel Kant committed
94
    item = {k: (v.numpy()) for k, v in example.items()}
Raul Puri's avatar
Raul Puri committed
95
    mask = np.zeros_like(item['input_ids'])
Neel Kant's avatar
Neel Kant committed
96
    mask_labels = np.ones_like(item['input_ids']) * -1
Raul Puri's avatar
Raul Puri committed
97
    for b, row in enumerate(item['masked_lm_positions'].astype(int)):
Raul Puri's avatar
Raul Puri committed
98
99
100
101
        for i, idx in enumerate(row):
            if item['masked_lm_weights'][b, i] != 0:
                mask[b, idx] = 1
                mask_labels[b, idx] = item['masked_lm_ids'][b, i]
Neel Kant's avatar
Neel Kant committed
102
103
104
105
    output = {'text': item['input_ids'], 'types': item['segment_ids'], 'is_random': item['next_sentence_labels'],
              'pad_mask': 1 - item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
    return {k: torch.from_numpy(v) for k, v in output.items()}

Raul Puri's avatar
Raul Puri committed
106
107
108
109

class MultiprocessLoader(object):
    def __init__(self, dataloader, num_workers=2):
        self.dl = dataloader
Neel Kant's avatar
Neel Kant committed
110
        self.queue_size = 2 * num_workers
Raul Puri's avatar
Raul Puri committed
111
112
113
114
115
116
117
118
119
120
121
122

    def __iter__(self):
        output_queue = queue.Queue(self.queue_size)
        output_thread = threading.Thread(target=_multiproc_iter,
                                         args=(self.dl, output_queue))
        output_thread.daemon = True
        output_thread.start()

        while output_thread.is_alive():
            yield output_queue.get(block=True)
        else:
            print(RuntimeError('TF record data loader thread exited unexpectedly'))
Raul Puri's avatar
Raul Puri committed
123

Neel Kant's avatar
Neel Kant committed
124

Raul Puri's avatar
Raul Puri committed
125
126
127
128
def _multiproc_iter(dl, output_queue):
    data_iter = iter(dl)
    for item in data_iter:
        tensors = convert_tf_example_to_torch_tensors(item)
Neel Kant's avatar
Neel Kant committed
129
        output_queue.put(tensors, block=True)