Commit a59abedf authored by thomwolf's avatar thomwolf
Browse files

DDP update

parent 2ef5e0de
......@@ -907,7 +907,7 @@ def main():
# except ImportError:
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = torch.nn.parallel.DistributedDataParallel(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment