test_num_microbatches_calculator.py 6.17 KB
Newer Older
liangjing's avatar
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from typing import List, Optional

import pytest

import megatron.core.num_microbatches_calculator as mb_calculator


def test_init_num_microbatches_calculator():
    mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
    mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 2, False)
    assert mb_calculator.get_num_microbatches() == 2
    assert mb_calculator.get_current_global_batch_size() == 32

    with pytest.raises(AssertionError):
        mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 2, False)

    mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
    mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 3, True)
    assert mb_calculator.get_num_microbatches() == 1
    assert mb_calculator.get_current_global_batch_size() == 32
    assert mb_calculator.get_current_running_global_batch_size() == 24

    mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
    mb_calculator.init_num_microbatches_calculator(0, None, 33, 8, 2, True)
    assert mb_calculator.get_num_microbatches() == 2
    assert mb_calculator.get_current_global_batch_size() == 33
    assert mb_calculator.get_current_running_global_batch_size() == 32


def test_reconfigure_num_microbatches_calculator():
    mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
    mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 2, False)
    assert mb_calculator.get_num_microbatches() == 2
    assert mb_calculator.get_current_global_batch_size() == 32

    mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 8, 2, False)
    assert mb_calculator.get_num_microbatches() == 1
    assert mb_calculator.get_current_global_batch_size() == 16

    mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 16, 96], 32, 8, 2, False)
    assert mb_calculator.get_num_microbatches() == 1
    assert mb_calculator.get_current_global_batch_size() == 16


def test_get_num_microbatches():
    mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 8, 2, False)
    assert mb_calculator.get_num_microbatches() == 1

    mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 4, 3, True)
    assert mb_calculator.get_num_microbatches() == 1


def test_get_current_global_batch_size():
    mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 4, 2, False)
    assert mb_calculator.get_current_global_batch_size() == 16

    mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 4, 3, True)
    assert mb_calculator.get_current_global_batch_size() == 16
    assert mb_calculator.get_current_running_global_batch_size() == 12


def test_get_micro_batch_size():
    mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 8, 2, False)
    assert mb_calculator.get_micro_batch_size() == 8


def test_update_num_microbatches():
    mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 8, 96], 32, 4, 2, False)
    assert mb_calculator.get_num_microbatches() == 2
    mb_calculator.update_num_microbatches(48, False)
    assert mb_calculator.get_num_microbatches() == 3

    mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 8, 96], 32, 8, 2, False)
    with pytest.raises(AssertionError):
        mb_calculator.update_num_microbatches(49, True)

    mb_calculator.reconfigure_num_microbatches_calculator(0, None, 32, 8, 2, False)
    mb_calculator.update_num_microbatches(16)
    assert mb_calculator.get_num_microbatches() == 2


def test_build_num_microbatches_calculator():
    temp_calculator = mb_calculator._build_num_microbatches_calculator(0, None, 32, 8, 2, False)
    assert temp_calculator.get() == 2
    assert temp_calculator.get_current_global_batch_size() == 32
    assert type(temp_calculator) is mb_calculator.ConstantNumMicroBatchesCalculator

    temp_calculator = mb_calculator._build_num_microbatches_calculator(
        0, [16, 16, 48], 32, 8, 2, False
    )
    assert temp_calculator.get() == 1
    assert temp_calculator.get_current_global_batch_size() == 16
    assert type(temp_calculator) is mb_calculator.RampupBatchsizeNumMicroBatchesCalculator


class TestConstantNumMicroBatchesCalculator:
    def setup_method(self, method):
        self.mb_calculator = mb_calculator.ConstantNumMicroBatchesCalculator(32, 8, 2, False, 0)

    def test_constructor(self):
        assert type(self.mb_calculator) is mb_calculator.ConstantNumMicroBatchesCalculator
        assert self.mb_calculator.num_micro_batches == 2
        assert self.mb_calculator.current_global_batch_size == 32
        assert self.mb_calculator.micro_batch_size == 8

    def test_get(self):
        assert self.mb_calculator.get() == 2

    def test_get_current_global_batch_size(self):
        assert self.mb_calculator.get_current_global_batch_size() == 32


class TestRampupBatchsizeNumMicroBatchesCalculator:
    def setup_method(self, method):
        self.mb_calculator = mb_calculator.RampupBatchsizeNumMicroBatchesCalculator(
            32, 8, 2, False, 0, 16, 16, 48
        )

    def test_constructor(self):
        assert type(self.mb_calculator) is mb_calculator.RampupBatchsizeNumMicroBatchesCalculator
        assert self.mb_calculator.global_batch_size == 32
        assert self.mb_calculator.micro_batch_size == 8
        assert self.mb_calculator.data_parallel_size == 2
        assert self.mb_calculator.start_global_batch_size == 16
        assert self.mb_calculator.batch_size_increment == 16
        assert self.mb_calculator.ramup_samples == 48
        assert self.mb_calculator.micro_batch_times_data_parallel_size == 16
        assert self.mb_calculator.num_micro_batches == 1

    def test_get(self):
        assert self.mb_calculator.get() == 1

    def test_get_current_global_batch_size(self):
        assert self.mb_calculator.get_current_global_batch_size() == 16


def test_ramp_up():
    mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 16, 96], 32, 8, 2, False)
    consumed_samples = 0
    count = 0
    expected_consumed_samples = [0, 16, 32, 48, 64, 80, 96, 128, 160, 192, 224, 256]

    while consumed_samples < 256:
        consumed_samples += mb_calculator.get_current_global_batch_size()
        count += 1
        assert consumed_samples == expected_consumed_samples[count]
        mb_calculator.update_num_microbatches(consumed_samples, True)