Commit 45b364b1 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

consolidated reduce_grad's sub-methods (i.e., allreduce_embedding_grads)

parent a9b1fc0a
...@@ -68,7 +68,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -68,7 +68,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Add shard, if within range. # Add shard, if within range.
if param_local_end > param_local_start: if param_local_end > param_local_start:
param_local_shard = Shard(param_local_start, param_local_end) param_local_shard = Shard(param_local_start, param_local_end)
# param_world_shard = param_local_shard.normalize(param_world_start)
param_world_shard = param_local_shard.normalize( param_world_shard = param_local_shard.normalize(
param_local_start + gbuf_world_shard.start) param_local_start + gbuf_world_shard.start)
sub_param_start = max(0, gbuf_world_shard.start-param_world_start) sub_param_start = max(0, gbuf_world_shard.start-param_world_start)
...@@ -79,8 +78,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -79,8 +78,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"param" : sub_param_shard, "param" : sub_param_shard,
} }
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
return param_shard_map return param_shard_map
@classmethod @classmethod
...@@ -94,26 +91,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -94,26 +91,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_size = grad_buffer.numel gbuf_size = grad_buffer.numel
max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size)) max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size))
# All world shards. (i.e., across all data parallel ranks)
gbuf_world_all_shards = [] gbuf_world_all_shards = []
for r in range(data_parallel_world_size): for r in range(data_parallel_world_size):
gbuf_world_start = r * max_gbuf_shard_size gbuf_world_start = r * max_gbuf_shard_size
gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size) gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size)
gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end) gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end)
gbuf_world_all_shards.append(gbuf_world_shard) gbuf_world_all_shards.append(gbuf_world_shard)
# >>>
# if max_gbuf_shard_size != gbuf_world_shard.size: # Local DP's shards.
# raise Exception("%d: smaller, rank %d. [ %d -> %d vs. %d]" % (
# data_parallel_rank,
# r,
# gbuf_size,
# max_gbuf_shard_size,
# gbuf_world_shard.size,
# ))
# <<<
gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank] gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank]
gbuf_local_shard = gbuf_world_shard.normalize() gbuf_local_shard = gbuf_world_shard.normalize()
# Param shards. # Get each param's shards.
param_shard_map = cls.get_model_gbuf_param_shard_map(model, param_shard_map = cls.get_model_gbuf_param_shard_map(model,
dtype, dtype,
gbuf_world_shard) gbuf_world_shard)
...@@ -127,8 +117,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -127,8 +117,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"max_shard_size" : max_gbuf_shard_size, "max_shard_size" : max_gbuf_shard_size,
} }
# pax(0, {"data": data})
return data return data
@classmethod @classmethod
...@@ -140,28 +128,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -140,28 +128,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
def get_param_gbuf_map(cls, model_gbuf_shards): def get_param_gbuf_map(cls, model_gbuf_shards):
'''Create a reverse of the model_gbuf_shards, for referencing in
opposite direction.'''
param_gbuf_map = {} param_gbuf_map = {}
for model_index, model_gbuf_shard_map in enumerate(model_gbuf_shards): for model_index, model_gbuf_shard_map in enumerate(model_gbuf_shards):
for dtype, gbuf_shard_map in model_gbuf_shard_map.items(): for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
for param, param_shard_map in gbuf_shard_map["param_map"].items(): for param, param_shard_map in gbuf_shard_map["param_map"].items():
# assert param not in param_size_map
# param_size_map[param] = param_shard_map["local"].size
param_gbuf_map[param] = (model_index, dtype) param_gbuf_map[param] = (model_index, dtype)
# pax(0, {
# "dtype" : dtype,
# "gbuf_shard_map" : gbuf_shard_map,
# "param" : tp(param),
# "param_shard_map" : param_shard_map,
# })
# pax(0, {
# "model_gbuf_shards" : model_gbuf_shards,
# # "param_size_map" :
# # [ (str(p.shape), s) for p, s in param_size_map.items() ],
# "param_gbuf_map" : param_gbuf_map,
# })
return param_gbuf_map return param_gbuf_map
@classmethod @classmethod
...@@ -190,82 +163,40 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -190,82 +163,40 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_group_end = param_group_start + param_size param_group_end = param_group_start + param_size
param_group_shard = Shard(param_group_start, param_group_end) param_group_shard = Shard(param_group_start, param_group_end)
# group_shard["max_size"] = gbuf_shard_map["max_shard_size"]
group_shard["size"] += param_size group_shard["size"] += param_size
group_shard["param_map"][param] = param_group_shard group_shard["param_map"][param] = param_group_shard
# pax(0, {"gbuf_shard_map": gbuf_shard_map})
# >>>
# if torch.distributed.get_rank() == 1:
# print(">>> [%d] ... group %d, size %d, param %s. <<<" % (
# torch.distributed.get_rank(),
# group_index,
# param_size,
# str(tuple(param.shape)),
# ))
# <<<
# Squeeze zero-size group shards. # Squeeze zero-size group shards.
for group_index, group_shard in enumerate(group_shards): for group_index, group_shard in enumerate(group_shards):
group_shard["orig_group"] = param_groups[group_index] group_shard["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if g["size"] > 0 ] group_shards = [ g for g in group_shards if g["size"] > 0 ]
# [ ... x ... ] Synchronize group sizes across ranks.
# pax(0, {
# "param_group_map": [
# (g, str(p.shape))
# for p, g in param_group_map.items()
# ],
# "group_shards" : group_shards,
# })
return group_shards return group_shards
@classmethod @classmethod
def allocate_main_param_shards(cls, opt_group_shards): def allocate_main_param_shards(cls, opt_group_shards):
# Allocate main param/grad shard. # Allocator method.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty( allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,), (shard_size,),
dtype = dtype, dtype = dtype,
device = torch.cuda.current_device(), device = torch.cuda.current_device(),
requires_grad = True) requires_grad = True)
# main_param_shards = [] # Allocate each group's param/grad shard.
for group_index, group_shard in enumerate(opt_group_shards): for group_index, group_shard in enumerate(opt_group_shards):
# pax(0, {
# "group_shard" : group_shard,
# })
group_size = group_shard["size"] group_size = group_shard["size"]
assert group_size != 0, "temporary check ... remove me." assert group_size != 0, "temporary check ... remove me."
# ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard. # Allocate shard.
# if group_size == 0:
# main_param = None
# else:
main_param = allocate_shard(group_size, torch.float) main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float) main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1) mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# main_param_shards.append(main_param) # Update group's param.
group_shard["orig_group"]["params"] = [ main_param ] group_shard["orig_group"]["params"] = [ main_param ]
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# pax(1, {
# "opt_group_shards" : opt_group_shards,
# "main_param_shards" : main_param_shards,
# })
# return main_param_shards
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
...@@ -276,10 +207,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -276,10 +207,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models) bf16, grad_scaler, models)
# >>> # Verify that contiguous buffers are being used
# - Note: this should already be checked in arguments.py
args = get_args() args = get_args()
assert args.use_contiguous_buffers_in_local_ddp # already checked in args assert args.use_contiguous_buffers_in_local_ddp
# <<<
# Model grad buffer shards. # Model grad buffer shards.
self.model_gbuf_shards = [] self.model_gbuf_shards = []
...@@ -295,14 +226,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -295,14 +226,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Allocate main param shards. # Allocate main param shards.
self.allocate_main_param_shards(self.opt_group_shards) self.allocate_main_param_shards(self.opt_group_shards)
# >>>
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# "main_param_shards" : self.main_param_shards,
# })
# <<<
# Update optimizer groups. # Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to # - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors. # recast preexisting per-param state tensors.
...@@ -310,27 +233,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -310,27 +233,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[ g["orig_group"] for g in self.opt_group_shards ] [ g["orig_group"] for g in self.opt_group_shards ]
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# pax(0, {
# # "opt_group_shards" : self.opt_group_shards,
# # "param_groups" : self.optimizer.param_groups,
# "optimizer" : self.optimizer,
# "optimizer / state" : self.optimizer.state,
# })
# pax(1, {
# "optimizer" : self.optimizer,
# **{"optimizer / param_groups / %d" % i : g
# for i, g in enumerate(self.optimizer.param_groups)},
# "optimizer / state" : self.optimizer.state,
# "optimizer / state_dict" : self.optimizer.state_dict(),
# })
# Initialize main params. # Initialize main params.
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
def get_model_parallel_group(self): def get_model_parallel_group(self):
# >>>
# i.e., no param replication across this group
# <<<
return None return None
# @staticmethod # @staticmethod
...@@ -378,7 +284,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -378,7 +284,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_main_grads(self): def get_main_grads(self):
return [ p.grad for p in self.get_main_params() ] return [ p.grad for p in self.get_main_params() ]
def get_main_param(self, group_index): def get_main_param(self, group_index):
# return self.optimizer.param_groups[group_index]["params"][0]
return self.get_main_params()[group_index] return self.get_main_params()[group_index]
def get_main_grad(self, group_index): def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad return self.get_main_param(group_index).grad
...@@ -476,90 +381,77 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -476,90 +381,77 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return gbuf_view_items return gbuf_view_items
def reduce_grads(self, model): # def reduce_grads(self, model):
def reduce_grads(self, args, timers):
# >>> # >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP # from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args # from megatron import get_args
from megatron import get_timers # from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP # from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module # from megatron.model import Float16Module
from megatron.utils import unwrap_model # from megatron.utils import unwrap_model
args = get_args() # args = get_args()
timers = get_timers() # timers = get_timers()
# <<< # <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # All-reduce embedding grads.
# Sync word embedding params.
# ... todo ...
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ self.allreduce_embedding_grads()
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
# raise Exception("[fix] ready for weight sync?")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
raise Exception("only 'main_grad' supported for distrib-opt.")
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params.
# ... todo ...
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
# >>>
raise Exception("[fix] ready for t5 sync?")
# <<<
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
# >>>
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(
# unwrapped_model.language_model.embedding.position_embeddings.weight)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_position_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
# # All-reduce word_embeddings' grad across first and last stages to ensure
# # that word_embeddings parameters stay in sync.
# # This should only run for models that support pipelined model parallelism
# # (BERT and GPT-2).
# timers('backward-embedding-all-reduce').start()
# if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
# mpu.get_pipeline_model_parallel_world_size() > 1:
# if mpu.is_pipeline_first_stage(ignore_virtual=True):
# unwrapped_model = model[0]
# elif mpu.is_pipeline_last_stage(ignore_virtual=True):
# unwrapped_model = model[-1]
# else: # We do not support the interleaved schedule for T5 yet.
# unwrapped_model = model[0]
# unwrapped_model = unwrap_model(
# unwrapped_model, (torchDDP, LocalDDP, Float16Module))
# if unwrapped_model.share_word_embeddings:
# word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# if args.DDP_impl == 'local':
# grad = word_embeddings_weight.main_grad
# else:
# raise Exception("only 'main_grad' supported for distrib-opt.")
# grad = word_embeddings_weight.grad
# torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# # All-reduce position_embeddings grad across first (encoder) and split (decoder)
# # stages to ensure that position embeddings parameters stay in sync.
# # This should only run for T5 models with pipeline parallelism
# if mpu.is_rank_in_position_embedding_group() and \
# mpu.get_pipeline_model_parallel_world_size() > 1 and \
# args.pipeline_model_parallel_split_rank is not None:
# # >>>
# raise Exception("[fix] ready for t5 sync?")
# # <<<
# unwrapped_model = model[0]
# unwrapped_model = unwrap_model(
# unwrapped_model, (torchDDP, LocalDDP, Float16Module))
# assert args.DDP_impl == 'local', \
# 'T5 model is only supported with local DDP mode'
# # >>>
# grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
# torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# # +++
# # grad_shard = optimizer.get_grad_shard(
# # unwrapped_model.language_model.embedding.position_embeddings.weight)
# # torch.distributed.all_reduce(grad_shard,
# # group=mpu.get_position_embedding_group())
# # <<<
# timers('backward-embedding-all-reduce').stop()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter. # Reduce-scatter.
# timers('backward-params-reduce-scatter').start() # timers('backward-params-reduce-scatter').start()
......
...@@ -183,33 +183,15 @@ class MegatronOptimizer(ABC): ...@@ -183,33 +183,15 @@ class MegatronOptimizer(ABC):
def gather_params(self, ITERATION): def gather_params(self, ITERATION):
pass pass
def reduce_grads(self, model): def allreduce_word_embedding_grads(self):
'''
All-reduce word embedding grads.
# >>> Reduce grads across first and last stages to ensure that word_embeddings
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP parameters stay in sync. This should only run for models that support
pipelined model parallelism (BERT and GPT-2).
from megatron import get_args '''
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
args = get_args()
timers = get_timers()
# <<<
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
# >>> # >>>
...@@ -232,15 +214,16 @@ class MegatronOptimizer(ABC): ...@@ -232,15 +214,16 @@ class MegatronOptimizer(ABC):
grad = word_embeddings_weight.grad grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder) def allreduce_position_embedding_grads(self):
# stages to ensure that position embeddings parameters stay in sync. '''
# This should only run for T5 models with pipeline parallelism All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters
stay in sync. This should only run for T5 models with pipeline
parallelism.
'''
if mpu.is_rank_in_position_embedding_group() and \ if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \ mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None: args.pipeline_model_parallel_split_rank is not None:
# >>>
raise Exception("[main] ready for t5 sync?")
# <<<
unwrapped_model = model[0] unwrapped_model = model[0]
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model, (torchDDP, LocalDDP, Float16Module))
...@@ -248,8 +231,45 @@ class MegatronOptimizer(ABC): ...@@ -248,8 +231,45 @@ class MegatronOptimizer(ABC):
'T5 model is only supported with local DDP mode' 'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
def allreduce_embedding_grads(self):
self.allreduce_word_embedding_grads()
self.allreduce_position_embedding_grads()
# def reduce_grads(self, model):
def reduce_grads(self, args, timers):
# pax(0, {
# "*models" : self.models,
# "model" : model,
# })
# >>>
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
# from megatron import get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
# args = get_args()
# timers = get_timers()
# <<<
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in self.models:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads()
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
# class BaseFloat16Optimizer(MegatronOptimizer): # class BaseFloat16Optimizer(MegatronOptimizer):
class MixedPrecisionOptimizer(MegatronOptimizer): class MixedPrecisionOptimizer(MegatronOptimizer):
......
...@@ -436,29 +436,8 @@ def train_step(forward_step_func, data_iterator, ...@@ -436,29 +436,8 @@ def train_step(forward_step_func, data_iterator,
# <<< # <<<
# >>> # >>>
# Reduce gradients. (with distributed optimizer option, optimizer # Reduce gradients.
# now responsible for reducing gradients) optimizer.reduce_grads(args, timers) # model)
optimizer.reduce_grads(model)
# <<<
# >>>
# r = mpu.get_data_parallel_rank()
# w = mpu.get_data_parallel_world_size()
# gbufs = []
# for m in model:
# for g in m._grad_buffers.values():
# t = g.data
# n = t.nelement()
# shard = int(n / w)
# start_index = r * shard
# end_index = min(n, start_index + shard)
# gbufs.append(t[start_index:end_index])
# pax(1, {"gbufs": gbufs})
# <<<
# >>>
# from lutil import pax
# pax(0, {"optimizer": optimizer})
# <<< # <<<
# Update parameters. # Update parameters.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment