Commit 2ef5e0de authored by thomwolf's avatar thomwolf
Browse files

switch to pytorch DistributedDataParallel

parent 9ce37af9
...@@ -902,12 +902,12 @@ def main(): ...@@ -902,12 +902,12 @@ def main():
model.half() model.half()
model.to(device) model.to(device)
if args.local_rank != -1: if args.local_rank != -1:
try: # try:
from apex.parallel import DistributedDataParallel as DDP # from apex.parallel import DistributedDataParallel as DDP
except ImportError: # except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") # raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model) model = torch.nn.parallel.DistributedDataParallel(model)
elif n_gpu > 1: elif n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment