from typing import List, Optional import pytest import megatron.core.num_microbatches_calculator as mb_calculator def test_init_num_microbatches_calculator(): mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 2, False) assert mb_calculator.get_num_microbatches() == 2 assert mb_calculator.get_current_global_batch_size() == 32 with pytest.raises(AssertionError): mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 2, False) mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 3, True) assert mb_calculator.get_num_microbatches() == 1 assert mb_calculator.get_current_global_batch_size() == 32 assert mb_calculator.get_current_running_global_batch_size() == 24 mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None mb_calculator.init_num_microbatches_calculator(0, None, 33, 8, 2, True) assert mb_calculator.get_num_microbatches() == 2 assert mb_calculator.get_current_global_batch_size() == 33 assert mb_calculator.get_current_running_global_batch_size() == 32 def test_reconfigure_num_microbatches_calculator(): mb_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None mb_calculator.init_num_microbatches_calculator(0, None, 32, 8, 2, False) assert mb_calculator.get_num_microbatches() == 2 assert mb_calculator.get_current_global_batch_size() == 32 mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 8, 2, False) assert mb_calculator.get_num_microbatches() == 1 assert mb_calculator.get_current_global_batch_size() == 16 mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 16, 96], 32, 8, 2, False) assert mb_calculator.get_num_microbatches() == 1 assert mb_calculator.get_current_global_batch_size() == 16 def test_get_num_microbatches(): mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 8, 2, False) assert mb_calculator.get_num_microbatches() == 1 mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 4, 3, True) assert mb_calculator.get_num_microbatches() == 1 def test_get_current_global_batch_size(): mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 4, 2, False) assert mb_calculator.get_current_global_batch_size() == 16 mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 4, 3, True) assert mb_calculator.get_current_global_batch_size() == 16 assert mb_calculator.get_current_running_global_batch_size() == 12 def test_get_micro_batch_size(): mb_calculator.reconfigure_num_microbatches_calculator(0, None, 16, 8, 2, False) assert mb_calculator.get_micro_batch_size() == 8 def test_update_num_microbatches(): mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 8, 96], 32, 4, 2, False) assert mb_calculator.get_num_microbatches() == 2 mb_calculator.update_num_microbatches(48, False) assert mb_calculator.get_num_microbatches() == 3 mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 8, 96], 32, 8, 2, False) with pytest.raises(AssertionError): mb_calculator.update_num_microbatches(49, True) mb_calculator.reconfigure_num_microbatches_calculator(0, None, 32, 8, 2, False) mb_calculator.update_num_microbatches(16) assert mb_calculator.get_num_microbatches() == 2 def test_build_num_microbatches_calculator(): temp_calculator = mb_calculator._build_num_microbatches_calculator(0, None, 32, 8, 2, False) assert temp_calculator.get() == 2 assert temp_calculator.get_current_global_batch_size() == 32 assert type(temp_calculator) is mb_calculator.ConstantNumMicroBatchesCalculator temp_calculator = mb_calculator._build_num_microbatches_calculator( 0, [16, 16, 48], 32, 8, 2, False ) assert temp_calculator.get() == 1 assert temp_calculator.get_current_global_batch_size() == 16 assert type(temp_calculator) is mb_calculator.RampupBatchsizeNumMicroBatchesCalculator class TestConstantNumMicroBatchesCalculator: def setup_method(self, method): self.mb_calculator = mb_calculator.ConstantNumMicroBatchesCalculator(32, 8, 2, False, 0) def test_constructor(self): assert type(self.mb_calculator) is mb_calculator.ConstantNumMicroBatchesCalculator assert self.mb_calculator.num_micro_batches == 2 assert self.mb_calculator.current_global_batch_size == 32 assert self.mb_calculator.micro_batch_size == 8 def test_get(self): assert self.mb_calculator.get() == 2 def test_get_current_global_batch_size(self): assert self.mb_calculator.get_current_global_batch_size() == 32 class TestRampupBatchsizeNumMicroBatchesCalculator: def setup_method(self, method): self.mb_calculator = mb_calculator.RampupBatchsizeNumMicroBatchesCalculator( 32, 8, 2, False, 0, 16, 16, 48 ) def test_constructor(self): assert type(self.mb_calculator) is mb_calculator.RampupBatchsizeNumMicroBatchesCalculator assert self.mb_calculator.global_batch_size == 32 assert self.mb_calculator.micro_batch_size == 8 assert self.mb_calculator.data_parallel_size == 2 assert self.mb_calculator.start_global_batch_size == 16 assert self.mb_calculator.batch_size_increment == 16 assert self.mb_calculator.ramup_samples == 48 assert self.mb_calculator.micro_batch_times_data_parallel_size == 16 assert self.mb_calculator.num_micro_batches == 1 def test_get(self): assert self.mb_calculator.get() == 1 def test_get_current_global_batch_size(self): assert self.mb_calculator.get_current_global_batch_size() == 16 def test_ramp_up(): mb_calculator.reconfigure_num_microbatches_calculator(0, [16, 16, 96], 32, 8, 2, False) consumed_samples = 0 count = 0 expected_consumed_samples = [0, 16, 32, 48, 64, 80, 96, 128, 160, 192, 224, 256] while consumed_samples < 256: consumed_samples += mb_calculator.get_current_global_batch_size() count += 1 assert consumed_samples == expected_consumed_samples[count] mb_calculator.update_num_microbatches(consumed_samples, True)