Commit de0b70a0 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Support for ramping up the batch size

parent c30ba0f7
......@@ -17,10 +17,12 @@
from abc import ABC
from abc import abstractmethod
import math
import os
import sys
import time
import numpy as np
import torch
from megatron.tokenizer import build_tokenizer
......@@ -127,11 +129,11 @@ class NumMicroBatchesCalculator(ABC):
def __init__(self, name):
self.name = name
self.num_micro_batches = None
super(NumMicroBatchesCalculator, self).__init__()
@abstractmethod
def get(self):
pass
return self.num_micro_batches
@abstractmethod
def update(self, consumed_samples):
......@@ -149,9 +151,70 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def update(self, consumed_samples):
pass
def get(self):
return self.num_micro_batches
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
global_batch_size, micro_batch_size, data_parallel_size):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
self.data_parallel_size
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
self.start_batch_size = start_batch_size
assert global_batch_size > 0
self.global_batch_size = global_batch_size
diff_batch_size = self.global_batch_size - self.start_batch_size
assert diff_batch_size >= 0
assert batch_size_increment > 0
self.batch_size_increment = batch_size_increment
assert diff_batch_size % batch_size_increment == 0, 'expected ' \
'global batch size interval ({}) to be divisible by global batch ' \
'size increment ({})'.format(diff_batch_size, batch_size_increment)
num_increments = diff_batch_size // self.batch_size_increment
assert ramup_samples >= 0
self.rampup_samples_per_increment = ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0)
def update(self, consumed_samples):
steps = int(consumed_samples / self.rampup_samples_per_increment)
current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
current_global_batch_size = min(current_global_batch_size,
self.global_batch_size)
assert current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = current_global_batch_size // \
self.micro_batch_times_data_parallel_size
def _build_tokenizer(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment