data_loaders.py 3.31 KB
Newer Older
mohammad's avatar
mohammad committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""Dataloaders."""


import torch

from megatron import get_args
from megatron import mpu


def build_pretraining_data_loader(dataset, consumed_samples):
    """Buld dataloader given an input dataset."""

    if dataset is None:
        return None
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()

    # Megatron sampler
    batch_sampler = MegatronPretrainingSampler(
        total_samples=len(dataset),
        consumed_samples=consumed_samples,
mohammad's avatar
mohammad committed
38
        global_batch_size=args.global_batch_size,
39
40
41
42
43
44
45
46
        rank=mpu.get_data_parallel_rank(),
        world_size=world_size)

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=args.num_workers,
                                       pin_memory=True)
mohammad's avatar
mohammad committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94


class MegatronPretrainingSampler:


    def __init__(self, total_samples, consumed_samples,
                 global_batch_size, rank, world_size):
        # Keep a copy of input params for later use.
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.global_batch_size = global_batch_size
        self.rank = rank

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
        assert self.consumed_samples < self.total_samples, \
            'no samples left to consume: {}, {}'.format(self.consumed_samples,
                                                        self.total_samples)
        assert self.global_batch_size > 0, \
            'Unexpected global batch size: {}'.format(self.global_batch_size)
        assert world_size > 0,\
            'non zero world size is expected: {}'.format(world_size)
        assert self.rank < world_size,\
            'rank should be smaller than world size: {}, {}'.format(
                self.rank, world_size)

        # Batch size per rank.
        assert self.global_batch_size % world_size == 0,\
            'global batch size must be divisible by world size: {}, {}'.format(
                self.global_batch_size, world_size)
        self.batch_size_per_rank = self.global_batch_size // world_size


    def __len__(self):
        return self.total_samples


    def __iter__(self):
        batch = []
        # Last batch if not complete will be dropped.
        for idx in range(self.consumed_samples, self.total_samples):
            batch.append(idx)
            if len(batch) == self.global_batch_size:
                start_idx = self.rank * self.batch_size_per_rank
                end_idx = start_idx + self.batch_size_per_rank
                yield batch[start_idx:end_idx]
                batch = []