Commit bf3ce751 authored by Mohammad's avatar Mohammad
Browse files

addressed comments from raul, neel, and jared

parent 8600642e
......@@ -76,12 +76,11 @@ class RandomSampler(data.sampler.Sampler):
class DistributedBatchSampler(data.sampler.BatchSampler):
"""
similar to normal implementation of distributed sampler, except
"""Similar to normal implementation of distributed sampler, except
implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
"""
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler."""
def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size,
......
......@@ -141,59 +141,60 @@ def _ensure_var_is_not_initialized(var, name):
assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
class Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = self.Timer(name)
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
......@@ -212,7 +213,7 @@ class Timers:
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0/ normalizer
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
......
......@@ -17,8 +17,8 @@
import random
import os
import numpy as np
import numpy as np
import torch
from megatron import get_adlr_autoresume
......@@ -31,7 +31,7 @@ from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}):
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
# Male sure cuda is avaiable.
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
......@@ -45,7 +45,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
# Autoresume.
_init_autoresume()
# Random seeds for reproducability.
# Random seeds for reproducibility.
args = get_args()
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
......
......@@ -97,7 +97,7 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
print_rank_0('training ...')
iteration = 0
if args.train_iters > 0:
if args.do_train and args.train_iters > 0:
if args.do_train:
iteration, _ = train(forward_step_func,
model, optimizer, lr_scheduler,
......@@ -151,7 +151,7 @@ def get_model(model_provider_func):
model = LocalDDP(model)
return model
print_rank_0('Unknown DDP implementation specified: {}. '
raise NotImplementedError('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
sys.exit()
......@@ -385,8 +385,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \
args.adlr_autoresume:
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler)
......
......@@ -109,13 +109,13 @@ def make_gpt2_dataloaders():
initial_seed = args.seed
# Build the datasets.
def build_dataset_(name):
def _build_dataset(name):
return GPT2Dataset(os.path.join(args.data_path, name),
args.input_data_sizes_file,
args.seq_length, args.seed)
train_ds = build_dataset_('train')
valid_ds = build_dataset_('valid')
test_ds = build_dataset_('test')
train_ds = _build_dataset('train')
valid_ds = _build_dataset('valid')
test_ds = _build_dataset('test')
# Dataloaders
train = make_data_loader(train_ds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment