Unverified Commit 4ac8bfb0 authored by CZYCW's avatar CZYCW Committed by GitHub
Browse files

[NFC] polish colossalai/engine/gradient_handler/utils.py code style (#2708)

parent 6427c406
import torch.distributed as dist from typing import Iterable
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors import torch.distributed as dist
from typing import Iterable import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
# get communication world size def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
comm_size = dist.get_world_size(group) # get communication world size
# bucketize and all-reduce comm_size = dist.get_world_size(group)
buckets = {} # bucketize and all-reduce
# Pack the buckets. buckets = {}
for param in param_list: # Pack the buckets.
if param.requires_grad and param.grad is not None: for param in param_list:
tp = param.data.type() if param.requires_grad and param.grad is not None:
if tp not in buckets: tp = param.data.type()
buckets[tp] = [] if tp not in buckets:
buckets[tp].append(param) buckets[tp] = []
buckets[tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets: # For each bucket, all-reduce and copy all-reduced grads.
bucket = buckets[tp] for tp in buckets:
grads = [param.grad.data for param in bucket] bucket = buckets[tp]
coalesced = _flatten_dense_tensors(grads) grads = [param.grad.data for param in bucket]
coalesced /= comm_size coalesced = _flatten_dense_tensors(grads)
coalesced /= comm_size
dist.all_reduce(coalesced, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): dist.all_reduce(coalesced, group=group)
buf.copy_(synced) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment