Commit 2af29c19 authored by Michael Carilli's avatar Michael Carilli
Browse files

Removing orphaned /distributed/run_distributed.sh

parent 1d45fada
export CUDA_VISIBLE_DEVICES=0,1; python -m apex.parallel.multiproc main.py
......@@ -98,7 +98,7 @@ def main():
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank % torch.cuda.device_count()
torch.cuda.set_device(args.gpu)
......
......@@ -100,13 +100,14 @@ def main():
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank % torch.cuda.device_count()
if args.distributed:
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
if args.fp16:
assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
......@@ -324,8 +325,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader),
torch.distributed.get_world_size() * args.batch_size / batch_time.val,
torch.distributed.get_world_size() * args.batch_size / batch_time.avg,
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
......@@ -382,8 +383,8 @@ def validate(val_loader, model, criterion):
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader),
torch.distributed.get_world_size() * args.batch_size / batch_time.val,
torch.distributed.get_world_size() * args.batch_size / batch_time.avg,
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
......@@ -445,7 +446,7 @@ def accuracy(output, target, topk=(1,)):
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= torch.distributed.get_world_size()
rt /= args.world_size
return rt
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment