Commit bf3ce751 authored by Mohammad's avatar Mohammad
Browse files

addressed comments from raul, neel, and jared

parent 8600642e
...@@ -76,12 +76,11 @@ class RandomSampler(data.sampler.Sampler): ...@@ -76,12 +76,11 @@ class RandomSampler(data.sampler.Sampler):
class DistributedBatchSampler(data.sampler.BatchSampler): class DistributedBatchSampler(data.sampler.BatchSampler):
""" """Similar to normal implementation of distributed sampler, except
similar to normal implementation of distributed sampler, except
implementation is at the batch sampler level, instead of just the implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch sampler. (sequential, random, WeightedRandomSampler, etc.) with this batch
""" sampler."""
def __init__(self, sampler, batch_size, drop_last, rank=-1, def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False): world_size=2, wrap_last=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, super(DistributedBatchSampler, self).__init__(sampler, batch_size,
......
...@@ -141,59 +141,60 @@ def _ensure_var_is_not_initialized(var, name): ...@@ -141,59 +141,60 @@ def _ensure_var_is_not_initialized(var, name):
assert var is None, '{} is already initialized.'.format(name) assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers: class Timers:
"""Group of timers.""" """Group of timers."""
class Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
def __init__(self): def __init__(self):
self.timers = {} self.timers = {}
def __call__(self, name): def __call__(self, name):
if name not in self.timers: if name not in self.timers:
self.timers[name] = self.Timer(name) self.timers[name] = _Timer(name)
return self.timers[name] return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False): def write(self, names, writer, iteration, normalizer=1.0, reset=False):
...@@ -212,7 +213,7 @@ class Timers: ...@@ -212,7 +213,7 @@ class Timers:
string = 'time (ms)' string = 'time (ms)'
for name in names: for name in names:
elapsed_time = self.timers[name].elapsed( elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0/ normalizer reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time) string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
import random import random
import os import os
import numpy as np
import numpy as np
import torch import torch
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
...@@ -31,7 +31,7 @@ from megatron.global_vars import set_global_variables ...@@ -31,7 +31,7 @@ from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}): def initialize_megatron(extra_args_provider=None, args_defaults={}):
"""Set global variables, initialize distributed, and """Set global variables, initialize distributed, and
set autoresume and random seeds.""" set autoresume and random seeds."""
# Male sure cuda is avaiable. # Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.' assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume, # Parse args, build tokenizer, and set adlr-autoresume,
...@@ -45,7 +45,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}): ...@@ -45,7 +45,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
# Random seeds for reproducability. # Random seeds for reproducibility.
args = get_args() args = get_args()
if args.rank == 0: if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed)) print('> setting random seeds to {} ...'.format(args.seed))
......
...@@ -97,7 +97,7 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func, ...@@ -97,7 +97,7 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
if args.train_iters > 0: if args.do_train and args.train_iters > 0:
if args.do_train: if args.do_train:
iteration, _ = train(forward_step_func, iteration, _ = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
...@@ -151,7 +151,7 @@ def get_model(model_provider_func): ...@@ -151,7 +151,7 @@ def get_model(model_provider_func):
model = LocalDDP(model) model = LocalDDP(model)
return model return model
print_rank_0('Unknown DDP implementation specified: {}. ' raise NotImplementedError('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl)) 'Exiting.'.format(args.DDP_impl))
sys.exit() sys.exit()
...@@ -385,8 +385,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -385,8 +385,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag) report_memory_flag)
# Autoresume # Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \ if args.adlr_autoresume and \
args.adlr_autoresume: (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer, check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler) lr_scheduler)
......
...@@ -109,13 +109,13 @@ def make_gpt2_dataloaders(): ...@@ -109,13 +109,13 @@ def make_gpt2_dataloaders():
initial_seed = args.seed initial_seed = args.seed
# Build the datasets. # Build the datasets.
def build_dataset_(name): def _build_dataset(name):
return GPT2Dataset(os.path.join(args.data_path, name), return GPT2Dataset(os.path.join(args.data_path, name),
args.input_data_sizes_file, args.input_data_sizes_file,
args.seq_length, args.seed) args.seq_length, args.seed)
train_ds = build_dataset_('train') train_ds = _build_dataset('train')
valid_ds = build_dataset_('valid') valid_ds = _build_dataset('valid')
test_ds = build_dataset_('test') test_ds = _build_dataset('test')
# Dataloaders # Dataloaders
train = make_data_loader(train_ds) train = make_data_loader(train_ds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment