Commit 601b19b7 authored by Mohammad's avatar Mohammad
Browse files

tasks tested

parent 259062c2
......@@ -20,26 +20,28 @@ import time
import torch
from megatron import get_args
from megatron import mpu
from megatron import print_rank_0
from .finetune_utils import build_data_loader
from .finetune_utils import process_batch
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
def accuracy_func_provider(args, single_dataset_provider):
def accuracy_func_provider(single_dataset_provider):
"""Provide function that calculates accuracies."""
args = get_args()
# Build dataloaders.
datapaths = args.valid_data
dataloaders = []
for datapath in datapaths:
dataset = single_dataset_provider(datapath, args)
dataset = single_dataset_provider(datapath)
dataloader = build_data_loader(
dataset, args.batch_size, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1))
dataloaders.append((dataset.dataset_name, dataloader))
def metrics_func(model, args_, epoch, output_predictions=False):
def metrics_func(model, epoch, output_predictions=False):
print_rank_0('calculating metrics ...')
correct = 0
total = 0
......@@ -48,7 +50,7 @@ def accuracy_func_provider(args, single_dataset_provider):
named_predictions = []
names = 'predictions'
for name, dataloader in dataloaders:
output = calculate_correct_answers(name, model, dataloader, args_,
output = calculate_correct_answers(name, model, dataloader,
epoch, output_predictions)
if not output_predictions:
correct_ans, total_count = output
......@@ -70,7 +72,7 @@ def accuracy_func_provider(args, single_dataset_provider):
return metrics_func
def calculate_correct_answers(name, model, dataloader, args,
def calculate_correct_answers(name, model, dataloader,
epoch, output_predictions):
"""Calculate correct over total answers and return prediction if the
`output_predictions` is true."""
......@@ -89,7 +91,7 @@ def calculate_correct_answers(name, model, dataloader, args,
ids = []
for _, batch in enumerate(dataloader):
# Run the model forward.
tokens, types, labels_, attention_mask = process_batch(batch, args)
tokens, types, labels_, attention_mask = process_batch(batch)
logits = model(tokens, attention_mask, types)
# Add output predictions.
if output_predictions:
......
......@@ -17,22 +17,23 @@
import torch
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron.data.tokenizer import add_tokenizer_to_args
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
from megatron.training import initialize_megatron
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import load_checkpoint
from megatron import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import save_checkpoint
def process_batch(batch, args):
def process_batch(batch):
"""Process batch and produce inputs for the model."""
args = get_args()
tokens = batch['text'].long().cuda().contiguous()
types = batch['types'].long().cuda().contiguous()
......@@ -44,8 +45,9 @@ def process_batch(batch, args):
return tokens, types, labels, attention_mask
def _cross_entropy_forward_step(batch, model, args, timers):
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers('batch generator').start()
......@@ -53,7 +55,7 @@ def _cross_entropy_forward_step(batch, model, args, timers):
batch_ = next(batch)
except:
batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_, args)
tokens, types, labels, attention_mask = process_batch(batch_)
timers('batch generator').stop()
# Forward model.
......@@ -101,8 +103,9 @@ def _build_infinite_size_dataloader(dataloader):
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset, args):
def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
......@@ -121,9 +124,10 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset, args):
def _train(model, optimizer, lr_scheduler, forward_step,
train_dataloader, valid_dataloader,
end_of_epoch_callback, timers, args, writer):
train_dataloader, valid_dataloader, end_of_epoch_callback):
"""Train the model."""
args = get_args()
timers = get_timers()
# Turn on training mode which enables dropout.
model.train()
......@@ -157,95 +161,99 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0
# Train for one step.
losses_dict, _ = train_step(forward_step, batch, model, optimizer,
lr_scheduler, args, timers)
losses_dict, _ = train_step(forward_step, batch, model,
optimizer, lr_scheduler)
iteration += 1
# Logging.
report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale,
report_memory_flag, writer,
args, timers)
report_memory_flag)
# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler)
# Checkpointing
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model, args,
writer, iteration, timers, False)
valid_dataloader, model,
iteration, False)
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
end_of_epoch_callback(model, args, epoch)
end_of_epoch_callback(model, epoch)
def finetune(args, train_valid_datasets_provider, model_provider,
def finetune(train_valid_datasets_provider, model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None):
"""Main finetune function used across all tasks."""
# Initialize megatron and get args, timers, and Tensorboard writer.
timers, writer = initialize_megatron(
'finetune model for {} ...'.format(args.task), args)
# Add tokenizer to the args.
add_tokenizer_to_args(args, args.tokenizer_type)
args = get_args()
timers = get_timers()
# Train and validation data loaders.
timers('train/valid/test dataset/dataloder').start()
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider(args)
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset, args)
train_dataset, valid_dataset)
timers('train/valid/test dataset/dataloder').stop()
# Build calback function.
timers('callback function').start()
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider(args)
end_of_epoch_callback = end_of_epoch_callback_provider()
timers('callback function').stop()
# Build model, optimizer and learning rate scheduler.
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
args)
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers('pretrained checkpoint').start()
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
_ = load_checkpoint(model, None, None, args)
_ = load_checkpoint(model, None, None)
args.load = original_load
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
if args.fp16:
optimizer._model_params_to_master_params()
timers('pretrained checkpoint').stop()
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['train/valid/test dataset/dataloder', 'callback function',
'model and optimizer', 'pretrained checkpoint'])
print_rank_0('training ...')
# Finetune the model.
if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step,
train_dataloader, valid_dataloader,
end_of_epoch_callback, timers, args, writer)
train_dataloader, valid_dataloader, end_of_epoch_callback)
# Or just evaluate.
else:
if end_of_epoch_callback is not None:
print_rank_0('evaluation only mode, setting epoch to -1')
end_of_epoch_callback(model, args, epoch=-1,
output_predictions=True)
end_of_epoch_callback(model, epoch=-1, output_predictions=True)
print_rank_0('done :-)')
......@@ -15,32 +15,41 @@
"""GLUE finetuning/evaluation."""
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
def glue_classification(args, num_classes, Dataset,
def glue_classification(num_classes, Dataset,
name_from_datapath_func):
def train_valid_datasets_provider(args):
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = Dataset('training', args.train_data,
args.tokenizer, args.seq_length)
tokenizer, args.seq_length)
valid_dataset = Dataset('validation', args.valid_data,
args.tokenizer, args.seq_length)
tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(args):
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building classification model for {} ...'.format(
args.task))
return Classification(
num_classes=num_classes,
num_layers=args.num_layers,
vocab_size=args.vocab_size,
vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
......@@ -50,25 +59,29 @@ def glue_classification(args, num_classes, Dataset,
checkpoint_activations=args.checkpoint_activations)
def metrics_func_provider(args):
def metrics_func_provider():
"""Privde metrics callback function."""
def single_dataset_provider(datapath, args):
def single_dataset_provider(datapath):
args = get_args()
tokenizer = get_tokenizer()
name = name_from_datapath_func(datapath)
return Dataset(name, [datapath], args.tokenizer, args.seq_length)
return accuracy_func_provider(args, single_dataset_provider)
return Dataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate."""
finetune(args, train_valid_datasets_provider, model_provider,
finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
def main(args):
def main():
args = get_args()
if args.task == 'MNLI':
num_classes = 3
from .mnli import MNLIDataset as Dataset
from tasks.glue.mnli import MNLIDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('MNLI')[-1].strip(
'.tsv').strip('/').replace('_', '-')
......@@ -76,7 +89,7 @@ def main(args):
elif args.task == 'QQP':
num_classes = 2
from .qqp import QQPDataset as Dataset
from tasks.glue.qqp import QQPDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('QQP')[-1].strip(
'.tsv').strip('/').replace('_', '-')
......@@ -85,4 +98,4 @@ def main(args):
raise NotImplementedError('GLUE task {} is not implemented.'.format(
args.task))
glue_classification(args, num_classes, Dataset, name_from_datapath)
glue_classification(num_classes, Dataset, name_from_datapath)
......@@ -20,29 +20,38 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from arguments import get_args
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group('tasks', 'tasks configurations')
parser.add_argument('--task', type=str, required=True,
help='task name.')
group = parser.add_argument_group(title='tasks')
group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument('--epochs', type=int, required=True,
help='number of finetunning epochs. Zero results in '
help='Number of finetunning epochs. Zero results in '
'evaluation only.')
parser.add_argument('--pretrained-checkpoint', type=str, default=None,
help='pretrained checkpoint used for finetunning.')
group.add_argument('--pretrained-checkpoint', type=str, default=None,
help='Pretrained checkpoint used for finetunning.')
group.add_argument('--keep-last', action='store_true',
help='keep the last batch (maybe incomplete) in'
help='Keep the last batch (maybe incomplete) in'
'the data loader')
group.add_argument('--train-data', nargs='+', default=None,
help='Whitespace separated paths or corpora names '
'for training.')
group.add_argument('--valid-data', nargs='*', default=None,
help='path(s) to the validation data.')
return parser
if __name__ == '__main__':
args = get_args(extra_args_provider=get_tasks_args)
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.task == 'RACE':
from race.finetune import main
elif args.task in ['MNLI', 'QQP']:
......@@ -51,4 +60,4 @@ if __name__ == '__main__':
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main(args)
main()
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Race."""
from megatron.model.multiple_choice import MultipleChoice
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset
def train_valid_datasets_provider(args):
def train_valid_datasets_provider():
"""Provide train and validation datasets."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = RaceDataset('training', args.train_data,
args.tokenizer, args.seq_length)
tokenizer, args.seq_length)
valid_dataset = RaceDataset('validation', args.valid_data,
args.tokenizer, args.seq_length)
tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(args):
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building multichoice model for RACE ...')
return MultipleChoice(
num_layers=args.num_layers,
vocab_size=args.vocab_size,
vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
......@@ -35,17 +55,19 @@ def model_provider(args):
checkpoint_activations=args.checkpoint_activations)
def metrics_func_provider(args):
def metrics_func_provider():
"""Privde metrics callback function."""
args = get_args()
tokenizer = get_tokenizer()
def single_dataset_provider(datapath, args):
def single_dataset_provider(datapath):
name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
return RaceDataset(name, [datapath], args.tokenizer, args.seq_length)
return RaceDataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(args, single_dataset_provider)
return accuracy_func_provider(single_dataset_provider)
def main(args):
def main():
finetune(args, train_valid_datasets_provider, model_provider,
finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment