sampler.py 2.49 KB
Newer Older
mohammad's avatar
mohammad committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatorn Sampler."""


class MegatronPretrainingSampler:


    def __init__(self, total_samples, consumed_samples,
                 global_batch_size, rank, world_size):
        # Keep a copy of input params for later use.
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.global_batch_size = global_batch_size
        self.rank = rank

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
        assert self.consumed_samples < self.total_samples, \
            'no samples left to consume: {}, {}'.format(self.consumed_samples,
                                                        self.total_samples)
        assert self.global_batch_size > 0, \
            'Unexpected global batch size: {}'.format(self.global_batch_size)
        assert world_size > 0,\
            'non zero world size is expected: {}'.format(world_size)
        assert self.rank < world_size,\
            'rank should be smaller than world size: {}, {}'.format(
                self.rank, world_size)

        # Batch size per rank.
        assert self.global_batch_size % world_size == 0,\
            'global batch size must be divisible by world size: {}, {}'.format(
                self.global_batch_size, world_size)
        self.batch_size_per_rank = self.global_batch_size // world_size


    def __len__(self):
        return self.total_samples


    def __iter__(self):
        batch = []
        # Last batch if not complete will be dropped.
        for idx in range(self.consumed_samples, self.total_samples):
            batch.append(idx)
            if len(batch) == self.global_batch_size:
                start_idx = self.rank * self.batch_size_per_rank
                end_idx = start_idx + self.batch_size_per_rank
                yield batch[start_idx:end_idx]
                batch = []