Unverified Commit 765db512 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

fixed ddp bug on torch 1.8 (#194)

parent 569357fe
...@@ -348,12 +348,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]], ...@@ -348,12 +348,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"added even though not specified in the configuration", "added even though not specified in the configuration",
ranks=[0]) ranks=[0])
elif is_using_sequence(): elif is_using_sequence():
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP)) model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), device_ids=[torch.cuda.current_device()])
if verbose: if verbose:
logger.info( logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0]) 'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA)) model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
if verbose: if verbose:
logger.info( logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) 'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment