Commit 6ae4a03c authored by Mohammad's avatar Mohammad
Browse files

Merge branch 'staging' into staging_gpt2_dataloader

parents e06159f2 7890681a
......@@ -59,6 +59,7 @@ def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
args = get_args()
device_count = torch.cuda.device_count()
if torch.distributed.is_initialized():
if args.rank == 0:
......@@ -66,23 +67,25 @@ def _initialize_distributed():
'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()
device = torch.cuda.current_device()
local_rank = args.rank % torch.cuda.device_count()
assert local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
if device_count > 0:
device = torch.cuda.current_device()
local_rank = args.rank % device_count
assert local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else:
if args.rank == 0:
print('> initializing torch distributed ...', flush=True)
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
assert args.local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else:
args.local_rank = device
torch.cuda.set_device(device)
if device_count > 0:
device = args.rank % device_count
if args.local_rank is not None:
assert args.local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else:
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
......@@ -94,7 +97,8 @@ def _initialize_distributed():
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
if device_count > 0:
mpu.initialize_model_parallel(args.model_parallel_size)
def _init_autoresume():
......@@ -112,7 +116,8 @@ def _set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
import os
import argparse
import collections
import numpy as np
import torch
def process_files(args):
all_predictions = collections.OrderedDict()
all_labels = collections.OrderedDict()
all_uid = collections.OrderedDict()
for path in args.paths:
path = os.path.join(path, args.prediction_name)
try:
data = torch.load(path)
for dataset in data:
name, d = dataset
predictions, labels, uid = d
if name not in all_predictions:
all_predictions[name] = np.array(predictions)
if args.labels is None:
args.labels = [i for i in range(all_predictions[name].shape[1])]
if args.eval:
all_labels[name] = np.array(labels)
all_uid[name] = np.array(uid)
else:
all_predictions[name] += np.array(predictions)
assert np.allclose(all_uid[name], np.array(uid))
except Exception as e:
print(e)
continue
return all_predictions, all_labels, all_uid
def get_threshold(all_predictions, all_labels, one_threshold=False):
if one_threshold:
all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
out_thresh = []
for dataset in all_predictions:
preds = all_predictions[dataset]
labels = all_labels[dataset]
out_thresh.append(calc_threshold(preds,labels))
return out_thresh
def calc_threshold(p, l):
trials = [(i)*(1./100.) for i in range(100)]
best_acc = float('-inf')
best_thresh = 0
for t in trials:
acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
if acc > best_acc:
best_acc = acc
best_thresh = t
return best_thresh
def apply_threshold(preds, t):
assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
prob = preds[:,-1]
thresholded = (prob >= t).astype(int)
preds = np.zeros_like(preds)
preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
return preds
def threshold_predictions(all_predictions, threshold):
if len(threshold)!=len(all_predictions):
threshold = [threshold[-1]]*(len(all_predictions)-len(threshold))
for i, dataset in enumerate(all_predictions):
thresh = threshold[i]
preds = all_predictions[dataset]
all_predictions[dataset] = apply_threshold(preds, thresh)
return all_predictions
def postprocess_predictions(all_predictions, all_labels, args):
for d in all_predictions:
all_predictions[d] = all_predictions[d]/len(args.paths)
if args.calc_threshold:
args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
print('threshold', args.threshold)
if args.threshold is not None:
all_predictions = threshold_predictions(all_predictions, args.threshold)
return all_predictions, all_labels
def write_predictions(all_predictions, all_labels, all_uid, args):
all_correct = 0
count = 0
for dataset in all_predictions:
preds = all_predictions[dataset]
preds = np.argmax(preds, -1)
if args.eval:
correct = (preds == all_labels[dataset]).sum()
num = len(all_labels[dataset])
accuracy = correct/num
count += num
all_correct += correct
accuracy = (preds == all_labels[dataset]).mean()
print(accuracy)
if not os.path.exists(os.path.join(args.outdir, dataset)):
os.makedirs(os.path.join(args.outdir, dataset))
outpath = os.path.join(args.outdir, dataset, os.path.splitext(args.prediction_name)[0]+'.tsv')
with open(outpath, 'w') as f:
f.write('id\tlabel\n')
f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist())))
if args.eval:
print(all_correct/count)
def ensemble_predictions(args):
all_predictions, all_labels, all_uid = process_files(args)
all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
write_predictions(all_predictions, all_labels, all_uid, args)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--paths', required=True, nargs='+',
help='paths to checkpoint directories used in ensemble')
parser.add_argument('--eval', action='store_true',
help='compute accuracy metrics against labels (dev set)')
parser.add_argument('--outdir',
help='directory to place ensembled predictions in')
parser.add_argument('--prediction-name', default='test_predictions.pt',
help='name of predictions in checkpoint directories')
parser.add_argument('--calc-threshold', action='store_true',
help='calculate threshold classification')
parser.add_argument('--one-threshold', action='store_true',
help='use on threshold for all subdatasets')
parser.add_argument('--threshold', nargs='+', default=None, type=float,
help='user supplied threshold for classification')
parser.add_argument('--labels',nargs='+', default=None,
help='whitespace separated list of label names')
args = parser.parse_args()
ensemble_predictions(args)
if __name__ == '__main__':
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment