"vscode:/vscode.git/clone" did not exist on "b18547fdeb6b9e13bc304fd8ba2167663cafdbe9"
Commit 00ac56ab authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Address Jared's comments

parent 22ab91bb
...@@ -264,9 +264,11 @@ def _add_training_args(parser): ...@@ -264,9 +264,11 @@ def _add_training_args(parser):
group.add_argument('--micro-batch-size', type=int, default=None, group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). ' help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data ' 'Global batch size is local batch size times data '
'parallel size.') 'parallel size times number of micro batches.')
group.add_argument('--global-batch-size', type=int, default=None, group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If this value is None, then ' help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the ' 'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for ' 'global batch size. This choice will result in 1 for '
'number of micro-batches.') 'number of micro-batches.')
......
...@@ -15,18 +15,15 @@ ...@@ -15,18 +15,15 @@
"""Megatron global variables.""" """Megatron global variables."""
from abc import ABC
from abc import abstractmethod
import math
import os import os
import sys import sys
import time import time
import numpy as np
import torch import torch
from megatron.tokenizer import build_tokenizer from megatron.tokenizer import build_tokenizer
from .arguments import parse_args from .arguments import parse_args
from .microbatches import build_num_microbatches_calculator
_GLOBAL_ARGS = None _GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
...@@ -104,132 +101,8 @@ def _build_num_microbatches_calculator(args): ...@@ -104,132 +101,8 @@ def _build_num_microbatches_calculator(args):
_ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
'num microbatches calculator') 'num microbatches calculator')
# Constant num micro-batches. _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
if args.rampup_batch_size is None: args)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
if args.rank == 0:
print('setting number of micro-batches to constant {}'.format(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()), flush=True)
else:
assert len(args.rampup_batch_size) == 3, 'expected the following ' \
'format: --rampup-batch-size <start batch size> ' \
'<batch size incerement> <ramp-up samples>'
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
print('will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'.format(start_batch_size,
args.global_batch_size,
batch_size_increment,
ramup_samples), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = RampupBatchsizeNumMicroBatches(
start_batch_size, batch_size_increment, ramup_samples,
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
class NumMicroBatchesCalculator(ABC):
def __init__(self):
self.num_micro_batches = None
def get(self):
return self.num_micro_batches
@abstractmethod
def update(self, consumed_samples):
pass
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
micro_batch_times_data_parallel = micro_batch_size * \
data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, \
'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(global_batch_size,
micro_batch_size,
data_parallel_size)
self.num_micro_batches = global_batch_size // \
micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
def update(self, consumed_samples):
pass
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
global_batch_size, micro_batch_size, data_parallel_size):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
self.data_parallel_size
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
self.start_batch_size = start_batch_size
assert global_batch_size > 0
self.global_batch_size = global_batch_size
diff_batch_size = self.global_batch_size - self.start_batch_size
assert diff_batch_size >= 0
assert batch_size_increment > 0
self.batch_size_increment = batch_size_increment
assert diff_batch_size % batch_size_increment == 0, 'expected ' \
'global batch size interval ({}) to be divisible by global batch ' \
'size increment ({})'.format(diff_batch_size, batch_size_increment)
num_increments = diff_batch_size // self.batch_size_increment
self.ramup_samples = ramup_samples
assert self.ramup_samples >= 0
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0)
def update(self, consumed_samples):
if consumed_samples > self.ramup_samples:
current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert current_global_batch_size <= self.global_batch_size
assert current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = current_global_batch_size // \
self.micro_batch_times_data_parallel_size
def _build_tokenizer(args): def _build_tokenizer(args):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron number of micro-batches calculators."""
from abc import ABC
from abc import abstractmethod
def build_num_microbatches_calculator(args):
# Constant num micro-batches.
if args.rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
if args.rank == 0:
print('setting number of micro-batches to constant {}'.format(
num_microbatches_calculator.get()), flush=True)
else:
assert len(args.rampup_batch_size) == 3, 'expected the following ' \
'format: --rampup-batch-size <start batch size> ' \
'<batch size incerement> <ramp-up samples>'
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
print('will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'.format(start_batch_size,
args.global_batch_size,
batch_size_increment,
ramup_samples), flush=True)
num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
start_batch_size, batch_size_increment, ramup_samples,
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
return num_microbatches_calculator
class NumMicroBatchesCalculator(ABC):
def __init__(self):
self.num_micro_batches = None
def get(self):
return self.num_micro_batches
@abstractmethod
def update(self, consumed_samples):
pass
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
micro_batch_times_data_parallel = micro_batch_size * \
data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, \
'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(global_batch_size,
micro_batch_size,
data_parallel_size)
self.num_micro_batches = global_batch_size // \
micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
def update(self, consumed_samples):
pass
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
global_batch_size, micro_batch_size, data_parallel_size):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
self.data_parallel_size
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
self.start_batch_size = start_batch_size
assert global_batch_size > 0
self.global_batch_size = global_batch_size
diff_batch_size = self.global_batch_size - self.start_batch_size
assert diff_batch_size >= 0
assert batch_size_increment > 0
self.batch_size_increment = batch_size_increment
assert diff_batch_size % batch_size_increment == 0, 'expected ' \
'global batch size interval ({}) to be divisible by global batch ' \
'size increment ({})'.format(diff_batch_size, batch_size_increment)
num_increments = diff_batch_size // self.batch_size_increment
self.ramup_samples = ramup_samples
assert self.ramup_samples >= 0
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0)
def update(self, consumed_samples):
if consumed_samples > self.ramup_samples:
current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert current_global_batch_size <= self.global_batch_size
assert current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = current_global_batch_size // \
self.micro_batch_times_data_parallel_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment