Commit 28b63ec7 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

moved distrib opt to own file.

parent 55695f81
......@@ -25,11 +25,8 @@ from lutil import pax, tp
# <<<
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
# >>>
from .optimizer import (
Float16OptimizerWithFloat16Params,
Float16DistributedOptimizer,
FP32Optimizer,
)
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
from .distrib_optimizer import DistributedOptimizer
# <<<
def get_param_groups(modules,
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron distributed optimizer."""
import math
from megatron import get_args
# >>>
from lutil import pax, tp
DEBUG_ITERATION = 2 # 10
# <<<
class Shard:
def __init__(self, start, end):
self.start = start
self.end = end
self.size = end - start
def normalize(self, start = 0):
return Shard(start, start + self.size)
def __str__(self):
return "%d,%d [%d]" % (self.start, self.end, self.size)
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(MegatronOptimizer):
# class Float16DistributedOptimizer(BaseFloat16Optimizer):
class DistributedOptimizer(MegatronOptimizer):
@classmethod
def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard):
# Param shard map.
param_world_index_map = model._grad_buffer_param_index_map[dtype]
param_shard_map = {}
for param, param_world_indexes in param_world_index_map.items():
# Shard range.
param_world_start, param_world_end = param_world_indexes
param_local_start = max(
0,
param_world_start - gbuf_world_shard.start)
param_local_end = min(
gbuf_world_shard.size,
param_world_end - gbuf_world_shard.start)
# Add shard, if within range.
if param_local_end > param_local_start:
param_local_shard = Shard(param_local_start, param_local_end)
# param_world_shard = param_local_shard.normalize(param_world_start)
param_world_shard = param_local_shard.normalize(
param_local_start + gbuf_world_shard.start)
sub_param_start = max(0, gbuf_world_shard.start-param_world_start)
sub_param_shard = param_local_shard.normalize(sub_param_start)
param_shard_map[param] = {
"gbuf_world" : param_world_shard,
"gbuf_local" : param_local_shard,
"param" : sub_param_shard,
}
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
return param_shard_map
@classmethod
def get_model_gbuf_shard(cls, model, dtype):
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer shard.
grad_buffer = model._grad_buffers[dtype]
gbuf_size = grad_buffer.numel
max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size))
gbuf_world_all_shards = []
for r in range(data_parallel_world_size):
gbuf_world_start = r * max_gbuf_shard_size
gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size)
gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end)
gbuf_world_all_shards.append(gbuf_world_shard)
# >>>
# if max_gbuf_shard_size != gbuf_world_shard.size:
# raise Exception("%d: smaller, rank %d. [ %d -> %d vs. %d]" % (
# data_parallel_rank,
# r,
# gbuf_size,
# max_gbuf_shard_size,
# gbuf_world_shard.size,
# ))
# <<<
gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank]
gbuf_local_shard = gbuf_world_shard.normalize()
# Param shards.
param_shard_map = cls.get_model_gbuf_param_shard_map(model,
dtype,
gbuf_world_shard)
# Altogether.
data = {
"local" : gbuf_local_shard,
"world" : gbuf_world_shard,
"world_all" : gbuf_world_all_shards,
"param_map" : param_shard_map,
"max_shard_size" : max_gbuf_shard_size,
}
# pax(0, {"data": data})
return data
@classmethod
def get_model_gbuf_shard_map(cls, model):
return {
dtype : cls.get_model_gbuf_shard(model, dtype)
for dtype in model._grad_buffers
}
@classmethod
def get_param_gbuf_map(cls, model_gbuf_shards):
param_gbuf_map = {}
for model_index, model_gbuf_shard_map in enumerate(model_gbuf_shards):
for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
for param, param_shard_map in gbuf_shard_map["param_map"].items():
# assert param not in param_size_map
# param_size_map[param] = param_shard_map["local"].size
param_gbuf_map[param] = (model_index, dtype)
# pax(0, {
# "dtype" : dtype,
# "gbuf_shard_map" : gbuf_shard_map,
# "param" : tp(param),
# "param_shard_map" : param_shard_map,
# })
# pax(0, {
# "model_gbuf_shards" : model_gbuf_shards,
# # "param_size_map" :
# # [ (str(p.shape), s) for p, s in param_size_map.items() ],
# "param_gbuf_map" : param_gbuf_map,
# })
return param_gbuf_map
@classmethod
def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards):
num_groups = len(param_groups)
# Param group map.
param_group_map = {}
for group_index, group in enumerate(param_groups):
for param in group["params"]:
assert param.requires_grad
param_group_map[param] = group_index
# Optimizer group shards.
group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ]
for model_gbuf_shard_map in model_gbuf_shards:
for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
for param in gbuf_shard_map["param_map"]:
group_index = param_group_map[param]
group_shard = group_shards[group_index]
param_size = gbuf_shard_map["param_map"][param]["param"].size
param_group_start = group_shard["size"]
param_group_end = param_group_start + param_size
param_group_shard = Shard(param_group_start, param_group_end)
# group_shard["max_size"] = gbuf_shard_map["max_shard_size"]
group_shard["size"] += param_size
group_shard["param_map"][param] = param_group_shard
# pax(0, {"gbuf_shard_map": gbuf_shard_map})
# >>>
# if torch.distributed.get_rank() == 1:
# print(">>> [%d] ... group %d, size %d, param %s. <<<" % (
# torch.distributed.get_rank(),
# group_index,
# param_size,
# str(tuple(param.shape)),
# ))
# <<<
# Squeeze zero-size group shards.
for group_index, group_shard in enumerate(group_shards):
group_shard["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if g["size"] > 0 ]
# [ ... x ... ] Synchronize group sizes across ranks.
# pax(0, {
# "param_group_map": [
# (g, str(p.shape))
# for p, g in param_group_map.items()
# ],
# "group_shards" : group_shards,
# })
return group_shards
@classmethod
def allocate_main_param_shards(cls, opt_group_shards):
# Allocate main param/grad shard.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
# main_param_shards = []
for group_index, group_shard in enumerate(opt_group_shards):
# pax(0, {
# "group_shard" : group_shard,
# })
group_size = group_shard["size"]
assert group_size != 0, "temporary check ... remove me."
# ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard.
# if group_size == 0:
# main_param = None
# else:
main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# main_param_shards.append(main_param)
group_shard["orig_group"]["params"] = [ main_param ]
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# pax(1, {
# "opt_group_shards" : opt_group_shards,
# "main_param_shards" : main_param_shards,
# })
# return main_param_shards
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models):
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models)
# >>>
args = get_args()
assert args.use_contiguous_buffers_in_local_ddp # already checked in args
# <<<
# Model grad buffer shards.
self.model_gbuf_shards = []
for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards)
# Optimizer shards.
self.opt_group_shards = self.get_optimizer_group_shards(
self.optimizer.param_groups,
self.model_gbuf_shards)
# Allocate main param shards.
self.allocate_main_param_shards(self.opt_group_shards)
# >>>
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# "main_param_shards" : self.main_param_shards,
# })
# <<<
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self.optimizer.param_groups = \
[ g["orig_group"] for g in self.opt_group_shards ]
self.optimizer.load_state_dict(self.optimizer.state_dict())
# pax(0, {
# # "opt_group_shards" : self.opt_group_shards,
# # "param_groups" : self.optimizer.param_groups,
# "optimizer" : self.optimizer,
# "optimizer / state" : self.optimizer.state,
# })
# pax(1, {
# "optimizer" : self.optimizer,
# **{"optimizer / param_groups / %d" % i : g
# for i, g in enumerate(self.optimizer.param_groups)},
# "optimizer / state" : self.optimizer.state,
# "optimizer / state_dict" : self.optimizer.state_dict(),
# })
# Initialize main params.
self._copy_model_params_to_main_params()
@staticmethod
def has_nan_debug(tensors):
if isinstance(tensors, torch.Tensor):
tensors = [ tensors ]
assert isinstance(tensors, list)
has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
has_nan = any(has_nans)
return has_nan
def get_local_model_param_views(self):
'''** FOR DEBUGGING. **'''
model_param_views = []
for group_index, opt_group_shard in enumerate(self.opt_group_shards):
for param, opt_shard in opt_group_shard["param_map"].items():
model_index, dtype = self.param_gbuf_map[param]
gbuf_shard_map = \
self.model_gbuf_shards[model_index][dtype]["param_map"][param]
model_param_shard = gbuf_shard_map["param"]
model_param_views.append(
param.view(-1)[model_param_shard.start:model_param_shard.end])
return model_param_views
def get_local_model_grad_views(self):
'''** FOR DEBUGGING. **'''
model_grad_views = []
for group_index, opt_group_shard in enumerate(self.opt_group_shards):
for param, opt_shard in opt_group_shard["param_map"].items():
model_index, dtype = self.param_gbuf_map[param]
gbuf = self.models[model_index]._grad_buffers[dtype].data
gbuf_shard_map = \
self.model_gbuf_shards[model_index][dtype]["param_map"][param]
gbuf_world_shard = gbuf_shard_map["gbuf_world"]
model_grad_views.append(
gbuf[gbuf_world_shard.start:gbuf_world_shard.end])
return model_grad_views
def get_world_model_params(self):
'''** FOR DEBUGGING. **'''
return [ p for m in self.models for p in m.parameters() ]
def get_world_model_grads(self):
'''** FOR DEBUGGING. **'''
return [ p.main_grad for p in self.get_world_model_params() ]
def get_main_params(self):
return [ g["params"][0] for g in self.optimizer.param_groups ]
def get_main_grads(self):
return [ p.grad for p in self.get_main_params() ]
def get_main_param(self, group_index):
# return self.optimizer.param_groups[group_index]["params"][0]
return self.get_main_params()[group_index]
def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad
def load_state_dict(self):
raise Exception("hi.")
def reload_model_params(self):
raise Exception("hi.")
def state_dict(self):
raise Exception("hi.")
def zero_grad(self, set_to_none=True):
model_params = []
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
model_params.extend(param_map.keys())
# main_params = []
# for main_group in self.optimizer.param_groups:
# main_params.extend(main_group["params"])
# ** using contiguous buffer; don't set_to_none **
_zero_grad_group_helper(model_params, set_to_none = False) # set_to_none)
# _zero_grad_group_helper(params, set_to_none = False)
# pax(0, {"model_params": model_params})
# def get_model_grad_buffer_dp_views(self):
# # >>>
# # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
# args = get_args()
# assert args.use_contiguous_buffers_in_local_ddp
# # <<<
# # Grad buffer views.
# gbuf_view_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
# world_shards = gbuf_shard["world_all"]
# gbuf = model._grad_buffers[dtype].data
# gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
# gbuf_view_items.append((model_index, dtype, gbuf_views))
# # pax(0, {
# # "world_shards" : world_shards,
# # "gbuf_views" : gbuf_views,
# # })
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
# return gbuf_view_items
def get_model_grad_buffer_dp_views(self):
# >>>
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
args = get_args()
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views.
gbuf_view_items = []
for model_index, model in enumerate(self.models):
for dtype, gbuf in model._grad_buffers.items():
# gbuf_size = gbuf.numel_padded
assert gbuf.numel_padded % data_parallel_world_size == 0
shard_size = int(gbuf.numel_padded / data_parallel_world_size)
# pax(0, {
# "numel" : gbuf.numel,
# "numel_padded" : gbuf.numel_padded,
# "shard_size / f" : gbuf.numel_padded/data_parallel_world_size,
# "shard_size / i" : shard_size,
# })
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf_views))
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
return gbuf_view_items
def reduce_grads(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
args = get_args()
timers = get_timers()
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params.
# ... todo ...
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
# raise Exception("[fix] ready for weight sync?")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
raise Exception("only 'main_grad' supported for distrib-opt.")
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params.
# ... todo ...
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
# >>>
raise Exception("[fix] ready for t5 sync?")
# <<<
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
# >>>
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(
# unwrapped_model.language_model.embedding.position_embeddings.weight)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_position_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# timers('backward-params-reduce-scatter').start()
timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
# pax(0, {"gbufs": [
# g.data
# for m in self.models
# for g in m._grad_buffers.values()
# ]})
# >>>
# buffer_.data /= mpu.get_data_parallel_world_size()
# torch.distributed.all_reduce(
# buffer_.data, group=mpu.get_data_parallel_group())
# <<<
# >>>
# self.debug_main_param(0, "before reduce scatter")
# self.debug_main_grad(0, "before reduce scatter")
# <<<
for model_index, dtype, gbuf_views in gbuf_view_items:
# coalesced /= mpu.get_data_parallel_world_size()
gbuf = self.models[model_index]._grad_buffers[dtype].data
# >>>
# ~~ distributed.py ~~
# gbuf /= data_parallel_world_size
# torch.distributed.all_reduce(gbuf, group=data_parallel_group)
# pax(0, {
# "gbuf" : tp(gbuf),
# })
# <<<
# torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
gbuf /= data_parallel_world_size
# if 1:
# try:
# pax(0, {"gbuf_views": gbuf_views})
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
# except:
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# "gbuf_views" : gbuf_views,
# })
# else:
# torch.distributed.all_reduce(
# gbuf,
# group = data_parallel_group,
# )
# timers('backward-params-reduce-scatter').stop()
timers('backward-params-all-reduce').stop()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
def gather_params(self, ITERATION):
# >>>
timers = get_timers()
# <<<
timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# All-gather updated main params.
for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather(
gbuf_views,
gbuf_views[data_parallel_rank],
group = data_parallel_group,
)
# Each model param now contains its updated values in its
# '.main_grad' field.
# for param in self.param_gbuf_map: # ... incomplete param list.
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map:
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
# pax(0, {"gbuf_view_items": gbuf_view_items})
# >>>
# self.debug_main(ITERATION, "after/inside gather_params.", 0)
# self.debug_model(ITERATION, "after/inside gather_params.", 0)
# if ITERATION == 2:
# pax(1, {
# "ITERATION" : ITERATION,
# # "gbufs" : [
# # tp(b.data)
# # for m in self.models
# # for b in m._grad_buffers.values()
# # ],
# "param_gbuf_map" : [ str(tuple(p.shape)) for p in self.param_gbuf_map ],
# })
# <<<
def _collect_main_grad_data_for_unscaling(self):
return [ g.data for g in self.get_main_grads() ]
def _copy_model_params_to_main_params(self):
for group_index, group_shard in enumerate(self.opt_group_shards):
main_param = self.get_main_param(group_index)
for model_param, main_shard in group_shard["param_map"].items():
# Model shard.
model_index, dtype = self.param_gbuf_map[model_param]
model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["param"]
assert main_shard.size == model_shard.size
# Copy shard data.
main_view = main_param[main_shard.start:main_shard.end]
model_view = model_param.view(-1)[model_shard.start:model_shard.end]
main_view.detach().copy_(model_view)
def _copy_model_grads_to_main_grads(self, ITERATION):
for group_index, group_shard in enumerate(self.opt_group_shards):
for model_param, main_shard in group_shard["param_map"].items():
# Model shard.
model_index, dtype = self.param_gbuf_map[model_param]
model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["gbuf_world"]
assert main_shard.size == model_shard.size
# pax(0, {
# "model_param" : tp(model_param),
# "main_shard" : str(main_shard),
# "param shard" : self.model_gbuf_shards \
# [model_index][dtype]["param_map"][model_param],
# })
# Copy from DDP's contiguous buffer to main shard's grad.
model_grad = self.models[model_index]._grad_buffers[dtype].data
main_grad = self.get_main_grad(group_index)
# Copy sub-range within tensor.
model_view = model_grad[model_shard.start:model_shard.end]
main_view = main_grad[main_shard.start:main_shard.end]
main_view.detach().copy_(model_view)
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# # "param" : tp(param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "model_grad" : tp(model_grad),
# "main_grad" : tp(main_grad),
# "model_view" : tp(model_view),
# "main_view" : tp(main_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# >>>
# if 1 or ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# # "model grads" : self.get_world_model_grads(),
# "main_grads" : self.get_main_grads(),
# "group shards" : [
# "group %d; %s" % (grp_idx, main_shard)
# for grp_idx, grp_shard in enumerate(self.opt_group_shards)
# for model_param, main_shard in grp_shard["param_map"].items()
# ],
# })
# <<<
def _copy_main_params_to_model_params(self, ITERATION):
for group_index, group_shard in enumerate(self.opt_group_shards):
for model_param, main_shard in group_shard["param_map"].items():
model_index, dtype = self.param_gbuf_map[model_param]
model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["gbuf_world"]
assert main_shard.size == model_shard.size
# Use DDP's contiguous buffer to temporarily hold params.
model_param = self.models[model_index]._grad_buffers[dtype].data
main_param = self.get_main_param(group_index)
# Copy sub-range within tensor.
model_view = model_param[model_shard.start:model_shard.end]
main_view = main_param[main_shard.start:main_shard.end]
model_view.detach().copy_(main_view)
# Debug.
# pax(1, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# "model_param" : tp(model_param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "model_param" : tp(model_param),
# "main_param" : tp(main_param),
# "model_view" : tp(model_view),
# "main_view" : tp(main_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# "model params" : self.get_world_model_params(),
# })
# <<<
# <<<
......@@ -265,7 +265,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
@classmethod
def debug_general(cls, ITERATION, key, value):
def debug_base(cls, ITERATION, key, value):
from megatron import get_args
args = get_args()
my_rank = torch.distributed.get_rank()
......@@ -281,21 +281,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# else:
# exit(0)
exit(0)
# def _debug_model(self, ITERATION, key, use_param):
def debug_model(self, ITERATION, key, use_grad):
use_grad = bool(use_grad)
tensors = [
(p.main_grad.float() if use_grad else p.float())
for m in self.models for p in m.parameters()
]
# pax(0, {
# "params" : params,
# "params / abs" : [ torch.abs(p) for p in params ],
# "params / abs / sum" : [ torch.sum(torch.abs(p)) for p in params ],
# })
count = sum(t.nelement() for t in tensors)
return self.debug_general(
return self.debug_base(
ITERATION,
"model/%s, %s [count %d]" % (
"grad" if use_grad else "param",
......@@ -305,43 +298,6 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
sum(torch.sum(torch.abs(t)) for t in tensors),
)
# def debug_model_param(self, ITERATION, key):
# return self._debug_model(ITERATION, key, True)
# def debug_model_grad(self, ITERATION, key):
# return self._debug_model(ITERATION, key, False)
# def _debug_main(self, ITERATION, key0, key1, f, ff):
# count = sum(
# p.nelement()
# for g in self.optimizer.param_groups
# for p in g["params"]
# )
# return self.debug_general(
# ITERATION,
# "main/%s, %s [count %d]" % (key1, key0, count),
# sum(ff(f(p))
# for g in self.optimizer.param_groups
# for p in g["params"]).item() / count,
# )
# def debug_main_param(self, ITERATION, key):
# return self._debug_main(
# ITERATION,
# key,
# "param", # sum",
# # lambda p : p,
# lambda p : torch.abs(p),
# torch.sum,
# )
# def debug_main_grad(self, ITERATION, key):
# return self._debug_main(
# ITERATION,
# key,
# "grad", # sum",
# # lambda p : p.grad,
# lambda p : torch.abs(p.grad),
# torch.sum,
# )
# def _debug_main(self, ITERATION, key, use_param):
def debug_main(self, ITERATION, key, use_grad):
use_grad = bool(use_grad)
tensors = [
......@@ -351,7 +307,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
]
tensors = [ t.float() for t in tensors ]
count = sum(t.nelement() for t in tensors)
return self.debug_general(
return self.debug_base(
ITERATION,
"main/%s, %s [count %d]" % (
"grad" if use_grad else "param",
......@@ -360,10 +316,6 @@ class BaseFloat16Optimizer(MegatronOptimizer):
),
sum(torch.sum(torch.abs(t)) for t in tensors),
)
# def debug_main_param(self, ITERATION, key):
# return self._debug_main(ITERATION, key, True)
# def debug_main_grad(self, ITERATION, key):
# return self._debug_main(ITERATION, key, False)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad()
......@@ -787,779 +739,6 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
current_param.data.copy_(saved_param.data)
# >>>
import math
from megatron import get_args
# class ShardIndex:
class Shard:
def __init__(self, start, end):
self.start = start
self.end = end
self.size = end - start
def normalize(self, start = 0):
return Shard(start, start + self.size)
def __str__(self):
return "%d,%d [%d]" % (self.start, self.end, self.size)
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(MegatronOptimizer):
class Float16DistributedOptimizer(BaseFloat16Optimizer):
@classmethod
def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard):
# Param shard map.
param_world_index_map = model._grad_buffer_param_index_map[dtype]
param_shard_map = {}
for param, param_world_indexes in param_world_index_map.items():
# Shard range.
param_world_start, param_world_end = param_world_indexes
param_local_start = max(
0,
param_world_start - gbuf_world_shard.start)
param_local_end = min(
gbuf_world_shard.size,
param_world_end - gbuf_world_shard.start)
# Add shard, if within range.
if param_local_end > param_local_start:
param_local_shard = Shard(param_local_start, param_local_end)
# param_world_shard = param_local_shard.normalize(param_world_start)
param_world_shard = param_local_shard.normalize(
param_local_start + gbuf_world_shard.start)
sub_param_start = max(0, gbuf_world_shard.start-param_world_start)
sub_param_shard = param_local_shard.normalize(sub_param_start)
param_shard_map[param] = {
"gbuf_world" : param_world_shard,
"gbuf_local" : param_local_shard,
"param" : sub_param_shard,
}
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
return param_shard_map
@classmethod
def get_model_gbuf_shard(cls, model, dtype):
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer shard.
grad_buffer = model._grad_buffers[dtype]
gbuf_size = grad_buffer.numel
max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size))
gbuf_world_all_shards = []
for r in range(data_parallel_world_size):
gbuf_world_start = r * max_gbuf_shard_size
gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size)
gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end)
gbuf_world_all_shards.append(gbuf_world_shard)
# >>>
# if max_gbuf_shard_size != gbuf_world_shard.size:
# raise Exception("%d: smaller, rank %d. [ %d -> %d vs. %d]" % (
# data_parallel_rank,
# r,
# gbuf_size,
# max_gbuf_shard_size,
# gbuf_world_shard.size,
# ))
# <<<
gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank]
gbuf_local_shard = gbuf_world_shard.normalize()
# Param shards.
param_shard_map = cls.get_model_gbuf_param_shard_map(model,
dtype,
gbuf_world_shard)
# Altogether.
data = {
"local" : gbuf_local_shard,
"world" : gbuf_world_shard,
"world_all" : gbuf_world_all_shards,
"param_map" : param_shard_map,
"max_shard_size" : max_gbuf_shard_size,
}
# pax(0, {"data": data})
return data
@classmethod
def get_model_gbuf_shard_map(cls, model):
return {
dtype : cls.get_model_gbuf_shard(model, dtype)
for dtype in model._grad_buffers
}
@classmethod
def get_param_gbuf_map(cls, model_gbuf_shards):
param_gbuf_map = {}
for model_index, model_gbuf_shard_map in enumerate(model_gbuf_shards):
for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
for param, param_shard_map in gbuf_shard_map["param_map"].items():
# assert param not in param_size_map
# param_size_map[param] = param_shard_map["local"].size
param_gbuf_map[param] = (model_index, dtype)
# pax(0, {
# "dtype" : dtype,
# "gbuf_shard_map" : gbuf_shard_map,
# "param" : tp(param),
# "param_shard_map" : param_shard_map,
# })
# pax(0, {
# "model_gbuf_shards" : model_gbuf_shards,
# # "param_size_map" :
# # [ (str(p.shape), s) for p, s in param_size_map.items() ],
# "param_gbuf_map" : param_gbuf_map,
# })
return param_gbuf_map
@classmethod
def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards):
num_groups = len(param_groups)
# Param group map.
param_group_map = {}
for group_index, group in enumerate(param_groups):
for param in group["params"]:
assert param.requires_grad
param_group_map[param] = group_index
# Optimizer group shards.
group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ]
for model_gbuf_shard_map in model_gbuf_shards:
for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
for param in gbuf_shard_map["param_map"]:
group_index = param_group_map[param]
group_shard = group_shards[group_index]
param_size = gbuf_shard_map["param_map"][param]["param"].size
param_group_start = group_shard["size"]
param_group_end = param_group_start + param_size
param_group_shard = Shard(param_group_start, param_group_end)
# group_shard["max_size"] = gbuf_shard_map["max_shard_size"]
group_shard["size"] += param_size
group_shard["param_map"][param] = param_group_shard
# pax(0, {"gbuf_shard_map": gbuf_shard_map})
# >>>
# if torch.distributed.get_rank() == 1:
# print(">>> [%d] ... group %d, size %d, param %s. <<<" % (
# torch.distributed.get_rank(),
# group_index,
# param_size,
# str(tuple(param.shape)),
# ))
# <<<
# Squeeze zero-size group shards.
for group_index, group_shard in enumerate(group_shards):
group_shard["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if g["size"] > 0 ]
# [ ... x ... ] Synchronize group sizes across ranks.
# pax(0, {
# "param_group_map": [
# (g, str(p.shape))
# for p, g in param_group_map.items()
# ],
# "group_shards" : group_shards,
# })
return group_shards
@classmethod
def allocate_main_param_shards(cls, opt_group_shards):
# Allocate main param/grad shard.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
# main_param_shards = []
for group_index, group_shard in enumerate(opt_group_shards):
# pax(0, {
# "group_shard" : group_shard,
# })
group_size = group_shard["size"]
assert group_size != 0, "temporary check ... remove me."
# ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard.
# if group_size == 0:
# main_param = None
# else:
main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# main_param_shards.append(main_param)
group_shard["orig_group"]["params"] = [ main_param ]
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# pax(1, {
# "opt_group_shards" : opt_group_shards,
# "main_param_shards" : main_param_shards,
# })
# return main_param_shards
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models):
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models)
# >>>
args = get_args()
assert args.use_contiguous_buffers_in_local_ddp # already checked in args
# <<<
# Model grad buffer shards.
self.model_gbuf_shards = []
for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards)
# Optimizer shards.
self.opt_group_shards = self.get_optimizer_group_shards(
self.optimizer.param_groups,
self.model_gbuf_shards)
# Allocate main param shards.
self.allocate_main_param_shards(self.opt_group_shards)
# >>>
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# "main_param_shards" : self.main_param_shards,
# })
# <<<
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self.optimizer.param_groups = \
[ g["orig_group"] for g in self.opt_group_shards ]
self.optimizer.load_state_dict(self.optimizer.state_dict())
# pax(0, {
# # "opt_group_shards" : self.opt_group_shards,
# # "param_groups" : self.optimizer.param_groups,
# "optimizer" : self.optimizer,
# "optimizer / state" : self.optimizer.state,
# })
# pax(1, {
# "optimizer" : self.optimizer,
# **{"optimizer / param_groups / %d" % i : g
# for i, g in enumerate(self.optimizer.param_groups)},
# "optimizer / state" : self.optimizer.state,
# "optimizer / state_dict" : self.optimizer.state_dict(),
# })
# Initialize main params.
self._copy_model_params_to_main_params()
@staticmethod
def has_nan_debug(tensors):
if isinstance(tensors, torch.Tensor):
tensors = [ tensors ]
assert isinstance(tensors, list)
has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
has_nan = any(has_nans)
return has_nan
def get_local_model_param_views(self):
'''** FOR DEBUGGING. **'''
model_param_views = []
for group_index, opt_group_shard in enumerate(self.opt_group_shards):
for param, opt_shard in opt_group_shard["param_map"].items():
model_index, dtype = self.param_gbuf_map[param]
gbuf_shard_map = \
self.model_gbuf_shards[model_index][dtype]["param_map"][param]
model_param_shard = gbuf_shard_map["param"]
model_param_views.append(
param.view(-1)[model_param_shard.start:model_param_shard.end])
return model_param_views
def get_local_model_grad_views(self):
'''** FOR DEBUGGING. **'''
model_grad_views = []
for group_index, opt_group_shard in enumerate(self.opt_group_shards):
for param, opt_shard in opt_group_shard["param_map"].items():
model_index, dtype = self.param_gbuf_map[param]
gbuf = self.models[model_index]._grad_buffers[dtype].data
gbuf_shard_map = \
self.model_gbuf_shards[model_index][dtype]["param_map"][param]
gbuf_world_shard = gbuf_shard_map["gbuf_world"]
model_grad_views.append(
gbuf[gbuf_world_shard.start:gbuf_world_shard.end])
return model_grad_views
def get_world_model_params(self):
'''** FOR DEBUGGING. **'''
return [ p for m in self.models for p in m.parameters() ]
def get_world_model_grads(self):
'''** FOR DEBUGGING. **'''
return [ p.main_grad for p in self.get_world_model_params() ]
def get_main_params(self):
return [ g["params"][0] for g in self.optimizer.param_groups ]
def get_main_grads(self):
return [ p.grad for p in self.get_main_params() ]
def get_main_param(self, group_index):
# return self.optimizer.param_groups[group_index]["params"][0]
return self.get_main_params()[group_index]
def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad
def load_state_dict(self):
raise Exception("hi.")
def reload_model_params(self):
raise Exception("hi.")
def state_dict(self):
raise Exception("hi.")
def zero_grad(self, set_to_none=True):
model_params = []
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
model_params.extend(param_map.keys())
# main_params = []
# for main_group in self.optimizer.param_groups:
# main_params.extend(main_group["params"])
# ** using contiguous buffer; don't set_to_none **
_zero_grad_group_helper(model_params, set_to_none = False) # set_to_none)
# _zero_grad_group_helper(params, set_to_none = False)
# pax(0, {"model_params": model_params})
# def get_model_grad_buffer_dp_views(self):
# # >>>
# # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
# args = get_args()
# assert args.use_contiguous_buffers_in_local_ddp
# # <<<
# # Grad buffer views.
# gbuf_view_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
# world_shards = gbuf_shard["world_all"]
# gbuf = model._grad_buffers[dtype].data
# gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
# gbuf_view_items.append((model_index, dtype, gbuf_views))
# # pax(0, {
# # "world_shards" : world_shards,
# # "gbuf_views" : gbuf_views,
# # })
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
# return gbuf_view_items
def get_model_grad_buffer_dp_views(self):
# >>>
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
args = get_args()
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views.
gbuf_view_items = []
for model_index, model in enumerate(self.models):
for dtype, gbuf in model._grad_buffers.items():
# gbuf_size = gbuf.numel_padded
assert gbuf.numel_padded % data_parallel_world_size == 0
shard_size = int(gbuf.numel_padded / data_parallel_world_size)
# pax(0, {
# "numel" : gbuf.numel,
# "numel_padded" : gbuf.numel_padded,
# "shard_size / f" : gbuf.numel_padded/data_parallel_world_size,
# "shard_size / i" : shard_size,
# })
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf_views))
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
return gbuf_view_items
def reduce_grads(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
args = get_args()
timers = get_timers()
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params.
# ... todo ...
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
# raise Exception("[fix] ready for weight sync?")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
raise Exception("only 'main_grad' supported for distrib-opt.")
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params.
# ... todo ...
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
# >>>
raise Exception("[fix] ready for t5 sync?")
# <<<
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
# >>>
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(
# unwrapped_model.language_model.embedding.position_embeddings.weight)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_position_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# timers('backward-params-reduce-scatter').start()
timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
# pax(0, {"gbufs": [
# g.data
# for m in self.models
# for g in m._grad_buffers.values()
# ]})
# >>>
# buffer_.data /= mpu.get_data_parallel_world_size()
# torch.distributed.all_reduce(
# buffer_.data, group=mpu.get_data_parallel_group())
# <<<
# >>>
# self.debug_main_param(0, "before reduce scatter")
# self.debug_main_grad(0, "before reduce scatter")
# <<<
for model_index, dtype, gbuf_views in gbuf_view_items:
# coalesced /= mpu.get_data_parallel_world_size()
gbuf = self.models[model_index]._grad_buffers[dtype].data
# >>>
# ~~ distributed.py ~~
# gbuf /= data_parallel_world_size
# torch.distributed.all_reduce(gbuf, group=data_parallel_group)
# pax(0, {
# "gbuf" : tp(gbuf),
# })
# <<<
# torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
gbuf /= data_parallel_world_size
# if 1:
# try:
# pax(0, {"gbuf_views": gbuf_views})
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
# except:
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# "gbuf_views" : gbuf_views,
# })
# else:
# torch.distributed.all_reduce(
# gbuf,
# group = data_parallel_group,
# )
# timers('backward-params-reduce-scatter').stop()
timers('backward-params-all-reduce').stop()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
def gather_params(self, ITERATION):
# >>>
timers = get_timers()
# <<<
timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# All-gather updated main params.
for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather(
gbuf_views,
gbuf_views[data_parallel_rank],
group = data_parallel_group,
)
# Each model param now contains its updated values in its
# '.main_grad' field.
# for param in self.param_gbuf_map: # ... incomplete param list.
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map:
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
# pax(0, {"gbuf_view_items": gbuf_view_items})
# >>>
# self.debug_main(ITERATION, "after/inside gather_params.", 0)
# self.debug_model(ITERATION, "after/inside gather_params.", 0)
# if ITERATION == 2:
# pax(1, {
# "ITERATION" : ITERATION,
# # "gbufs" : [
# # tp(b.data)
# # for m in self.models
# # for b in m._grad_buffers.values()
# # ],
# "param_gbuf_map" : [ str(tuple(p.shape)) for p in self.param_gbuf_map ],
# })
# <<<
def _collect_main_grad_data_for_unscaling(self):
return [ g.data for g in self.get_main_grads() ]
def _copy_model_params_to_main_params(self):
for group_index, group_shard in enumerate(self.opt_group_shards):
main_param = self.get_main_param(group_index)
for model_param, main_shard in group_shard["param_map"].items():
# Model shard.
model_index, dtype = self.param_gbuf_map[model_param]
model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["param"]
assert main_shard.size == model_shard.size
# Copy shard data.
main_view = main_param[main_shard.start:main_shard.end]
model_view = model_param.view(-1)[model_shard.start:model_shard.end]
main_view.detach().copy_(model_view)
def _copy_model_grads_to_main_grads(self, ITERATION):
for group_index, group_shard in enumerate(self.opt_group_shards):
for model_param, main_shard in group_shard["param_map"].items():
# Model shard.
model_index, dtype = self.param_gbuf_map[model_param]
model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["gbuf_world"]
assert main_shard.size == model_shard.size
# pax(0, {
# "model_param" : tp(model_param),
# "main_shard" : str(main_shard),
# "param shard" : self.model_gbuf_shards \
# [model_index][dtype]["param_map"][model_param],
# })
# Copy from DDP's contiguous buffer to main shard's grad.
model_grad = self.models[model_index]._grad_buffers[dtype].data
main_grad = self.get_main_grad(group_index)
# Copy sub-range within tensor.
model_view = model_grad[model_shard.start:model_shard.end]
main_view = main_grad[main_shard.start:main_shard.end]
main_view.detach().copy_(model_view)
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# # "param" : tp(param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "model_grad" : tp(model_grad),
# "main_grad" : tp(main_grad),
# "model_view" : tp(model_view),
# "main_view" : tp(main_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# >>>
# if 1 or ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# # "model grads" : self.get_world_model_grads(),
# "main_grads" : self.get_main_grads(),
# "group shards" : [
# "group %d; %s" % (grp_idx, main_shard)
# for grp_idx, grp_shard in enumerate(self.opt_group_shards)
# for model_param, main_shard in grp_shard["param_map"].items()
# ],
# })
# <<<
def _copy_main_params_to_model_params(self, ITERATION):
for group_index, group_shard in enumerate(self.opt_group_shards):
for model_param, main_shard in group_shard["param_map"].items():
model_index, dtype = self.param_gbuf_map[model_param]
model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["gbuf_world"]
assert main_shard.size == model_shard.size
# Use DDP's contiguous buffer to temporarily hold params.
model_param = self.models[model_index]._grad_buffers[dtype].data
main_param = self.get_main_param(group_index)
# Copy sub-range within tensor.
model_view = model_param[model_shard.start:model_shard.end]
main_view = main_param[main_shard.start:main_shard.end]
model_view.detach().copy_(main_view)
# Debug.
# pax(1, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# "model_param" : tp(model_param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "model_param" : tp(model_param),
# "main_param" : tp(main_param),
# "model_view" : tp(model_view),
# "main_view" : tp(main_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# "model params" : self.get_world_model_params(),
# })
# <<<
# <<<
class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment