Commit 55695f81 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

padded ddp's grad_buffer to multiple of data parallel world size

parent 4f2356dc
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
# >>>
import math
# <<<
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
...@@ -27,14 +30,16 @@ from .module import MegatronModule ...@@ -27,14 +30,16 @@ from .module import MegatronModule
class MemoryBuffer: class MemoryBuffer:
def __init__(self, numel, dtype): # >>>
def __init__(self, numel, numel_padded, dtype):
self.numel = numel self.numel = numel
self.numel_padded = numel_padded
self.dtype = dtype self.dtype = dtype
self.data = torch.zeros(self.numel, self.data = torch.zeros(self.numel_padded,
dtype=self.dtype, dtype=self.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
# <<<
def zero(self): def zero(self):
"""Reset the buffer to zero.""" """Reset the buffer to zero."""
...@@ -132,6 +137,7 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -132,6 +137,7 @@ class DistributedDataParallel(DistributedDataParallelBase):
# self._grad_buffer_param_offsets = defaultdict(dict) # self._grad_buffer_param_offsets = defaultdict(dict)
# self._grad_buffer_param_index_map = defaultdict(dict) # self._grad_buffer_param_index_map = defaultdict(dict)
self._grad_buffer_param_index_map = {} self._grad_buffer_param_index_map = {}
data_parallel_world_size = mpu.get_data_parallel_world_size()
# <<< # <<<
# Simple function to define buffer type. # Simple function to define buffer type.
...@@ -149,7 +155,31 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -149,7 +155,31 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Allocate the buffer. # Allocate the buffer.
for dtype, num_elements in type_num_elements.items(): for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# >>>
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded = data_parallel_world_size * \
int(math.ceil(num_elements / data_parallel_world_size))
# <<<
# Allocate grad buffer.
self._grad_buffers[dtype] = MemoryBuffer(num_elements,
num_elements_padded,
dtype)
# >>>
# from lutil import pax
# if True or num_elements % data_parallel_world_size != 0:
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "num_elements" : num_elements,
# "num_elements_padded" : num_elements_padded,
# "modulo" : num_elements % data_parallel_world_size,
# "grad buffer" : self._grad_buffers[dtype],
# })
# <<<
# Assume the back prop order is reverse the params order, # Assume the back prop order is reverse the params order,
# store the start index for the gradients. # store the start index for the gradients.
......
...@@ -626,17 +626,11 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -626,17 +626,11 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if unwrapped_model.share_word_embeddings: if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight() word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad grad = word_embeddings_weight.main_grad
else: else:
grad = word_embeddings_weight.grad grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# All-reduce position_embeddings grad across first (encoder) and split (decoder) # All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync. # stages to ensure that position embeddings parameters stay in sync.
...@@ -652,15 +646,8 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -652,15 +646,8 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
unwrapped_model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \ assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode' 'T5 model is only supported with local DDP mode'
# >>>
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(
# unwrapped_model.language_model.embedding.position_embeddings.weight)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_position_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
def gather_params(self, ITERATION): def gather_params(self, ITERATION):
...@@ -717,19 +704,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -717,19 +704,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if main_param.grad is not None: if main_param.grad is not None:
main_grads.append(main_param.grad.data) main_grads.append(main_param.grad.data)
# pax(1, {"main_grads": main_grads})
# Append fp32 parameters. # Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups: for main_group in self.fp32_from_fp32_groups:
for main_param in main_group: for main_param in main_group:
if main_param.grad is not None: if main_param.grad is not None:
main_grads.append(main_param.grad.data) main_grads.append(main_param.grad.data)
# >>>
# from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# pax(1, {"main_grads": [ (param_is_not_tensor_parallel_duplicate(t), tp(t)) for t in main_grads ]})
# <<<
return main_grads return main_grads
...@@ -827,40 +807,6 @@ class Shard: ...@@ -827,40 +807,6 @@ class Shard:
# class Float16DistributedOptimizer(MegatronOptimizer): # class Float16DistributedOptimizer(MegatronOptimizer):
class Float16DistributedOptimizer(BaseFloat16Optimizer): class Float16DistributedOptimizer(BaseFloat16Optimizer):
# >>>
# @classmethod
# def test_reduce_scatter(cls):
# torch.manual_seed(mpu.get_data_parallel_rank())
# size = (20,)
# dtype = torch.float
# device = torch.cuda.current_device()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# input_list = [
# # torch.randn(size, dtype = dtype, device = device)
# 5 * torch.randint(low = 1, high = 3, size = size, dtype = dtype, device = device)
# for _ in range(data_parallel_world_size)
# ]
# output = torch.empty(size, dtype = dtype, device = device)
# torch.distributed.reduce_scatter(
# output,
# input_list,
# group = data_parallel_group,
# )
# if torch.distributed.get_rank() == 0:
# print(output)
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "data_parallel_group" : data_parallel_group,
# "input_list" : input_list,
# "output" : tp(output),
# })
# <<<
@classmethod @classmethod
def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard): def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard):
...@@ -913,6 +859,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -913,6 +859,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size) gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size)
gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end) gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end)
gbuf_world_all_shards.append(gbuf_world_shard) gbuf_world_all_shards.append(gbuf_world_shard)
# >>>
# if max_gbuf_shard_size != gbuf_world_shard.size:
# raise Exception("%d: smaller, rank %d. [ %d -> %d vs. %d]" % (
# data_parallel_rank,
# r,
# gbuf_size,
# max_gbuf_shard_size,
# gbuf_world_shard.size,
# ))
# <<<
gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank] gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank]
gbuf_local_shard = gbuf_world_shard.normalize() gbuf_local_shard = gbuf_world_shard.normalize()
...@@ -927,9 +883,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -927,9 +883,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"world" : gbuf_world_shard, "world" : gbuf_world_shard,
"world_all" : gbuf_world_all_shards, "world_all" : gbuf_world_all_shards,
"param_map" : param_shard_map, "param_map" : param_shard_map,
"max_shard_size" : max_gbuf_shard_size,
} }
# pax(1, {"data": data}) # pax(0, {"data": data})
return data return data
...@@ -992,9 +949,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -992,9 +949,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
param_group_end = param_group_start + param_size param_group_end = param_group_start + param_size
param_group_shard = Shard(param_group_start, param_group_end) param_group_shard = Shard(param_group_start, param_group_end)
# group_shard["max_size"] = gbuf_shard_map["max_shard_size"]
group_shard["size"] += param_size group_shard["size"] += param_size
group_shard["param_map"][param] = param_group_shard group_shard["param_map"][param] = param_group_shard
# pax(0, {"gbuf_shard_map": gbuf_shard_map})
# >>> # >>>
# if torch.distributed.get_rank() == 1: # if torch.distributed.get_rank() == 1:
# print(">>> [%d] ... group %d, size %d, param %s. <<<" % ( # print(">>> [%d] ... group %d, size %d, param %s. <<<" % (
...@@ -1010,6 +969,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1010,6 +969,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group_shard["orig_group"] = param_groups[group_index] group_shard["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if g["size"] > 0 ] group_shards = [ g for g in group_shards if g["size"] > 0 ]
# [ ... x ... ] Synchronize group sizes across ranks.
# pax(0, { # pax(0, {
# "param_group_map": [ # "param_group_map": [
# (g, str(p.shape)) # (g, str(p.shape))
...@@ -1035,6 +996,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1035,6 +996,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# main_param_shards = [] # main_param_shards = []
for group_index, group_shard in enumerate(opt_group_shards): for group_index, group_shard in enumerate(opt_group_shards):
# pax(0, {
# "group_shard" : group_shard,
# })
group_size = group_shard["size"] group_size = group_shard["size"]
assert group_size != 0, "temporary check ... remove me." assert group_size != 0, "temporary check ... remove me."
...@@ -1075,29 +1040,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1075,29 +1040,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert args.use_contiguous_buffers_in_local_ddp # already checked in args assert args.use_contiguous_buffers_in_local_ddp # already checked in args
# <<< # <<<
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Model grad buffer shards. # Model grad buffer shards.
self.model_gbuf_shards = [] self.model_gbuf_shards = []
for model_index, model in enumerate(self.models): for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model)) self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards) self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards)
# pax(0, {"param_gbuf_map": [ (str(tuple(p.shape)), d) for p, d in self.param_gbuf_map.items() ]})
# Optimizer shards. # Optimizer shards.
self.opt_group_shards = self.get_optimizer_group_shards( self.opt_group_shards = self.get_optimizer_group_shards(
self.optimizer.param_groups, self.optimizer.param_groups,
self.model_gbuf_shards) self.model_gbuf_shards)
# pax(0, {**{"opt_group_shards / %d" % i : g for i, g in enumerate(self.opt_group_shards)}})
# Allocate main param shards. # Allocate main param shards.
# self.main_param_shards = \
# self.allocate_main_param_shards(self.opt_group_shards)
self.allocate_main_param_shards(self.opt_group_shards) self.allocate_main_param_shards(self.opt_group_shards)
# >>> # >>>
...@@ -1205,6 +1159,37 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1205,6 +1159,37 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {"model_params": model_params}) # pax(0, {"model_params": model_params})
# def get_model_grad_buffer_dp_views(self):
# # >>>
# # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
# args = get_args()
# assert args.use_contiguous_buffers_in_local_ddp
# # <<<
# # Grad buffer views.
# gbuf_view_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
# world_shards = gbuf_shard["world_all"]
# gbuf = model._grad_buffers[dtype].data
# gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
# gbuf_view_items.append((model_index, dtype, gbuf_views))
# # pax(0, {
# # "world_shards" : world_shards,
# # "gbuf_views" : gbuf_views,
# # })
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
# return gbuf_view_items
def get_model_grad_buffer_dp_views(self): def get_model_grad_buffer_dp_views(self):
# >>> # >>>
...@@ -1213,21 +1198,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1213,21 +1198,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
# <<< # <<<
# data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views. # Grad buffer views.
gbuf_view_items = [] gbuf_view_items = []
for model_index, model in enumerate(self.models): for model_index, model in enumerate(self.models):
for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items(): for dtype, gbuf in model._grad_buffers.items():
world_shards = gbuf_shard["world_all"]
gbuf = model._grad_buffers[dtype].data
gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
gbuf_view_items.append((model_index, dtype, gbuf_views))
# gbuf_size = gbuf.numel_padded
assert gbuf.numel_padded % data_parallel_world_size == 0
shard_size = int(gbuf.numel_padded / data_parallel_world_size)
# pax(0, { # pax(0, {
# "world_shards" : world_shards, # "numel" : gbuf.numel,
# "gbuf_views" : gbuf_views, # "numel_padded" : gbuf.numel_padded,
# "shard_size / f" : gbuf.numel_padded/data_parallel_world_size,
# "shard_size / i" : shard_size,
# }) # })
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf_views))
# pax(0, {"gbuf_view_items": gbuf_view_items}) # pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
return gbuf_view_items return gbuf_view_items
...@@ -1361,11 +1359,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1361,11 +1359,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf /= data_parallel_world_size gbuf /= data_parallel_world_size
# if 1: # if 1:
# try:
# pax(0, {"gbuf_views": gbuf_views})
torch.distributed.reduce_scatter( torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank], gbuf_views[data_parallel_rank],
gbuf_views, gbuf_views,
group = data_parallel_group, group = data_parallel_group,
) )
# except:
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# "gbuf_views" : gbuf_views,
# })
# else: # else:
# torch.distributed.all_reduce( # torch.distributed.all_reduce(
# gbuf, # gbuf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment