Commit 5e56e563 authored by Neel Kant's avatar Neel Kant
Browse files

Merge master into realm-mlm

parents 6c0a5bd8 569b3dab
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -96,7 +96,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -96,7 +96,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator) train_data_iterator, valid_data_iterator)
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
...@@ -173,7 +172,7 @@ def get_optimizer(model): ...@@ -173,7 +172,7 @@ def get_optimizer(model):
dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={ dynamic_loss_args={
'scale_window': args.loss_scale_window, 'scale_window': args.loss_scale_window,
'min_scale':args.min_scale, 'min_scale': args.min_scale,
'delayed_shift': args.hysteresis}) 'delayed_shift': args.hysteresis})
return optimizer return optimizer
...@@ -228,7 +227,7 @@ def backward_step(optimizer, model, loss): ...@@ -228,7 +227,7 @@ def backward_step(optimizer, model, loss):
torch.cuda.synchronize() torch.cuda.synchronize()
# Backward pass. # Backward pass.
optimizer.zero_grad() optimizer.zero_grad(set_grads_to_None=True)
if args.fp16: if args.fp16:
optimizer.backward(loss, update_master_grads=False) optimizer.backward(loss, update_master_grads=False)
else: else:
...@@ -297,6 +296,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -297,6 +296,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Logging. # Logging.
timers_to_log = [] timers_to_log = []
def add_to_logging(name): def add_to_logging(name):
if name in timers.timers: if name in timers.timers:
timers_to_log.append(name) timers_to_log.append(name)
...@@ -431,7 +431,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -431,7 +431,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Reduce across processes. # Reduce across processes.
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
loss_dict[key] loss_dict[key]
# Move model back to the train mode. # Move model back to the train mode.
model.train() model.train()
...@@ -521,14 +521,14 @@ def build_train_valid_test_data_iterators( ...@@ -521,14 +521,14 @@ def build_train_valid_test_data_iterators(
# Shift the start iterations. # Shift the start iterations.
if train_dataloader is not None: if train_dataloader is not None:
train_dataloader.batch_sampler.start_iter = args.iteration % \ train_dataloader.batch_sampler.start_iter = args.iteration % \
len(train_dataloader) len(train_dataloader)
print_rank_0('setting training data start iteration to {}'. print_rank_0('setting training data start iteration to {}'.
format(train_dataloader.batch_sampler.start_iter)) format(train_dataloader.batch_sampler.start_iter))
if valid_dataloader is not None: if valid_dataloader is not None:
start_iter_val = (args.iteration // args.eval_interval) * \ start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters args.eval_iters
valid_dataloader.batch_sampler.start_iter = start_iter_val % \ valid_dataloader.batch_sampler.start_iter = start_iter_val % \
len(valid_dataloader) len(valid_dataloader)
print_rank_0('setting validation data start iteration to {}'. print_rank_0('setting validation data start iteration to {}'.
format(valid_dataloader.batch_sampler.start_iter)) format(valid_dataloader.batch_sampler.start_iter))
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -48,7 +48,7 @@ def report_memory(name): ...@@ -48,7 +48,7 @@ def report_memory(name):
torch.cuda.max_memory_allocated() / mega_bytes) torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format( string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached()/ mega_bytes) torch.cuda.max_memory_cached() / mega_bytes)
print_rank_0(string) print_rank_0(string)
...@@ -119,8 +119,7 @@ def get_ltor_masks_and_position_ids(data, ...@@ -119,8 +119,7 @@ def get_ltor_masks_and_position_ids(data,
eod_token, eod_token,
reset_position_ids, reset_position_ids,
reset_attention_mask, reset_attention_mask,
eod_mask_loss, eod_mask_loss):
fp16):
"""Build masks and position id for left to right model.""" """Build masks and position id for left to right model."""
# Extract batch size and sequence length. # Extract batch size and sequence length.
...@@ -164,14 +163,13 @@ def get_ltor_masks_and_position_ids(data, ...@@ -164,14 +163,13 @@ def get_ltor_masks_and_position_ids(data,
i = eod_index[j] i = eod_index[j]
# Mask attention loss. # Mask attention loss.
if reset_attention_mask: if reset_attention_mask:
attention_mask[b, 0, (i+1):, :(i+1)] = 0 attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions. # Reset positions.
if reset_position_ids: if reset_position_ids:
position_ids[b, (i+1):] -= (i + 1 - prev_index) position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1 prev_index = i + 1
# Convert # Convert attention mask to binary:
if fp16: attention_mask = (attention_mask < 0.5)
attention_mask = attention_mask.half()
return attention_mask, loss_mask, position_ids return attention_mask, loss_mask, position_ids
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -65,8 +65,7 @@ def get_batch(data_iterator): ...@@ -65,8 +65,7 @@ def get_batch(data_iterator):
tokenizer.eod, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
args.eod_mask_loss, args.eod_mask_loss)
args.fp16)
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -75,8 +75,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length, ...@@ -75,8 +75,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
# A. # A.
len_text_a = len(text_a_ids) len_text_a = len(text_a_ids)
ids.extend(text_a_ids) ids.extend(text_a_ids)
types.extend([0]*len_text_a) types.extend([0] * len_text_a)
paddings.extend([1]*len_text_a) paddings.extend([1] * len_text_a)
# [SEP]. # [SEP].
ids.append(sep_id) ids.append(sep_id)
...@@ -87,8 +87,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length, ...@@ -87,8 +87,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
if text_b_ids is not None: if text_b_ids is not None:
len_text_b = len(text_b_ids) len_text_b = len(text_b_ids)
ids.extend(text_b_ids) ids.extend(text_b_ids)
types.extend([1]*len_text_b) types.extend([1] * len_text_b)
paddings.extend([1]*len_text_b) paddings.extend([1] * len_text_b)
# Cap the size. # Cap the size.
trimmed = False trimmed = False
...@@ -111,8 +111,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length, ...@@ -111,8 +111,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
# Padding. # Padding.
padding_length = max_seq_length - len(ids) padding_length = max_seq_length - len(ids)
if padding_length > 0: if padding_length > 0:
ids.extend([pad_id]*padding_length) ids.extend([pad_id] * padding_length)
types.extend([pad_id]*padding_length) types.extend([pad_id] * padding_length)
paddings.extend([0]*padding_length) paddings.extend([0] * padding_length)
return ids, types, paddings return ids, types, paddings
...@@ -5,6 +5,7 @@ import collections ...@@ -5,6 +5,7 @@ import collections
import numpy as np import numpy as np
import torch import torch
def process_files(args): def process_files(args):
all_predictions = collections.OrderedDict() all_predictions = collections.OrderedDict()
all_labels = collections.OrderedDict() all_labels = collections.OrderedDict()
...@@ -40,12 +41,12 @@ def get_threshold(all_predictions, all_labels, one_threshold=False): ...@@ -40,12 +41,12 @@ def get_threshold(all_predictions, all_labels, one_threshold=False):
for dataset in all_predictions: for dataset in all_predictions:
preds = all_predictions[dataset] preds = all_predictions[dataset]
labels = all_labels[dataset] labels = all_labels[dataset]
out_thresh.append(calc_threshold(preds,labels)) out_thresh.append(calc_threshold(preds, labels))
return out_thresh return out_thresh
def calc_threshold(p, l): def calc_threshold(p, l):
trials = [(i)*(1./100.) for i in range(100)] trials = [(i) * (1. / 100.) for i in range(100)]
best_acc = float('-inf') best_acc = float('-inf')
best_thresh = 0 best_thresh = 0
for t in trials: for t in trials:
...@@ -58,7 +59,7 @@ def calc_threshold(p, l): ...@@ -58,7 +59,7 @@ def calc_threshold(p, l):
def apply_threshold(preds, t): def apply_threshold(preds, t):
assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0]))) assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
prob = preds[:,-1] prob = preds[:, -1]
thresholded = (prob >= t).astype(int) thresholded = (prob >= t).astype(int)
preds = np.zeros_like(preds) preds = np.zeros_like(preds)
preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1 preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
...@@ -66,8 +67,8 @@ def apply_threshold(preds, t): ...@@ -66,8 +67,8 @@ def apply_threshold(preds, t):
def threshold_predictions(all_predictions, threshold): def threshold_predictions(all_predictions, threshold):
if len(threshold)!=len(all_predictions): if len(threshold) != len(all_predictions):
threshold = [threshold[-1]]*(len(all_predictions)-len(threshold)) threshold = [threshold[-1]] * (len(all_predictions) - len(threshold))
for i, dataset in enumerate(all_predictions): for i, dataset in enumerate(all_predictions):
thresh = threshold[i] thresh = threshold[i]
preds = all_predictions[dataset] preds = all_predictions[dataset]
...@@ -77,7 +78,7 @@ def threshold_predictions(all_predictions, threshold): ...@@ -77,7 +78,7 @@ def threshold_predictions(all_predictions, threshold):
def postprocess_predictions(all_predictions, all_labels, args): def postprocess_predictions(all_predictions, all_labels, args):
for d in all_predictions: for d in all_predictions:
all_predictions[d] = all_predictions[d]/len(args.paths) all_predictions[d] = all_predictions[d] / len(args.paths)
if args.calc_threshold: if args.calc_threshold:
args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold) args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
...@@ -98,19 +99,22 @@ def write_predictions(all_predictions, all_labels, all_uid, args): ...@@ -98,19 +99,22 @@ def write_predictions(all_predictions, all_labels, all_uid, args):
if args.eval: if args.eval:
correct = (preds == all_labels[dataset]).sum() correct = (preds == all_labels[dataset]).sum()
num = len(all_labels[dataset]) num = len(all_labels[dataset])
accuracy = correct/num accuracy = correct / num
count += num count += num
all_correct += correct all_correct += correct
accuracy = (preds == all_labels[dataset]).mean() accuracy = (preds == all_labels[dataset]).mean()
print(accuracy) print(accuracy)
if not os.path.exists(os.path.join(args.outdir, dataset)): if not os.path.exists(os.path.join(args.outdir, dataset)):
os.makedirs(os.path.join(args.outdir, dataset)) os.makedirs(os.path.join(args.outdir, dataset))
outpath = os.path.join(args.outdir, dataset, os.path.splitext(args.prediction_name)[0]+'.tsv') outpath = os.path.join(
args.outdir, dataset, os.path.splitext(
args.prediction_name)[0] + '.tsv')
with open(outpath, 'w') as f: with open(outpath, 'w') as f:
f.write('id\tlabel\n') f.write('id\tlabel\n')
f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist()))) f.write('\n'.join(str(uid) + '\t' + str(args.labels[p])
for uid, p in zip(all_uid[dataset], preds.tolist())))
if args.eval: if args.eval:
print(all_correct/count) print(all_correct / count)
def ensemble_predictions(args): def ensemble_predictions(args):
...@@ -119,7 +123,7 @@ def ensemble_predictions(args): ...@@ -119,7 +123,7 @@ def ensemble_predictions(args):
write_predictions(all_predictions, all_labels, all_uid, args) write_predictions(all_predictions, all_labels, all_uid, args)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--paths', required=True, nargs='+', parser.add_argument('--paths', required=True, nargs='+',
help='paths to checkpoint directories used in ensemble') help='paths to checkpoint directories used in ensemble')
...@@ -135,11 +139,11 @@ def main(): ...@@ -135,11 +139,11 @@ def main():
help='use on threshold for all subdatasets') help='use on threshold for all subdatasets')
parser.add_argument('--threshold', nargs='+', default=None, type=float, parser.add_argument('--threshold', nargs='+', default=None, type=float,
help='user supplied threshold for classification') help='user supplied threshold for classification')
parser.add_argument('--labels',nargs='+', default=None, parser.add_argument('--labels', nargs='+', default=None,
help='whitespace separated list of label names') help='whitespace separated list of label names')
args = parser.parse_args() args = parser.parse_args()
ensemble_predictions(args) ensemble_predictions(args)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,7 +21,7 @@ from megatron import get_args ...@@ -21,7 +21,7 @@ from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer from megatron.training import setup_model_and_optimizer
...@@ -53,7 +53,7 @@ def _cross_entropy_forward_step(batch, model): ...@@ -53,7 +53,7 @@ def _cross_entropy_forward_step(batch, model):
timers('batch generator').start() timers('batch generator').start()
try: try:
batch_ = next(batch) batch_ = next(batch)
except: except BaseException:
batch_ = batch batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_) tokens, types, labels, attention_mask = process_batch(batch_)
timers('batch generator').stop() timers('batch generator').stop()
...@@ -146,7 +146,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -146,7 +146,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# For each remaining epoch # For each remaining epoch
timers('interval time').start() timers('interval time').start()
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
print_rank_0('working on epoch {} ...'.format(epoch+1)) print_rank_0('working on epoch {} ...'.format(epoch + 1))
# Set the data loader epoch to shuffle the index iterator. # Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch) train_dataloader.sampler.set_epoch(args.seed + epoch)
...@@ -172,7 +172,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -172,7 +172,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
report_memory_flag) report_memory_flag)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler) optimizer, lr_scheduler)
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -48,11 +48,9 @@ class GLUEAbstractDataset(ABC, Dataset): ...@@ -48,11 +48,9 @@ class GLUEAbstractDataset(ABC, Dataset):
print_rank_0(' >> total number of samples: {}'.format( print_rank_0(' >> total number of samples: {}'.format(
len(self.samples))) len(self.samples)))
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def __getitem__(self, idx): def __getitem__(self, idx):
raw_sample = self.samples[idx] raw_sample = self.samples[idx]
ids, types, paddings = build_tokens_types_paddings_from_text( ids, types, paddings = build_tokens_types_paddings_from_text(
...@@ -62,7 +60,6 @@ class GLUEAbstractDataset(ABC, Dataset): ...@@ -62,7 +60,6 @@ class GLUEAbstractDataset(ABC, Dataset):
raw_sample['label'], raw_sample['uid']) raw_sample['label'], raw_sample['uid'])
return sample return sample
@abstractmethod @abstractmethod
def process_samples_from_single_path(self, datapath): def process_samples_from_single_path(self, datapath):
"""Abstract method that takes a single path / filename and """Abstract method that takes a single path / filename and
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -38,7 +38,6 @@ def glue_classification(num_classes, Dataset, ...@@ -38,7 +38,6 @@ def glue_classification(num_classes, Dataset,
return train_dataset, valid_dataset return train_dataset, valid_dataset
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
...@@ -48,7 +47,6 @@ def glue_classification(num_classes, Dataset, ...@@ -48,7 +47,6 @@ def glue_classification(num_classes, Dataset,
return Classification(num_classes=num_classes, num_tokentypes=2) return Classification(num_classes=num_classes, num_tokentypes=2)
def metrics_func_provider(): def metrics_func_provider():
"""Privde metrics callback function.""" """Privde metrics callback function."""
def single_dataset_provider(datapath): def single_dataset_provider(datapath):
...@@ -59,7 +57,6 @@ def glue_classification(num_classes, Dataset, ...@@ -59,7 +57,6 @@ def glue_classification(num_classes, Dataset,
return Dataset(name, [datapath], tokenizer, args.seq_length) return Dataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(single_dataset_provider) return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate.""" """Finetune/evaluate."""
finetune(train_valid_datasets_provider, model_provider, finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider) end_of_epoch_callback_provider=metrics_func_provider)
...@@ -72,6 +69,7 @@ def main(): ...@@ -72,6 +69,7 @@ def main():
num_classes = 3 num_classes = 3
from tasks.glue.mnli import MNLIDataset as Dataset from tasks.glue.mnli import MNLIDataset as Dataset
def name_from_datapath(datapath): def name_from_datapath(datapath):
return datapath.split('MNLI')[-1].strip( return datapath.split('MNLI')[-1].strip(
'.tsv').strip('/').replace('_', '-') '.tsv').strip('/').replace('_', '-')
...@@ -80,6 +78,7 @@ def main(): ...@@ -80,6 +78,7 @@ def main():
num_classes = 2 num_classes = 2
from tasks.glue.qqp import QQPDataset as Dataset from tasks.glue.qqp import QQPDataset as Dataset
def name_from_datapath(datapath): def name_from_datapath(datapath):
return datapath.split('QQP')[-1].strip( return datapath.split('QQP')[-1].strip(
'.tsv').strip('/').replace('_', '-') '.tsv').strip('/').replace('_', '-')
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,7 +31,6 @@ class MNLIDataset(GLUEAbstractDataset): ...@@ -31,7 +31,6 @@ class MNLIDataset(GLUEAbstractDataset):
super().__init__('MNLI', name, datapaths, super().__init__('MNLI', name, datapaths,
tokenizer, max_seq_length) tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename): def process_samples_from_single_path(self, filename):
""""Implement abstract method.""" """"Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename)) print_rank_0(' > Processing {} ...'.format(filename))
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,7 +31,6 @@ class QQPDataset(GLUEAbstractDataset): ...@@ -31,7 +31,6 @@ class QQPDataset(GLUEAbstractDataset):
super().__init__('QQP', name, datapaths, super().__init__('QQP', name, datapaths,
tokenizer, max_seq_length) tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename): def process_samples_from_single_path(self, filename):
""""Implement abstract method.""" """"Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename)) print_rank_0(' > Processing {} ...'.format(filename))
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -46,7 +46,7 @@ def get_tasks_args(parser): ...@@ -46,7 +46,7 @@ def get_tasks_args(parser):
group.add_argument('--overlapping-eval', type=int, default=32, group.add_argument('--overlapping-eval', type=int, default=32,
help='Sliding window for overlapping evaluation.') help='Sliding window for overlapping evaluation.')
group.add_argument('--strict-lambada', action='store_true', group.add_argument('--strict-lambada', action='store_true',
help='Use more difficult formulation of lambada.') help='Use more difficult formulation of lambada.')
return parser return parser
......
...@@ -39,16 +39,13 @@ class RaceDataset(Dataset): ...@@ -39,16 +39,13 @@ class RaceDataset(Dataset):
print_rank_0(' >> total number of samples: {}'.format( print_rank_0(' >> total number of samples: {}'.format(
len(self.samples))) len(self.samples)))
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.samples[idx] return self.samples[idx]
def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length): def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
"""Read in RACE files, combine, clean-up, tokenize, and convert to """Read in RACE files, combine, clean-up, tokenize, and convert to
samples.""" samples."""
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -64,12 +64,12 @@ class _LMDataset(torch.utils.data.Dataset): ...@@ -64,12 +64,12 @@ class _LMDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
start_idx = idx * self.overalapping_eval start_idx = idx * self.overalapping_eval
end_idx = start_idx + self.seq_len end_idx = start_idx + self.seq_len
tokens = self.tokens[start_idx:end_idx+1] tokens = self.tokens[start_idx:end_idx + 1]
num_tokens = len(tokens) num_tokens = len(tokens)
pad_mask = [1]*num_tokens pad_mask = [1] * num_tokens
if num_tokens < self.seq_len+1: if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len+1-num_tokens) num_pad = (self.seq_len + 1 - num_tokens)
pad_mask += [0]*(num_pad) pad_mask += [0] * (num_pad)
tokens += [self.pad_idx] * num_pad tokens += [self.pad_idx] * num_pad
pad_mask = np.array(pad_mask[1:]) pad_mask = np.array(pad_mask[1:])
if self.overalapping_eval != self.seq_len and idx != 0: if self.overalapping_eval != self.seq_len and idx != 0:
...@@ -103,7 +103,7 @@ class _LambadaDataset(torch.utils.data.Dataset): ...@@ -103,7 +103,7 @@ class _LambadaDataset(torch.utils.data.Dataset):
last_token = text.split()[-1] last_token = text.split()[-1]
start_idx = text.rfind(last_token) start_idx = text.rfind(last_token)
beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip()) beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip())
last_token = self.tokenizer.tokenize(' '+last_token) last_token = self.tokenizer.tokenize(' ' + last_token)
return beginning_tokens, last_token return beginning_tokens, last_token
def __len__(self): def __len__(self):
...@@ -112,14 +112,14 @@ class _LambadaDataset(torch.utils.data.Dataset): ...@@ -112,14 +112,14 @@ class _LambadaDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
tokens = self.tokens[idx] tokens = self.tokens[idx]
num_tokens = len(tokens) num_tokens = len(tokens)
pad_mask = [0]*num_tokens pad_mask = [0] * num_tokens
labels = self.labels[idx] labels = self.labels[idx]
pad_mask += [1]*len(labels) pad_mask += [1] * len(labels)
tokens = tokens+labels tokens = tokens + labels
num_tokens = len(tokens) num_tokens = len(tokens)
if num_tokens < self.seq_len+1: if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len+1-num_tokens) num_pad = (self.seq_len + 1 - num_tokens)
pad_mask += [0]*(num_pad) pad_mask += [0] * (num_pad)
tokens += [self.pad_idx] * num_pad tokens += [self.pad_idx] * num_pad
pad_mask = np.array(pad_mask[1:]) pad_mask = np.array(pad_mask[1:])
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,64 +19,62 @@ import re ...@@ -19,64 +19,62 @@ import re
def ptb_detokenizer(string): def ptb_detokenizer(string):
string = string.replace(" '", "'") string = string.replace(" '", "'")
string = string.replace(" \n", "\n") string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n") string = string.replace("\n ", "\n")
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" N ","1 ") string = string.replace(" N ", "1 ")
string = string.replace("$ 1", "$1") string = string.replace("$ 1", "$1")
string = string.replace("# 1", "#1") string = string.replace("# 1", "#1")
return string return string
def wikitext_detokenizer(string): def wikitext_detokenizer(string):
#contractions # contractions
string = string.replace("s '", "s'") string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators # number separators
string = string.replace(" @-@ ", "-") string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",") string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".") string = string.replace(" @.@ ", ".")
#punctuation # punctuation
string = string.replace(" : ", ": ") string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ") string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ") string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ") string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ") string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ") string = string.replace(" , ", ", ")
# double brackets # double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous # miscellaneous
string = string.replace("= = = =", "====") string = string.replace("= = = =", "====")
string = string.replace("= = =", "===") string = string.replace("= = =", "===")
string = string.replace("= =", "==") string = string.replace("= =", "==")
string = string.replace(" "+chr(176)+" ", chr(176)) string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n") string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n") string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ") string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s") string = string.replace(" 's", "'s")
return string return string
def lambada_detokenizer(string): def lambada_detokenizer(string):
return string return string
_DETOKENIZERS = { _DETOKENIZERS = {
'ptb': ptb_detokenizer, 'ptb': ptb_detokenizer,
'wikitext': wikitext_detokenizer, 'wikitext': wikitext_detokenizer,
'lambada': lambada_detokenizer, 'lambada': lambada_detokenizer,
} }
def get_detokenizer(path): def get_detokenizer(path):
for key in _DETOKENIZERS.keys(): for key in _DETOKENIZERS.keys():
if key in path: if key in path:
return _DETOKENIZERS[key] return _DETOKENIZERS[key]
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -71,8 +71,7 @@ def process_batch(batch): ...@@ -71,8 +71,7 @@ def process_batch(batch):
tokenizer.eod, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
args.eod_mask_loss, args.eod_mask_loss)
args.fp16)
return tokens, labels, attention_mask, position_ids, loss_mask return tokens, labels, attention_mask, position_ids, loss_mask
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment