Commit 55695f81 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

padded ddp's grad_buffer to multiple of data parallel world size

parent 4f2356dc
......@@ -15,6 +15,9 @@
from abc import ABC
from abc import abstractmethod
# >>>
import math
# <<<
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
......@@ -27,14 +30,16 @@ from .module import MegatronModule
class MemoryBuffer:
def __init__(self, numel, dtype):
# >>>
def __init__(self, numel, numel_padded, dtype):
self.numel = numel
self.numel_padded = numel_padded
self.dtype = dtype
self.data = torch.zeros(self.numel,
self.data = torch.zeros(self.numel_padded,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# <<<
def zero(self):
"""Reset the buffer to zero."""
......@@ -132,6 +137,7 @@ class DistributedDataParallel(DistributedDataParallelBase):
# self._grad_buffer_param_offsets = defaultdict(dict)
# self._grad_buffer_param_index_map = defaultdict(dict)
self._grad_buffer_param_index_map = {}
data_parallel_world_size = mpu.get_data_parallel_world_size()
# <<<
# Simple function to define buffer type.
......@@ -149,7 +155,31 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# >>>
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded = data_parallel_world_size * \
int(math.ceil(num_elements / data_parallel_world_size))
# <<<
# Allocate grad buffer.
self._grad_buffers[dtype] = MemoryBuffer(num_elements,
num_elements_padded,
dtype)
# >>>
# from lutil import pax
# if True or num_elements % data_parallel_world_size != 0:
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "num_elements" : num_elements,
# "num_elements_padded" : num_elements_padded,
# "modulo" : num_elements % data_parallel_world_size,
# "grad buffer" : self._grad_buffers[dtype],
# })
# <<<
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
......
......@@ -626,17 +626,11 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
......@@ -652,15 +646,8 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
# >>>
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(
# unwrapped_model.language_model.embedding.position_embeddings.weight)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_position_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop()
def gather_params(self, ITERATION):
......@@ -717,19 +704,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# pax(1, {"main_grads": main_grads})
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# >>>
# from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# pax(1, {"main_grads": [ (param_is_not_tensor_parallel_duplicate(t), tp(t)) for t in main_grads ]})
# <<<
return main_grads
......@@ -827,40 +807,6 @@ class Shard:
# class Float16DistributedOptimizer(MegatronOptimizer):
class Float16DistributedOptimizer(BaseFloat16Optimizer):
# >>>
# @classmethod
# def test_reduce_scatter(cls):
# torch.manual_seed(mpu.get_data_parallel_rank())
# size = (20,)
# dtype = torch.float
# device = torch.cuda.current_device()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# input_list = [
# # torch.randn(size, dtype = dtype, device = device)
# 5 * torch.randint(low = 1, high = 3, size = size, dtype = dtype, device = device)
# for _ in range(data_parallel_world_size)
# ]
# output = torch.empty(size, dtype = dtype, device = device)
# torch.distributed.reduce_scatter(
# output,
# input_list,
# group = data_parallel_group,
# )
# if torch.distributed.get_rank() == 0:
# print(output)
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "data_parallel_group" : data_parallel_group,
# "input_list" : input_list,
# "output" : tp(output),
# })
# <<<
@classmethod
def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard):
......@@ -913,6 +859,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size)
gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end)
gbuf_world_all_shards.append(gbuf_world_shard)
# >>>
# if max_gbuf_shard_size != gbuf_world_shard.size:
# raise Exception("%d: smaller, rank %d. [ %d -> %d vs. %d]" % (
# data_parallel_rank,
# r,
# gbuf_size,
# max_gbuf_shard_size,
# gbuf_world_shard.size,
# ))
# <<<
gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank]
gbuf_local_shard = gbuf_world_shard.normalize()
......@@ -927,9 +883,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"world" : gbuf_world_shard,
"world_all" : gbuf_world_all_shards,
"param_map" : param_shard_map,
"max_shard_size" : max_gbuf_shard_size,
}
# pax(1, {"data": data})
# pax(0, {"data": data})
return data
......@@ -992,9 +949,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
param_group_end = param_group_start + param_size
param_group_shard = Shard(param_group_start, param_group_end)
# group_shard["max_size"] = gbuf_shard_map["max_shard_size"]
group_shard["size"] += param_size
group_shard["param_map"][param] = param_group_shard
# pax(0, {"gbuf_shard_map": gbuf_shard_map})
# >>>
# if torch.distributed.get_rank() == 1:
# print(">>> [%d] ... group %d, size %d, param %s. <<<" % (
......@@ -1010,6 +969,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group_shard["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if g["size"] > 0 ]
# [ ... x ... ] Synchronize group sizes across ranks.
# pax(0, {
# "param_group_map": [
# (g, str(p.shape))
......@@ -1035,6 +996,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# main_param_shards = []
for group_index, group_shard in enumerate(opt_group_shards):
# pax(0, {
# "group_shard" : group_shard,
# })
group_size = group_shard["size"]
assert group_size != 0, "temporary check ... remove me."
......@@ -1075,29 +1040,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert args.use_contiguous_buffers_in_local_ddp # already checked in args
# <<<
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Model grad buffer shards.
self.model_gbuf_shards = []
for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards)
# pax(0, {"param_gbuf_map": [ (str(tuple(p.shape)), d) for p, d in self.param_gbuf_map.items() ]})
# Optimizer shards.
self.opt_group_shards = self.get_optimizer_group_shards(
self.optimizer.param_groups,
self.model_gbuf_shards)
# pax(0, {**{"opt_group_shards / %d" % i : g for i, g in enumerate(self.opt_group_shards)}})
# Allocate main param shards.
# self.main_param_shards = \
# self.allocate_main_param_shards(self.opt_group_shards)
self.allocate_main_param_shards(self.opt_group_shards)
# >>>
......@@ -1205,6 +1159,37 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {"model_params": model_params})
# def get_model_grad_buffer_dp_views(self):
# # >>>
# # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
# args = get_args()
# assert args.use_contiguous_buffers_in_local_ddp
# # <<<
# # Grad buffer views.
# gbuf_view_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
# world_shards = gbuf_shard["world_all"]
# gbuf = model._grad_buffers[dtype].data
# gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
# gbuf_view_items.append((model_index, dtype, gbuf_views))
# # pax(0, {
# # "world_shards" : world_shards,
# # "gbuf_views" : gbuf_views,
# # })
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
# return gbuf_view_items
def get_model_grad_buffer_dp_views(self):
# >>>
......@@ -1213,21 +1198,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views.
gbuf_view_items = []
for model_index, model in enumerate(self.models):
for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
world_shards = gbuf_shard["world_all"]
gbuf = model._grad_buffers[dtype].data
gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
gbuf_view_items.append((model_index, dtype, gbuf_views))
for dtype, gbuf in model._grad_buffers.items():
# gbuf_size = gbuf.numel_padded
assert gbuf.numel_padded % data_parallel_world_size == 0
shard_size = int(gbuf.numel_padded / data_parallel_world_size)
# pax(0, {
# "world_shards" : world_shards,
# "gbuf_views" : gbuf_views,
# "numel" : gbuf.numel,
# "numel_padded" : gbuf.numel_padded,
# "shard_size / f" : gbuf.numel_padded/data_parallel_world_size,
# "shard_size / i" : shard_size,
# })
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf_views))
# pax(0, {"gbuf_view_items": gbuf_view_items})
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
return gbuf_view_items
......@@ -1361,11 +1359,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf /= data_parallel_world_size
# if 1:
# try:
# pax(0, {"gbuf_views": gbuf_views})
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
# except:
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# "gbuf_views" : gbuf_views,
# })
# else:
# torch.distributed.all_reduce(
# gbuf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment