test_microbatches.py 3.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import logging
from typing import List, Optional

from torch.testing._internal import common_utils

logging.getLogger("torch").setLevel(logging.WARNING)

from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import (
    _reconfigure_microbatch_calculator,
    get_micro_batch_size,
    get_num_microbatches,
    get_current_global_batch_size,
    update_num_microbatches,
)
16
17
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
18
19
20
21

logging.getLogger("apex").setLevel(logging.WARNING)


22
class MicrobatchCalculatorTestBase:
23
24
25
26
27
28
29

    GLOBAL_BATCH_SIZE: int = 1024
    MICRO_BATCH_SIZE: int = 1

    def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
        for data_parallel_size in range(1, self.world_size + 1):

30
31
            expected_global_batch_size = self.GLOBAL_BATCH_SIZE
            expected_micro_batch_size = self.MICRO_BATCH_SIZE
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
            if rampup_batch_size:
                expected_global_batch_size = rampup_batch_size[0]
                num_consumed_samples = 0
                step_of_global_batch_size = rampup_batch_size[1]
                threshold = rampup_batch_size[2]

            if data_parallel_size > 1 and data_parallel_size % 2 != 0:
                continue
            if self.world_size % data_parallel_size != 0:
                continue
            with self.subTest(data_parallel_size=data_parallel_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=self.world_size // data_parallel_size,
                    pipeline_model_parallel_size_=1,
                )
                self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size())

                _reconfigure_microbatch_calculator(
                    self.rank,
                    rampup_batch_size,
52
53
                    self.GLOBAL_BATCH_SIZE,
                    self.MICRO_BATCH_SIZE,
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
                    data_parallel_size,
                )

                self.assertEqual(get_micro_batch_size(), expected_micro_batch_size)
                self.assertEqual(get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size)
                current_global_batch_size = get_current_global_batch_size()
                self.assertEqual(current_global_batch_size, expected_global_batch_size)

                # Make sure `global_batch_size` equals to the final global batch size after
                # certain number of updates.
                if rampup_batch_size:
                    update_num_microbatches(current_global_batch_size)
                    for i in range(100):
                        current_global_batch_size = get_current_global_batch_size()
                        update_num_microbatches(current_global_batch_size)
                    current_global_batch_size = get_current_global_batch_size()
70
                    self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE)
71
72
73
74
75
76
77
78
79
                parallel_state.destroy_model_parallel()

    def test_constant_microbatch_calculator(self):
        self._test(rampup_batch_size=None)

    def test_dynamic_microbatch_calculator(self):
        self._test(rampup_batch_size=[256, 128, 500])


80
81
82
83
class NcclMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, NcclDistributedTestBase): pass
class UccMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, UccDistributedTestBase): pass


84
85
if __name__ == "__main__":
    common_utils.run_tests()