Commit de0b70a0 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Support for ramping up the batch size

parent c30ba0f7
...@@ -17,10 +17,12 @@ ...@@ -17,10 +17,12 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
import math
import os import os
import sys import sys
import time import time
import numpy as np
import torch import torch
from megatron.tokenizer import build_tokenizer from megatron.tokenizer import build_tokenizer
...@@ -127,11 +129,11 @@ class NumMicroBatchesCalculator(ABC): ...@@ -127,11 +129,11 @@ class NumMicroBatchesCalculator(ABC):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.num_micro_batches = None
super(NumMicroBatchesCalculator, self).__init__() super(NumMicroBatchesCalculator, self).__init__()
@abstractmethod
def get(self): def get(self):
pass return self.num_micro_batches
@abstractmethod @abstractmethod
def update(self, consumed_samples): def update(self, consumed_samples):
...@@ -149,9 +151,70 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator): ...@@ -149,9 +151,70 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def update(self, consumed_samples): def update(self, consumed_samples):
pass pass
def get(self):
return self.num_micro_batches
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
global_batch_size, micro_batch_size, data_parallel_size):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
self.data_parallel_size
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
self.start_batch_size = start_batch_size
assert global_batch_size > 0
self.global_batch_size = global_batch_size
diff_batch_size = self.global_batch_size - self.start_batch_size
assert diff_batch_size >= 0
assert batch_size_increment > 0
self.batch_size_increment = batch_size_increment
assert diff_batch_size % batch_size_increment == 0, 'expected ' \
'global batch size interval ({}) to be divisible by global batch ' \
'size increment ({})'.format(diff_batch_size, batch_size_increment)
num_increments = diff_batch_size // self.batch_size_increment
assert ramup_samples >= 0
self.rampup_samples_per_increment = ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0)
def update(self, consumed_samples):
steps = int(consumed_samples / self.rampup_samples_per_increment)
current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
current_global_batch_size = min(current_global_batch_size,
self.global_batch_size)
assert current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = current_global_batch_size // \
self.micro_batch_times_data_parallel_size
def _build_tokenizer(args): def _build_tokenizer(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment