Commit 6c521eb1 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'staging_args' into 'staging'

Major refactor: args, global variables, tokenizer

See merge request ADLR/megatron-lm!36
parents 37ff534f b07f1d0a
...@@ -17,22 +17,23 @@ ...@@ -17,22 +17,23 @@
import torch import torch
from megatron import get_args
from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.tokenizer import add_tokenizer_to_args from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results from megatron.training import evaluate_and_print_results
from megatron.training import initialize_megatron
from megatron.training import setup_model_and_optimizer from megatron.training import setup_model_and_optimizer
from megatron.training import train_step from megatron.training import train_step
from megatron.training import training_log from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import load_checkpoint
from megatron.utils import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import save_checkpoint
def process_batch(batch, args): def process_batch(batch):
"""Process batch and produce inputs for the model.""" """Process batch and produce inputs for the model."""
args = get_args()
tokens = batch['text'].long().cuda().contiguous() tokens = batch['text'].long().cuda().contiguous()
types = batch['types'].long().cuda().contiguous() types = batch['types'].long().cuda().contiguous()
...@@ -44,8 +45,9 @@ def process_batch(batch, args): ...@@ -44,8 +45,9 @@ def process_batch(batch, args):
return tokens, types, labels, attention_mask return tokens, types, labels, attention_mask
def _cross_entropy_forward_step(batch, model, args, timers): def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss.""" """Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
...@@ -53,7 +55,7 @@ def _cross_entropy_forward_step(batch, model, args, timers): ...@@ -53,7 +55,7 @@ def _cross_entropy_forward_step(batch, model, args, timers):
batch_ = next(batch) batch_ = next(batch)
except: except:
batch_ = batch batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_, args) tokens, types, labels, attention_mask = process_batch(batch_)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
...@@ -101,8 +103,9 @@ def _build_infinite_size_dataloader(dataloader): ...@@ -101,8 +103,9 @@ def _build_infinite_size_dataloader(dataloader):
iterator = dataloader.__iter__() iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset, args): def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders.""" """Traing and validation dataloaders."""
args = get_args()
print_rank_0('building train and validation dataloaders ...') print_rank_0('building train and validation dataloaders ...')
# Training dataset. # Training dataset.
...@@ -121,9 +124,10 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset, args): ...@@ -121,9 +124,10 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset, args):
def _train(model, optimizer, lr_scheduler, forward_step, def _train(model, optimizer, lr_scheduler, forward_step,
train_dataloader, valid_dataloader, train_dataloader, valid_dataloader, end_of_epoch_callback):
end_of_epoch_callback, timers, args, writer):
"""Train the model.""" """Train the model."""
args = get_args()
timers = get_timers()
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() model.train()
...@@ -157,95 +161,99 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -157,95 +161,99 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
losses_dict, _ = train_step(forward_step, batch, model, optimizer, losses_dict, _ = train_step(forward_step, batch, model,
lr_scheduler, args, timers) optimizer, lr_scheduler)
iteration += 1 iteration += 1
# Logging. # Logging.
report_memory_flag = training_log(losses_dict, losses_dict_sum, report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale, iteration, optimizer.loss_scale,
report_memory_flag, writer, report_memory_flag)
args, timers)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer, check_adlr_autoresume_termination(iteration, model,
lr_scheduler, args) optimizer, lr_scheduler)
# Checkpointing # Checkpointing
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args) save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0: if args.eval_interval and iteration % args.eval_interval == 0:
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step, evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model, args, valid_dataloader, model,
writer, iteration, timers, False) iteration, False)
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args) save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Callback at the end of each epoch. # Callback at the end of each epoch.
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
end_of_epoch_callback(model, args, epoch) end_of_epoch_callback(model, epoch)
def finetune(args, train_valid_datasets_provider, model_provider, def finetune(train_valid_datasets_provider, model_provider,
forward_step=_cross_entropy_forward_step, forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None): end_of_epoch_callback_provider=None):
"""Main finetune function used across all tasks.""" """Main finetune function used across all tasks."""
args = get_args()
# Initialize megatron and get args, timers, and Tensorboard writer. timers = get_timers()
timers, writer = initialize_megatron(
'finetune model for {} ...'.format(args.task), args)
# Add tokenizer to the args.
add_tokenizer_to_args(args, args.tokenizer_type)
# Train and validation data loaders. # Train and validation data loaders.
timers('train/valid/test dataset/dataloder').start()
if args.epochs > 0: if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider(args) train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders( train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset, args) train_dataset, valid_dataset)
timers('train/valid/test dataset/dataloder').stop()
# Build calback function. # Build calback function.
timers('callback function').start()
end_of_epoch_callback = None end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None: if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider(args) end_of_epoch_callback = end_of_epoch_callback_provider()
timers('callback function').stop()
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, timers('model and optimizer').start()
args) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained # any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint. # checkpoint.
timers('pretrained checkpoint').start()
if args.iteration == 0 and args.pretrained_checkpoint is not None: if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load original_load = args.load
args.load = args.pretrained_checkpoint args.load = args.pretrained_checkpoint
_ = load_checkpoint(model, None, None, args) _ = load_checkpoint(model, None, None)
args.load = original_load args.load = original_load
# This is critical when only model is loaded. We should make sure # This is critical when only model is loaded. We should make sure
# master parameters are also updated. # master parameters are also updated.
if args.fp16: if args.fp16:
optimizer._model_params_to_master_params() optimizer._model_params_to_master_params()
timers('pretrained checkpoint').stop()
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['train/valid/test dataset/dataloder', 'callback function',
'model and optimizer', 'pretrained checkpoint'])
print_rank_0('training ...')
# Finetune the model. # Finetune the model.
if args.epochs > 0: if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step, _train(model, optimizer, lr_scheduler, forward_step,
train_dataloader, valid_dataloader, train_dataloader, valid_dataloader, end_of_epoch_callback)
end_of_epoch_callback, timers, args, writer)
# Or just evaluate. # Or just evaluate.
else: else:
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
print_rank_0('evaluation only mode, setting epoch to -1') print_rank_0('evaluation only mode, setting epoch to -1')
end_of_epoch_callback(model, args, epoch=-1, end_of_epoch_callback(model, epoch=-1, output_predictions=True)
output_predictions=True)
print_rank_0('done :-)') print_rank_0('done :-)')
...@@ -20,7 +20,7 @@ from abc import abstractmethod ...@@ -20,7 +20,7 @@ from abc import abstractmethod
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron.utils import print_rank_0 from megatron import print_rank_0
from tasks.data_utils import build_sample from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_text from tasks.data_utils import build_tokens_types_paddings_from_text
......
...@@ -15,60 +15,63 @@ ...@@ -15,60 +15,63 @@
"""GLUE finetuning/evaluation.""" """GLUE finetuning/evaluation."""
from megatron.utils import print_rank_0 from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.model.classification import Classification from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
def glue_classification(args, num_classes, Dataset, def glue_classification(num_classes, Dataset,
name_from_datapath_func): name_from_datapath_func):
def train_valid_datasets_provider(args): def train_valid_datasets_provider():
"""Build train and validation dataset.""" """Build train and validation dataset."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = Dataset('training', args.train_data, train_dataset = Dataset('training', args.train_data,
args.tokenizer, args.seq_length) tokenizer, args.seq_length)
valid_dataset = Dataset('validation', args.valid_data, valid_dataset = Dataset('validation', args.valid_data,
args.tokenizer, args.seq_length) tokenizer, args.seq_length)
return train_dataset, valid_dataset return train_dataset, valid_dataset
def model_provider(args): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building classification model for {} ...'.format( print_rank_0('building classification model for {} ...'.format(
args.task)) args.task))
return Classification(
num_classes=num_classes, return Classification(num_classes=num_classes, num_tokentypes=2)
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size, def metrics_func_provider():
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations)
def metrics_func_provider(args):
"""Privde metrics callback function.""" """Privde metrics callback function."""
def single_dataset_provider(datapath, args): def single_dataset_provider(datapath):
args = get_args()
tokenizer = get_tokenizer()
name = name_from_datapath_func(datapath) name = name_from_datapath_func(datapath)
return Dataset(name, [datapath], args.tokenizer, args.seq_length) return Dataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(args, single_dataset_provider) return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate.""" """Finetune/evaluate."""
finetune(args, train_valid_datasets_provider, model_provider, finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider) end_of_epoch_callback_provider=metrics_func_provider)
def main(args): def main():
args = get_args()
if args.task == 'MNLI': if args.task == 'MNLI':
num_classes = 3 num_classes = 3
from .mnli import MNLIDataset as Dataset from tasks.glue.mnli import MNLIDataset as Dataset
def name_from_datapath(datapath): def name_from_datapath(datapath):
return datapath.split('MNLI')[-1].strip( return datapath.split('MNLI')[-1].strip(
'.tsv').strip('/').replace('_', '-') '.tsv').strip('/').replace('_', '-')
...@@ -76,7 +79,7 @@ def main(args): ...@@ -76,7 +79,7 @@ def main(args):
elif args.task == 'QQP': elif args.task == 'QQP':
num_classes = 2 num_classes = 2
from .qqp import QQPDataset as Dataset from tasks.glue.qqp import QQPDataset as Dataset
def name_from_datapath(datapath): def name_from_datapath(datapath):
return datapath.split('QQP')[-1].strip( return datapath.split('QQP')[-1].strip(
'.tsv').strip('/').replace('_', '-') '.tsv').strip('/').replace('_', '-')
...@@ -85,4 +88,4 @@ def main(args): ...@@ -85,4 +88,4 @@ def main(args):
raise NotImplementedError('GLUE task {} is not implemented.'.format( raise NotImplementedError('GLUE task {} is not implemented.'.format(
args.task)) args.task))
glue_classification(args, num_classes, Dataset, name_from_datapath) glue_classification(num_classes, Dataset, name_from_datapath)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""MNLI dataset.""" """MNLI dataset."""
from megatron.utils import print_rank_0 from megatron import print_rank_0
from tasks.data_utils import clean_text from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset from .data import GLUEAbstractDataset
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""QQP dataset.""" """QQP dataset."""
from megatron.utils import print_rank_0 from megatron import print_rank_0
from tasks.data_utils import clean_text from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset from .data import GLUEAbstractDataset
......
...@@ -20,29 +20,38 @@ import sys ...@@ -20,29 +20,38 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir))) os.path.pardir)))
from arguments import get_args from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser): def get_tasks_args(parser):
"""Provide extra arguments required for tasks.""" """Provide extra arguments required for tasks."""
group = parser.add_argument_group('tasks', 'tasks configurations') group = parser.add_argument_group(title='tasks')
parser.add_argument('--task', type=str, required=True,
help='task name.') group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument('--epochs', type=int, required=True, group.add_argument('--epochs', type=int, required=True,
help='number of finetunning epochs. Zero results in ' help='Number of finetunning epochs. Zero results in '
'evaluation only.') 'evaluation only.')
parser.add_argument('--pretrained-checkpoint', type=str, default=None, group.add_argument('--pretrained-checkpoint', type=str, default=None,
help='pretrained checkpoint used for finetunning.') help='Pretrained checkpoint used for finetunning.')
group.add_argument('--keep-last', action='store_true', group.add_argument('--keep-last', action='store_true',
help='keep the last batch (maybe incomplete) in' help='Keep the last batch (maybe incomplete) in'
'the data loader') 'the data loader')
group.add_argument('--train-data', nargs='+', default=None,
help='Whitespace separated paths or corpora names '
'for training.')
group.add_argument('--valid-data', nargs='*', default=None,
help='path(s) to the validation data.')
return parser return parser
if __name__ == '__main__': if __name__ == '__main__':
args = get_args(extra_args_provider=get_tasks_args) initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.task == 'RACE': if args.task == 'RACE':
from race.finetune import main from race.finetune import main
elif args.task in ['MNLI', 'QQP']: elif args.task in ['MNLI', 'QQP']:
...@@ -51,4 +60,4 @@ if __name__ == '__main__': ...@@ -51,4 +60,4 @@ if __name__ == '__main__':
raise NotImplementedError('Task {} is not implemented.'.format( raise NotImplementedError('Task {} is not implemented.'.format(
args.task)) args.task))
main(args) main()
...@@ -6,7 +6,7 @@ import time ...@@ -6,7 +6,7 @@ import time
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron.utils import print_rank_0 from megatron import print_rank_0
from tasks.data_utils import build_sample from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_ids from tasks.data_utils import build_tokens_types_paddings_from_ids
from tasks.data_utils import clean_text from tasks.data_utils import clean_text
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Race.""" """Race."""
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.model.multiple_choice import MultipleChoice from megatron.model.multiple_choice import MultipleChoice
from megatron.utils import print_rank_0
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset from tasks.race.data import RaceDataset
def train_valid_datasets_provider(args): def train_valid_datasets_provider():
"""Provide train and validation datasets.""" """Provide train and validation datasets."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = RaceDataset('training', args.train_data, train_dataset = RaceDataset('training', args.train_data,
args.tokenizer, args.seq_length) tokenizer, args.seq_length)
valid_dataset = RaceDataset('validation', args.valid_data, valid_dataset = RaceDataset('validation', args.valid_data,
args.tokenizer, args.seq_length) tokenizer, args.seq_length)
return train_dataset, valid_dataset return train_dataset, valid_dataset
def model_provider(args): def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building multichoice model for RACE ...') print_rank_0('building multichoice model for RACE ...')
return MultipleChoice( return MultipleChoice(num_tokentypes=2)
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations)
def metrics_func_provider(args): def metrics_func_provider():
"""Privde metrics callback function.""" """Privde metrics callback function."""
args = get_args()
tokenizer = get_tokenizer()
def single_dataset_provider(datapath, args): def single_dataset_provider(datapath):
name = datapath.split('RACE')[-1].strip('/').replace('/', '-') name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
return RaceDataset(name, [datapath], args.tokenizer, args.seq_length) return RaceDataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(args, single_dataset_provider) return accuracy_func_provider(single_dataset_provider)
def main(args): def main():
finetune(args, train_valid_datasets_provider, model_provider, finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider) end_of_epoch_callback_provider=metrics_func_provider)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment