Commit 90ef2e28 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Raul Puri
Browse files

BERT downstream tasks: RACE, MNLI, QQP

parent 41141526
...@@ -346,10 +346,10 @@ def add_data_args(parser): ...@@ -346,10 +346,10 @@ def add_data_args(parser):
help='path used to save/load sentencepiece tokenization ' help='path used to save/load sentencepiece tokenization '
'models') 'models')
group.add_argument('--tokenizer-type', type=str, group.add_argument('--tokenizer-type', type=str,
default='BertWordPieceTokenizer', default='BertWordPieceLowerCase',
choices=['CharacterLevelTokenizer', choices=['CharacterLevelTokenizer',
'SentencePieceTokenizer', 'SentencePieceTokenizer',
'BertWordPieceTokenizer', 'BertWordPieceLowerCase',
'GPT2BPETokenizer'], 'GPT2BPETokenizer'],
help='what type of tokenizer to use') help='what type of tokenizer to use')
group.add_argument("--cache-dir", default=None, type=str, group.add_argument("--cache-dir", default=None, type=str,
...@@ -358,7 +358,7 @@ def add_data_args(parser): ...@@ -358,7 +358,7 @@ def add_data_args(parser):
return parser return parser
def get_args(): def get_args(extra_args_provider=None):
"""Parse all the args.""" """Parse all the args."""
parser = argparse.ArgumentParser(description='PyTorch BERT Model') parser = argparse.ArgumentParser(description='PyTorch BERT Model')
...@@ -368,6 +368,8 @@ def get_args(): ...@@ -368,6 +368,8 @@ def get_args():
parser = add_evaluation_args(parser) parser = add_evaluation_args(parser)
parser = add_text_generate_args(parser) parser = add_text_generate_args(parser)
parser = add_data_args(parser) parser = add_data_args(parser)
if extra_args_provider is not None:
parser = extra_args_provider(parser)
args = parser.parse_args() args = parser.parse_args()
......
"""Megatron tokenizer."""
from abc import ABC
from abc import abstractmethod
from megatron.utils import vocab_size_with_padding
from .bert_tokenization import FullTokenizer as FullBertTokenizer
def add_tokenizer_to_args(args, tokenizer_type):
"""Instantiate tokenizer based on input type and add it to args."""
# Make sure we have not already called this method.
if hasattr(args, 'tokenizer'):
raise Exception('args already has a tokenizer')
# Select and instantiate the tokenizer.
if tokenizer_type == 'BertWordPieceLowerCase':
args.tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab,
lower_case=True)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(tokenizer_type))
# Add vocab size.
args.vocab_size = vocab_size_with_padding(args.tokenizer.vocab_size, args)
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@abstractmethod
def tokenize(self, text):
pass
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True):
if lower_case:
name = 'BERT Lower Case'
else:
name = 'BERT Upper Case'
super().__init__(name)
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
@property
def vocab_size(self):
return self.tokenizer.vocab_size()
def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
@property
def cls(self):
return self.cls_id
@property
def sep(self):
return self.sep_id
@property
def pad(self):
return self.pad_id
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classification model."""
import torch
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron.utils import print_rank_0
class Classification(MegatronModule):
def __init__(self,
num_classes,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=2,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(Classification, self).__init__()
self.num_classes = num_classes
init_method = init_method_normal(init_method_std)
self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes,
add_pooler=True,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
# Multi-choice head.
self.classification_dropout = torch.nn.Dropout(output_dropout_prob)
self.classification_head = get_linear_layer(hidden_size,
self.num_classes,
init_method)
self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output)
# Reshape back to separate choices.
classification_logits = classification_logits.view(-1, self.num_classes)
return classification_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._classification_head_key in state_dict:
self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multiple choice model."""
import torch
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron.utils import print_rank_0
class MultipleChoice(MegatronModule):
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=2,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(MultipleChoice, self).__init__()
init_method = init_method_normal(init_method_std)
self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes,
add_pooler=True,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
# Multi-choice head.
self.multichoice_dropout = torch.nn.Dropout(output_dropout_prob)
self.multichoice_head = get_linear_layer(hidden_size, 1, init_method)
self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
# transformer --> [batch, choices] --> softmax
# Ensure the shape is [batch-size, choices, sequence]
assert len(input_ids.shape) == 3
assert len(attention_mask.shape) == 3
assert len(tokentype_ids.shape) == 3
# Reshape and treat choice dimension the same as batch.
num_choices = input_ids.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1))
attention_mask = attention_mask.view(-1, attention_mask.size(-1))
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
multichoice_output = self.multichoice_dropout(pooled_output)
multichoice_logits = self.multichoice_head(multichoice_output)
# Reshape back to separate choices.
multichoice_logits = multichoice_logits.view(-1, num_choices)
return multichoice_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._multichoice_head_key))
...@@ -43,7 +43,7 @@ from megatron.utils import Timers ...@@ -43,7 +43,7 @@ from megatron.utils import Timers
def run(top_level_message, train_val_test_data_provider, def run(top_level_message, train_val_test_data_provider,
model_provider, forward_step_func): model_provider, forward_step_func, extra_args_provider=None):
"""Main training program. """Main training program.
This function will run the followings in the order provided: This function will run the followings in the order provided:
...@@ -71,17 +71,9 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -71,17 +71,9 @@ def run(top_level_message, train_val_test_data_provider,
function add `batch generator` to the timers class. function add `batch generator` to the timers class.
""" """
# Arguments. # Initalize and get arguments, timers, and Tensorboard writer.
args = get_args() args = get_args(extra_args_provider=extra_args_provider)
timers, writer = initialize_megatron(top_level_message, args)
# Timer.
timers = Timers()
# Tensorboard writer
writer = get_tensorboard_writer(args)
# Initalize.
initialize_megatron(top_level_message, args, writer)
# Data stuff. # Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args) train_data, val_data, test_data = train_val_test_data_provider(args)
...@@ -124,9 +116,15 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -124,9 +116,15 @@ def run(top_level_message, train_val_test_data_provider,
args, None, 0, timers, True) args, None, 0, timers, True)
def initialize_megatron(message, args, writer): def initialize_megatron(message, args):
""""Initialize distributed, random seed, and autoresume.""" """"Initialize distributed, random seed, and autoresume."""
# Timer.
timers = Timers()
# Tensorboard writer.
writer = get_tensorboard_writer(args)
# Pytorch distributed. # Pytorch distributed.
initialize_distributed(args) initialize_distributed(args)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -141,6 +139,8 @@ def initialize_megatron(message, args, writer): ...@@ -141,6 +139,8 @@ def initialize_megatron(message, args, writer):
# Random seeds for reproducability. # Random seeds for reproducability.
set_random_seed(args.seed) set_random_seed(args.seed)
return timers, writer
def get_model(model_provider_func, args): def get_model(model_provider_func, args):
"""Build the model.""" """Build the model."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tasks data utility."""
import re
import numpy as np
def clean_text(text):
"""Remove new lines and multiple spaces and adjust end of sentence dot."""
text = text.replace("\n", " ")
text = re.sub(r'\s+', ' ', text)
for _ in range(3):
text = text.replace(' . ', '. ')
return text
def build_sample(ids, types, paddings, label, unique_id):
"""Convert to numpy and return a sample consumed by the batch producer."""
ids_np = np.array(ids, dtype=np.int64)
types_np = np.array(types, dtype=np.int64)
paddings_np = np.array(paddings, dtype=np.int64)
sample = ({'text': ids_np,
'types': types_np,
'padding_mask': paddings_np,
'label': int(label),
'uid': int(unique_id)})
return sample
def build_tokens_types_paddings_from_text(text_a, text_b,
tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
text_a_ids = tokenizer.tokenize(text_a)
text_b_ids = None
if text_b is not None:
text_b_ids = tokenizer.tokenize(text_b)
return build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids,
max_seq_length, tokenizer.cls,
tokenizer.sep, tokenizer.pad)
def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
ids = []
types = []
paddings = []
# [CLS].
ids.append(cls_id)
types.append(0)
paddings.append(1)
# A.
len_text_a = len(text_a_ids)
ids.extend(text_a_ids)
types.extend([0]*len_text_a)
paddings.extend([1]*len_text_a)
# [SEP].
ids.append(sep_id)
types.append(0)
paddings.append(1)
# B.
if text_b_ids is not None:
len_text_b = len(text_b_ids)
ids.extend(text_b_ids)
types.extend([1]*len_text_b)
paddings.extend([1]*len_text_b)
# Cap the size.
trimmed = False
if len(ids) >= max_seq_length:
max_seq_length_m1 = max_seq_length - 1
ids = ids[0:max_seq_length_m1]
types = types[0:max_seq_length_m1]
paddings = paddings[0:max_seq_length_m1]
trimmed = True
# [SEP].
if (text_b_ids is not None) or trimmed:
ids.append(sep_id)
if text_b_ids is None:
types.append(0)
else:
types.append(1)
paddings.append(1)
# Padding.
padding_length = max_seq_length - len(ids)
if padding_length > 0:
ids.extend([pad_id]*padding_length)
types.extend([pad_id]*padding_length)
paddings.extend([0]*padding_length)
return ids, types, paddings
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
import os
import time
import torch
from megatron import mpu
from megatron.utils import print_rank_0
from .finetune_utils import build_data_loader
from .finetune_utils import process_batch
def accuracy_func_provider(args, single_dataset_provider):
"""Provide function that calculates accuracies."""
# Build dataloaders.
datapaths = args.valid_data
dataloaders = []
for datapath in datapaths:
dataset = single_dataset_provider(datapath, args)
dataloader = build_data_loader(
dataset, args.batch_size, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1))
dataloaders.append((dataset.dataset_name, dataloader))
def metrics_func(model, args_, epoch, output_predictions=False):
print_rank_0('calculating metrics ...')
correct = 0
total = 0
if output_predictions:
assert mpu.get_data_parallel_world_size() == 1
named_predictions = []
names = 'predictions'
for name, dataloader in dataloaders:
output = calculate_correct_answers(name, model, dataloader, args_,
epoch, output_predictions)
if not output_predictions:
correct_ans, total_count = output
else:
correct_ans, total_count, predictions = output
named_predictions.append((name, predictions))
names += '_' + name
correct += correct_ans
total += total_count
percent = float(correct) * 100.0 / float(total)
print_rank_0(' >> |epoch: {}| overall: correct / total = {} / {} = '
'{:.4f} %'.format(epoch, correct, total, percent))
if output_predictions and torch.distributed.get_rank() == 0:
assert args.load is not None
filename = os.path.join(args.load, names + '.pt')
torch.save(named_predictions, filename)
return metrics_func
def calculate_correct_answers(name, model, dataloader, args,
epoch, output_predictions):
"""Calculate correct over total answers and return prediction if the
`output_predictions` is true."""
start_time = time.time()
model.eval()
with torch.no_grad():
# For all the batches in the dataset.
total = 0
correct = 0
if output_predictions:
# This option is only possible when data parallel size is 1.
assert mpu.get_data_parallel_world_size() == 1
softmaxes = []
labels = []
ids = []
for _, batch in enumerate(dataloader):
# Run the model forward.
tokens, types, labels_, attention_mask = process_batch(batch, args)
logits = model(tokens, attention_mask, types)
# Add output predictions.
if output_predictions:
softmaxes.extend(torch.nn.Softmax(dim=-1)(
logits.float()).data.cpu().numpy().tolist())
labels.extend(labels_.data.cpu().numpy().tolist())
ids.extend(batch['uid'].cpu().numpy().tolist())
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels_)
# Add to the counters.
total += labels_.size(0)
correct += corrects.sum().item()
model.train()
# Reduce.
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced,
group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
percent = float(correct_ans) * 100.0 / float(total_count)
elapsed_time = time.time() - start_time
print_rank_0(' > |epoch: {}| metrics for {}: correct / total '
'= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
epoch, name, correct_ans, total_count,
percent, elapsed_time))
if output_predictions:
return correct_ans, total_count, (softmaxes, labels, ids)
return correct_ans, total_count
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetune utilities."""
import torch
from megatron import mpu
from megatron.data.tokenizer import add_tokenizer_to_args
from megatron.training import evaluate_and_print_results
from megatron.training import initialize_megatron
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import load_checkpoint
from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import save_checkpoint
def process_batch(batch, args):
"""Process batch and produce inputs for the model."""
tokens = batch['text'].long().cuda().contiguous()
types = batch['types'].long().cuda().contiguous()
labels = batch['label'].long().cuda().contiguous()
attention_mask = batch['padding_mask'].float().cuda().contiguous()
if args.fp16:
attention_mask = attention_mask.half()
return tokens, types, labels, attention_mask
def _cross_entropy_forward_step(batch, model, args, timers):
"""Simple forward step with cross-entropy loss."""
# Get the batch.
timers('batch generator').start()
try:
batch_ = next(batch)
except:
batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_, args)
timers('batch generator').stop()
# Forward model.
logits = model(tokens, attention_mask, types)
# Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging.
reduced_loss = reduce_losses([loss])
return loss, {'lm loss': reduced_loss[0]}
def build_data_loader(dataset, batch_size, num_workers, drop_last):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True)
return data_loader
def _build_infinite_size_dataloader(dataloader):
"""Build a looped dataloader with infinite size."""
iterator = dataloader.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset, args):
"""Traing and validation dataloaders."""
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.batch_size,
args.num_workers, not args.keep_last)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.batch_size,
args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
return train_dataloader, valid_dataloader
def _train(model, optimizer, lr_scheduler, forward_step,
train_dataloader, valid_dataloader,
end_of_epoch_callback, timers, args, writer):
"""Train the model."""
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
losses_dict_sum = {}
# Starting epoch and iteration
start_epoch = args.iteration // args.train_iters_per_epoch
start_iteration = args.iteration % args.train_iters_per_epoch
iteration = args.iteration
# Memory reporting flag.
report_memory_flag = True
# For each remaining epoch
timers('interval time').start()
for epoch in range(start_epoch, args.epochs):
print_rank_0('working on epoch {} ...'.format(epoch+1))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch)
# For all the batches in the dataset.
for iteration_, batch in enumerate(train_dataloader):
# Ignore the iterations before starting value
if iteration_ < start_iteration:
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration = 0
# Train for one step.
losses_dict, _ = train_step(forward_step, batch, model, optimizer,
lr_scheduler, args, timers)
iteration += 1
# Logging.
report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale,
report_memory_flag, writer,
args, timers)
# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
# Checkpointing
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model, args,
writer, iteration, timers, False)
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
end_of_epoch_callback(model, args, epoch)
def finetune(args, train_valid_datasets_provider, model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None):
"""Main finetune function used across all tasks."""
# Initialize megatron and get args, timers, and Tensorboard writer.
timers, writer = initialize_megatron(
'finetune model for {} ...'.format(args.task), args)
# Add tokenizer to the args.
add_tokenizer_to_args(args, args.tokenizer_type)
# Train and validation data loaders.
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider(args)
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset, args)
# Build calback function.
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider(args)
# Build model, optimizer and learning rate scheduler.
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
args)
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
_ = load_checkpoint(model, None, None, args)
args.load = original_load
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
if args.fp16:
optimizer._model_params_to_master_params()
# Finetune the model.
if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step,
train_dataloader, valid_dataloader,
end_of_epoch_callback, timers, args, writer)
# Or just evaluate.
else:
if end_of_epoch_callback is not None:
print_rank_0('evaluation only mode, setting epoch to -1')
end_of_epoch_callback(model, args, epoch=-1,
output_predictions=True)
print_rank_0('done :-)')
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GLUE dataset."""
from abc import ABC
from abc import abstractmethod
from torch.utils.data import Dataset
from megatron.utils import print_rank_0
from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_text
class GLUEAbstractDataset(ABC, Dataset):
"""GLUE base dataset class."""
def __init__(self, task_name, dataset_name, datapaths,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(self.process_samples_from_single_path(datapath))
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ids, types, paddings = build_tokens_types_paddings_from_text(
raw_sample['text_a'], raw_sample['text_b'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ids, types, paddings,
raw_sample['label'], raw_sample['uid'])
return sample
@abstractmethod
def process_samples_from_single_path(self, datapath):
"""Abstract method that takes a single path / filename and
returns a list of dataset samples, each sample being a dict of
{'text_a': string, 'text_b': string, 'label': int, 'uid': int}
"""
pass
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GLUE finetuning/evaluation."""
from megatron.utils import print_rank_0
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
def glue_classification(args, num_classes, Dataset,
name_from_datapath_func):
def train_valid_datasets_provider(args):
"""Build train and validation dataset."""
train_dataset = Dataset('training', args.train_data,
args.tokenizer, args.seq_length)
valid_dataset = Dataset('validation', args.valid_data,
args.tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(args):
"""Build the model."""
print_rank_0('building classification model for {} ...'.format(
args.task))
return Classification(
num_classes=num_classes,
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations)
def metrics_func_provider(args):
"""Privde metrics callback function."""
def single_dataset_provider(datapath, args):
name = name_from_datapath_func(datapath)
return Dataset(name, [datapath], args.tokenizer, args.seq_length)
return accuracy_func_provider(args, single_dataset_provider)
"""Finetune/evaluate."""
finetune(args, train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
def main(args):
if args.task == 'MNLI':
num_classes = 3
from .mnli import MNLIDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('MNLI')[-1].strip(
'.tsv').strip('/').replace('_', '-')
elif args.task == 'QQP':
num_classes = 2
from .qqp import QQPDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('QQP')[-1].strip(
'.tsv').strip('/').replace('_', '-')
else:
raise NotImplementedError('GLUE task {} is not implemented.'.format(
args.task))
glue_classification(args, num_classes, Dataset, name_from_datapath)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNLI dataset."""
from megatron.utils import print_rank_0
from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset
LABELS = {'contradiction': 0, 'entailment': 1, 'neutral': 2}
class MNLIDataset(GLUEAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length,
test_label='contradiction'):
self.test_label = test_label
super().__init__('MNLI', name, datapaths,
tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
first = True
is_test = False
with open(filename, 'r') as f:
for line in f:
row = line.strip().split('\t')
if first:
first = False
if len(row) == 10:
is_test = True
print_rank_0(
' reading {}, {} and {} columns and setting '
'labels to {}'.format(
row[0].strip(), row[8].strip(),
row[9].strip(), self.test_label))
else:
print_rank_0(' reading {} , {}, {}, and {} columns '
'...'.format(
row[0].strip(), row[8].strip(),
row[9].strip(), row[-1].strip()))
continue
text_a = clean_text(row[8].strip())
text_b = clean_text(row[9].strip())
unique_id = int(row[0].strip())
label = row[-1].strip()
if is_test:
label = self.test_label
assert len(text_a) > 0
assert len(text_b) > 0
assert label in LABELS
assert unique_id >= 0
sample = {'text_a': text_a,
'text_b': text_b,
'label': LABELS[label],
'uid': unique_id}
total += 1
samples.append(sample)
if total % 50000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""QQP dataset."""
from megatron.utils import print_rank_0
from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset
LABELS = [0, 1]
class QQPDataset(GLUEAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length,
test_label=0):
self.test_label = test_label
super().__init__('QQP', name, datapaths,
tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
first = True
is_test = False
with open(filename, 'r') as f:
for line in f:
row = line.strip().split('\t')
if first:
first = False
if len(row) == 3:
is_test = True
print_rank_0(' reading {}, {}, and {} columns and '
'setting labels to {}'.format(
row[0].strip(), row[1].strip(),
row[2].strip(), self.test_label))
else:
assert len(row) == 6
print_rank_0(' reading {}, {}, {}, and {} columns'
' ...'.format(
row[0].strip(), row[3].strip(),
row[4].strip(), row[5].strip()))
continue
if is_test:
assert len(row) == 3, 'expected length 3: {}'.format(row)
uid = int(row[0].strip())
text_a = clean_text(row[1].strip())
text_b = clean_text(row[2].strip())
label = self.test_label
assert len(text_a) > 0
assert len(text_b) > 0
else:
if len(row) == 6:
uid = int(row[0].strip())
text_a = clean_text(row[3].strip())
text_b = clean_text(row[4].strip())
label = int(row[5].strip())
else:
print_rank_0('***WARNING*** index error, '
'skipping: {}'.format(row))
continue
if len(text_a) == 0:
print_rank_0('***WARNING*** zero length a, '
'skipping: {}'.format(row))
continue
if len(text_b) == 0:
print_rank_0('***WARNING*** zero length b, '
'skipping: {}'.format(row))
continue
assert label in LABELS
assert uid >= 0
sample = {'uid': uid,
'text_a': text_a,
'text_b': text_b,
'label': label}
total += 1
samples.append(sample)
if total % 50000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from arguments import get_args
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group('tasks', 'tasks configurations')
parser.add_argument('--task', type=str, required=True,
help='task name.')
group.add_argument('--epochs', type=int, required=True,
help='number of finetunning epochs. Zero results in '
'evaluation only.')
parser.add_argument('--pretrained-checkpoint', type=str, default=None,
help='pretrained checkpoint used for finetunning.')
group.add_argument('--keep-last', action='store_true',
help='keep the last batch (maybe incomplete) in'
'the data loader')
return parser
if __name__ == '__main__':
args = get_args(extra_args_provider=get_tasks_args)
if args.task == 'RACE':
from race.finetune import main
elif args.task in ['MNLI', 'QQP']:
from glue.finetune import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main(args)
import glob
import json
import os
import time
from torch.utils.data import Dataset
from megatron.utils import print_rank_0
from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_ids
from tasks.data_utils import clean_text
NUM_CHOICES = 4
MAX_QA_LENGTH = 128
class RaceDataset(Dataset):
def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length,
max_qa_length=MAX_QA_LENGTH):
self.dataset_name = dataset_name
print_rank_0(' > building RACE dataset for {}:'.format(
self.dataset_name))
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(process_single_datapath(datapath, tokenizer,
max_qa_length,
max_seq_length))
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
"""Read in RACE files, combine, clean-up, tokenize, and convert to
samples."""
print_rank_0(' > working on {}'.format(datapath))
start_time = time.time()
# Get list of files.
filenames = glob.glob(os.path.join(datapath, '*.txt'))
samples = []
num_docs = 0
num_questions = 0
num_samples = 0
# Load all the files
for filename in filenames:
with open(filename, 'r') as f:
for line in f:
data = json.loads(line)
num_docs += 1
context = data["article"]
questions = data["questions"]
choices = data["options"]
answers = data["answers"]
# Check the length.
assert len(questions) == len(answers)
assert len(questions) == len(choices)
# Context: clean up and convert to ids.
context = clean_text(context)
context_ids = tokenizer.tokenize(context)
# Loop over questions.
for qi, question in enumerate(questions):
num_questions += 1
# Label.
label = ord(answers[qi]) - ord("A")
assert label >= 0
assert label < NUM_CHOICES
assert len(choices[qi]) == NUM_CHOICES
# For each question, build num-choices samples.
ids_list = []
types_list = []
paddings_list = []
for ci in range(NUM_CHOICES):
choice = choices[qi][ci]
# Merge with choice.
if "_" in question:
qa = question.replace("_", choice)
else:
qa = " ".join([question, choice])
# Clean QA.
qa = clean_text(qa)
# Tokenize.
qa_ids = tokenizer.tokenize(qa)
# Trim if needed.
if len(qa_ids) > max_qa_length:
qa_ids = qa_ids[0:max_qa_length]
# Build the sample.
ids, types, paddings \
= build_tokens_types_paddings_from_ids(
qa_ids, context_ids, max_seq_length,
tokenizer.cls, tokenizer.sep, tokenizer.pad)
ids_list.append(ids)
types_list.append(types)
paddings_list.append(paddings)
# Convert to numpy and add to samples
samples.append(build_sample(ids_list, types_list,
paddings_list, label,
num_samples))
num_samples += 1
elapsed_time = time.time() - start_time
print_rank_0(' > processed {} document, {} questions, and {} samples'
' in {:.2f} seconds'.format(num_docs, num_questions,
num_samples, elapsed_time))
return samples
"""Race."""
from megatron.model.multiple_choice import MultipleChoice
from megatron.utils import print_rank_0
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset
def train_valid_datasets_provider(args):
"""Provide train and validation datasets."""
train_dataset = RaceDataset('training', args.train_data,
args.tokenizer, args.seq_length)
valid_dataset = RaceDataset('validation', args.valid_data,
args.tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(args):
"""Build the model."""
print_rank_0('building multichoice model for RACE ...')
return MultipleChoice(
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations)
def metrics_func_provider(args):
"""Privde metrics callback function."""
def single_dataset_provider(datapath, args):
name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
return RaceDataset(name, [datapath], args.tokenizer, args.seq_length)
return accuracy_func_provider(args, single_dataset_provider)
def main(args):
finetune(args, train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment