data_samplers.py 7.56 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
2
3
4
5
6

"""Dataloaders."""


import random
7
8
9
import torch
import numpy as np
from torch.utils.data import Dataset
xingjinliang's avatar
xingjinliang committed
10
from megatron.training import get_args
11
from megatron.core import mpu
Vijay Korthikanti's avatar
Vijay Korthikanti committed
12
13
14


def build_pretraining_data_loader(dataset, consumed_samples):
xingjinliang's avatar
xingjinliang committed
15
    """Build dataloader given an input dataset."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

    if dataset is None:
        return None
    args = get_args()

    # Megatron sampler
    if args.dataloader_type == 'single':
        batch_sampler = MegatronPretrainingSampler(
            total_samples=len(dataset),
            consumed_samples=consumed_samples,
            micro_batch_size=args.micro_batch_size,
            data_parallel_rank=mpu.get_data_parallel_rank(),
            data_parallel_size=mpu.get_data_parallel_world_size())
    elif args.dataloader_type == 'cyclic':
        batch_sampler = MegatronPretrainingRandomSampler(
31
            dataset,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
32
33
34
35
            total_samples=len(dataset),
            consumed_samples=consumed_samples,
            micro_batch_size=args.micro_batch_size,
            data_parallel_rank=mpu.get_data_parallel_rank(),
36
37
            data_parallel_size=mpu.get_data_parallel_world_size(),
            data_sharding=args.data_sharding)
xingjinliang's avatar
xingjinliang committed
38
39
40
41
    elif args.dataloader_type == "external":
        # External dataloaders are passed through. User is expected to provide a
        # torch-compatible dataloader and define samplers, if needed.
        return dataset
Vijay Korthikanti's avatar
Vijay Korthikanti committed
42
43
44
45
46
47
48
49
    else:
        raise Exception('{} dataloader type is not supported.'.format(
                args.dataloader_type))

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=args.num_workers,
xingjinliang's avatar
xingjinliang committed
50
51
52
                                       pin_memory=True,
                                       persistent_workers=True if args.num_workers > 0 else False,
                                       )
Vijay Korthikanti's avatar
Vijay Korthikanti committed
53
54
55
56

class MegatronPretrainingSampler:

    def __init__(self, total_samples, consumed_samples, micro_batch_size,
57
                 data_parallel_rank, data_parallel_size, drop_last=True):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
58
59
60
61
62
63
64
        # Keep a copy of input params for later use.
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.micro_batch_size = micro_batch_size
        self.data_parallel_rank = data_parallel_rank
        self.micro_batch_times_data_parallel_size = \
            self.micro_batch_size * data_parallel_size
65
        self.drop_last = drop_last
Vijay Korthikanti's avatar
Vijay Korthikanti committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
        assert self.consumed_samples < self.total_samples, \
            'no samples left to consume: {}, {}'.format(self.consumed_samples,
                                                        self.total_samples)
        assert self.micro_batch_size > 0
        assert data_parallel_size > 0
        assert self.data_parallel_rank < data_parallel_size, \
            'data_parallel_rank should be smaller than data size: {}, ' \
            '{}'.format(self.data_parallel_rank, data_parallel_size)

    def __len__(self):
        return self.total_samples

82
83
84
85
86
    def get_start_end_idx(self):
        start_idx = self.data_parallel_rank * self.micro_batch_size
        end_idx = start_idx + self.micro_batch_size
        return start_idx, end_idx

Vijay Korthikanti's avatar
Vijay Korthikanti committed
87
88
    def __iter__(self):
        batch = []
89
        # Last batch will be dropped if drop_last is not set False
Vijay Korthikanti's avatar
Vijay Korthikanti committed
90
91
92
        for idx in range(self.consumed_samples, self.total_samples):
            batch.append(idx)
            if len(batch) == self.micro_batch_times_data_parallel_size:
93
                start_idx, end_idx = self.get_start_end_idx()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
94
95
96
                yield batch[start_idx:end_idx]
                batch = []

97
98
99
100
101
        # Check the last partial batch and see drop_last is set
        if len(batch) > 0 and not self.drop_last:
            start_idx, end_idx = self.get_start_end_idx()
            yield batch[start_idx:end_idx]

Vijay Korthikanti's avatar
Vijay Korthikanti committed
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class RandomSeedDataset(Dataset):

    def __init__(self, dataset):
        args = get_args()
        self.base_seed = args.seed
        self.curr_seed = args.seed
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def set_epoch(self, epoch):
        self.curr_seed = self.base_seed + epoch

    def __getitem__(self, idx):
        seed = idx + self.curr_seed
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        return self.dataset[idx]


Vijay Korthikanti's avatar
Vijay Korthikanti committed
125
126
class MegatronPretrainingRandomSampler:

127
128
    def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
                 data_parallel_rank, data_parallel_size, data_sharding):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
129
        # Keep a copy of input params for later use.
130
        self.dataset = dataset
Vijay Korthikanti's avatar
Vijay Korthikanti committed
131
132
133
134
135
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.micro_batch_size = micro_batch_size
        self.data_parallel_rank = data_parallel_rank
        self.data_parallel_size = data_parallel_size
136
        self.data_sharding = data_sharding
Vijay Korthikanti's avatar
Vijay Korthikanti committed
137
138
        self.micro_batch_times_data_parallel_size = \
            self.micro_batch_size * data_parallel_size
139
140
        self.last_batch_size = \
            self.total_samples % self.micro_batch_times_data_parallel_size
Vijay Korthikanti's avatar
Vijay Korthikanti committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
        assert self.micro_batch_size > 0
        assert data_parallel_size > 0
        assert self.data_parallel_rank < data_parallel_size, \
            'data_parallel_rank should be smaller than data size: {}, ' \
            '{}'.format(self.data_parallel_rank, data_parallel_size)

    def __len__(self):
        return self.total_samples

    def __iter__(self):
155
156
157
        active_total_samples = self.total_samples - self.last_batch_size
        self.epoch = self.consumed_samples // active_total_samples
        current_epoch_samples = self.consumed_samples % active_total_samples
Vijay Korthikanti's avatar
Vijay Korthikanti committed
158
159
        assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

Vijay Korthikanti's avatar
Vijay Korthikanti committed
160
        if isinstance(self.dataset, RandomSeedDataset):
161
162
            self.dataset.set_epoch(self.epoch)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
163
        # data sharding and random sampling
164
165
166
167
168
        if self.data_sharding:
            bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
                           * self.micro_batch_size
            bucket_offset = current_epoch_samples // self.data_parallel_size
            start_idx = self.data_parallel_rank * bucket_size
xingjinliang's avatar
xingjinliang committed
169

170
171
172
173
174
175
176
177
178
179
180
181
182
183
            g = torch.Generator()
            g.manual_seed(self.epoch)
            random_idx = torch.randperm(bucket_size, generator=g).tolist()
            idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
        else:
            full_bucket_size = (self.total_samples // self.micro_batch_size) \
                                * self.micro_batch_size
            full_bucket_offset = current_epoch_samples
            g = torch.Generator()
            g.manual_seed(self.epoch)
            idx_range_total = \
                torch.randperm(full_bucket_size, generator=g).tolist()
            idx_range_active = idx_range_total[full_bucket_offset:]
            idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
184
185
186
187
188
189
190
191
192

        batch = []
        # Last batch if not complete will be dropped.
        for idx in idx_range:
            batch.append(idx)
            if len(batch) == self.micro_batch_size:
                self.consumed_samples += self.micro_batch_times_data_parallel_size
                yield batch
                batch = []