r""" Supportive modules to conduct distributed training """ import torch import torch.nn as nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from .utils import get_torch_default_comm class DistributedGroupedDataParallel(nn.Module): r""" A customized DDP module to support different all-reduce regions in the model. The all-reduce region is defined as an attribution `dp_comm` in the weight object. The grads of the weights are identified to be reduced in different groups according to the weigths' `dp_comm` attribute. If it is set to `dp`, it will only be reduced across the data-parallel groups, which means that in the model parallel group, they are not synchronized. If it is set to `world`, the gradients is synchronized across all workers, regardless their model or data parallel group. This is extremely useful for shared layers like the gate. """ def __init__( self, module, auto_allreduce=False, **kwargs ): assert not auto_allreduce, "Automatic all-reduce is not implemented yet" super().__init__() self.module = module self.comms = dict() for k in kwargs: if k.endswith('_group'): self.comms[k[:-6]] = kwargs[k] for k in ['dp', 'gate', 'moe', 'world']: if k not in self.comms: self.comms[k] = get_torch_default_comm() def allreduce_params(no_scale=False, reduce_after=False, fp32_allreduce=False): groups = dict() for p in self.module.parameters(): if not p.requires_grad or p.grad is None: continue if hasattr(p, "dp_comm"): dp_comm = p.dp_comm else: dp_comm = "dp" group_key = (dp_comm, p.dtype) if group_key not in groups: groups[group_key] = [p] else: groups[group_key].append(p) for (dp_comm, dtype), group in groups.items(): if dp_comm not in self.comms: continue comm = self.comms[dp_comm] grads = [p.grad.data for p in group] coalesced = _flatten_dense_tensors(grads) if fp32_allreduce and dtype != torch.float32: coalesced = coalesced.float() if not no_scale and not reduce_after: coalesced /= comm.size() torch.distributed.all_reduce(coalesced, group=comm) torch.cuda.synchronize() if not no_scale and reduce_after: coalesced /= comm.size() synced = _unflatten_dense_tensors(coalesced, grads) for g, s in zip(grads, synced): g.copy_(s) self.allreduce_params = allreduce_params self._sync_params() def _sync_params(self): r""" Note that this module does not guarantee initial consistency of parameters. Users are supposed to manually initalize the model on different workers with the same parameters using either this function or other methods like pre-defined random seeds. """ groups = dict() for p in self.module.parameters(): if not p.requires_grad or p.grad is None: continue if hasattr(p, "dp_comm"): dp_comm = p.dp_comm else: dp_comm = "dp" group_key = (dp_comm, p.dtype) if group_key not in groups: groups[group_key] = [p] else: groups[group_key].append(p) for (dp_comm, _), group in groups.items(): if dp_comm not in self.comms: continue comm = self.comms[dp_comm] datas = [p.data for p in group] coalesced = _flatten_dense_tensors(datas) torch.distributed.broadcast(coalesced, 0, group=comm) torch.cuda.synchronize() synced = _unflatten_dense_tensors(coalesced, datas) for d, s in zip(datas, synced): d.copy_(s) def forward(self, *args, **kwargs): r""" Directly call the module's forward function. """ return self.module(*args, **kwargs)