Commit 99493a85 authored by Myle Ott's avatar Myle Ott
Browse files

Save number of GPUs in args (and checkpoints)

parent bd46c5ec
......@@ -53,18 +53,18 @@ def main():
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
args.num_gpus = torch.cuda.device_count()
print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
num_gpus = torch.cuda.device_count()
print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
num_gpus, args.max_tokens, args.max_sentences))
args.num_gpus, args.max_tokens, args.max_sentences))
# Build model and criterion
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
......@@ -102,11 +102,11 @@ def main():
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, epoch, batch_offset, trainer, dataset, max_positions_train, num_gpus)
train(args, epoch, batch_offset, trainer, dataset, max_positions_train)
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset, num_gpus)
val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset)
if k == 0:
if not args.no_save:
# save checkpoint
......@@ -130,7 +130,7 @@ def get_perplexity(loss):
return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
def train(args, epoch, batch_offset, trainer, dataset, max_positions):
"""Train the model for one epoch."""
seed = args.seed + epoch
......@@ -152,7 +152,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
lr = trainer.get_lr()
with utils.build_progress_bar(args, itr, epoch) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
for i, sample in data.skip_group_enumerator(t, args.num_gpus, batch_offset):
loss_dict = trainer.train_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
......@@ -222,7 +222,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
def validate(args, epoch, trainer, dataset, max_positions, subset):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.eval_dataloader(
......@@ -236,7 +236,7 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
prefix = 'valid on \'{}\' subset'.format(subset)
with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, ngpus):
for _, sample in data.skip_group_enumerator(t, args.num_gpus):
loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment