Unverified Commit 8d3250d7 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] ZeRO supports pipeline parallel (#477)

parent 7f5e4592
#!/usr/bin/env python
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from collections import defaultdict
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from ._base_gradient_handler import BaseGradientHandler
from collections import defaultdict
@GRADIENT_HANDLER.register_module
......@@ -35,7 +37,7 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
for group, group_buckets in buckets.items():
for tp, bucket in group_buckets.items():
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
......@@ -262,3 +262,15 @@ class ShardedModelV2(nn.Module):
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
raise NotImplementedError
def __getitem__(self, idx: int):
assert isinstance(self.module, nn.ModuleList)
return self.module[idx]
def __len__(self):
assert isinstance(self.module, nn.ModuleList)
return len(self.module)
def __iter__(self):
assert isinstance(self.module, nn.ModuleList)
return iter(self.module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment