Commit 291592e4 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed zero-size optimizer group shards.

parent 23f9238d
......@@ -875,7 +875,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ))
# <<<
# pax(1, {
# Squeeze zero-size group shards.
for group_index, group_shard in enumerate(group_shards):
group_shard["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if g["size"] > 0 ]
# pax(0, {
# "param_group_map": [
# (g, str(p.shape))
# for p, g in param_group_map.items()
......@@ -885,6 +890,47 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return group_shards
@classmethod
def allocate_main_param_shards(cls, opt_group_shards):
# Allocate main param/grad shard.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
# main_param_shards = []
for group_index, group_shard in enumerate(opt_group_shards):
group_size = group_shard["size"]
assert group_size != 0, "temporary check ... remove me."
# ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard.
# if group_size == 0:
# main_param = None
# else:
main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# main_param_shards.append(main_param)
group_shard["orig_group"]["params"] = [ main_param ]
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# pax(1, {
# "opt_group_shards" : opt_group_shards,
# "main_param_shards" : main_param_shards,
# })
# return main_param_shards
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models):
......@@ -910,52 +956,36 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards)
# pax(0, {"param_gbuf_map": [ (str(tuple(p.shape)), d) for p, d in self.param_gbuf_map.items() ]})
# Optimizer shards.
self.opt_group_shards = self.get_optimizer_group_shards(
self.optimizer.param_groups,
self.model_gbuf_shards)
# Allocate main param/grad shard.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
self.main_param_shards = []
for group_index, group_shard in enumerate(self.opt_group_shards):
group_size = group_shard["size"]
# ** todo: for dtype in model_main_dtypes ........ **
# pax(0, {**{"opt_group_shards / %d" % i : g for i, g in enumerate(self.opt_group_shards)}})
# Allocate shard.
if group_size == 0:
main_param = None
else:
main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
self.main_param_shards.append(main_param)
# Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# Allocate main param shards.
# self.main_param_shards = \
# self.allocate_main_param_shards(self.opt_group_shards)
self.allocate_main_param_shards(self.opt_group_shards)
# >>>
pax(0, {
"model_gbuf_shards" : self.model_gbuf_shards,
"opt_group_shards" : self.opt_group_shards,
"main_param_shards" : self.main_param_shards,
})
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# "main_param_shards" : self.main_param_shards,
# })
# <<<
# Initialize main params.
self._copy_model_params_to_main_params()
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self.optimizer.param_groups = \
[ g["orig_group"] for g in self.opt_group_shards ]
self.optimizer.load_state_dict(self.optimizer.state_dict())
......@@ -1069,11 +1099,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def _collect_main_grad_data_for_unscaling(self):
return [ p.grad.data for p in self.main_param_shards ]
# return [ p.grad.data for p in self.main_param_shards if p is not None ]
def _copy_model_params_to_main_params(self):
for group_index, group_shard in enumerate(self.opt_group_shards):
main_param = self.main_param_shards[group_index]
# main_param = self.main_param_shards[group_index]
main_param = self.optimizer.param_groups[group_index]["params"][0]
pax(0, {"main_param": tp(main_param)})
for model_param, main_shard in group_shard["param_map"].items():
# Model shard.
......@@ -1098,13 +1131,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "model_shard" : str(model_shard),
# })
pax(1, {
**{
"opt_group_shards / %d" % i : s
for i, s in enumerate(self.opt_group_shards)
},
"main_param_shards" : self.main_param_shards,
})
# pax(1, {
# **{
# "opt_group_shards / %d" % i : s
# for i, s in enumerate(self.opt_group_shards)
# },
# "main_param_shards" : self.main_param_shards,
# })
def _copy_model_grads_to_main_grads(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment