data_loaders.py 3.77 KB
Newer Older
mohammad's avatar
mohammad committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
18
19
20
21
22
23
24
"""Dataloaders."""


import torch

from megatron import get_args
from megatron import mpu


25
def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False):
26
27
28
29
30
31
32
33
34
35
    """Buld dataloader given an input dataset."""

    if dataset is None:
        return None
    args = get_args()

    # Megatron sampler
    batch_sampler = MegatronPretrainingSampler(
        total_samples=len(dataset),
        consumed_samples=consumed_samples,
36
37
        micro_batch_size=args.micro_batch_size,
        data_parallel_rank=mpu.get_data_parallel_rank(),
38
39
        data_parallel_size=mpu.get_data_parallel_world_size(),
        random_sample=random_sample)
40
41
42
43
44
45

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=args.num_workers,
                                       pin_memory=True)
mohammad's avatar
mohammad committed
46
47
48
49


class MegatronPretrainingSampler:

50
    def __init__(self, total_samples, consumed_samples, micro_batch_size,
51
                 data_parallel_rank, data_parallel_size, random_sample=False):
mohammad's avatar
mohammad committed
52
53
54
        # Keep a copy of input params for later use.
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
55
56
        self.micro_batch_size = micro_batch_size
        self.data_parallel_rank = data_parallel_rank
57
58
59
        self.micro_batch_times_data_parallel_size = \
            self.micro_batch_size * data_parallel_size
        self.random_sample = random_sample
mohammad's avatar
mohammad committed
60
61
62
63

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
64
65
66
        #assert self.consumed_samples < self.total_samples, \
        #    'no samples left to consume: {}, {}'.format(self.consumed_samples,
        #                                                self.total_samples)
67
68
69
70
71
        assert self.micro_batch_size > 0
        assert data_parallel_size > 0
        assert self.data_parallel_rank < data_parallel_size, \
            'data_parallel_rank should be smaller than data size: {}, ' \
            '{}'.format(self.data_parallel_rank, data_parallel_size)
mohammad's avatar
mohammad committed
72
73
74
75
76

    def __len__(self):
        return self.total_samples

    def __iter__(self):
77
78
79
80
81
82
83
84
85
86
87
        self.epoch = self.consumed_samples // self.total_samples
        current_epoch_samples = self.consumed_samples % self.total_samples
        if self.random_sample:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            idx_range_total = \
                torch.randperm(self.total_samples, generator=g).tolist()
            idx_range = idx_range_total[current_epoch_samples:]
        else:
            idx_range = range(current_epoch_samples, self.total_samples)

mohammad's avatar
mohammad committed
88
89
        batch = []
        # Last batch if not complete will be dropped.
90
        for idx in idx_range:
mohammad's avatar
mohammad committed
91
            batch.append(idx)
92
            if len(batch) == self.micro_batch_times_data_parallel_size:
93
                self.consumed_samples += len(batch)
94
95
                start_idx = self.data_parallel_rank * self.micro_batch_size
                end_idx = start_idx + self.micro_batch_size
mohammad's avatar
mohammad committed
96
97
                yield batch[start_idx:end_idx]
                batch = []
98
        self.consumed_samples += len(batch)