Commit 291592e4 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed zero-size optimizer group shards.

parent 23f9238d
...@@ -875,7 +875,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -875,7 +875,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# )) # ))
# <<< # <<<
# pax(1, { # Squeeze zero-size group shards.
for group_index, group_shard in enumerate(group_shards):
group_shard["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if g["size"] > 0 ]
# pax(0, {
# "param_group_map": [ # "param_group_map": [
# (g, str(p.shape)) # (g, str(p.shape))
# for p, g in param_group_map.items() # for p, g in param_group_map.items()
...@@ -885,6 +890,47 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -885,6 +890,47 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return group_shards return group_shards
@classmethod
def allocate_main_param_shards(cls, opt_group_shards):
# Allocate main param/grad shard.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
# main_param_shards = []
for group_index, group_shard in enumerate(opt_group_shards):
group_size = group_shard["size"]
assert group_size != 0, "temporary check ... remove me."
# ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard.
# if group_size == 0:
# main_param = None
# else:
main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# main_param_shards.append(main_param)
group_shard["orig_group"]["params"] = [ main_param ]
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# pax(1, {
# "opt_group_shards" : opt_group_shards,
# "main_param_shards" : main_param_shards,
# })
# return main_param_shards
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models): bf16, grad_scaler, models):
...@@ -910,52 +956,36 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -910,52 +956,36 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model)) self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards) self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards)
# pax(0, {"param_gbuf_map": [ (str(tuple(p.shape)), d) for p, d in self.param_gbuf_map.items() ]})
# Optimizer shards. # Optimizer shards.
self.opt_group_shards = self.get_optimizer_group_shards( self.opt_group_shards = self.get_optimizer_group_shards(
self.optimizer.param_groups, self.optimizer.param_groups,
self.model_gbuf_shards) self.model_gbuf_shards)
# Allocate main param/grad shard. # pax(0, {**{"opt_group_shards / %d" % i : g for i, g in enumerate(self.opt_group_shards)}})
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
self.main_param_shards = []
for group_index, group_shard in enumerate(self.opt_group_shards):
group_size = group_shard["size"]
# ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard. # Allocate main param shards.
if group_size == 0: # self.main_param_shards = \
main_param = None # self.allocate_main_param_shards(self.opt_group_shards)
else: self.allocate_main_param_shards(self.opt_group_shards)
main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
self.main_param_shards.append(main_param)
# Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# >>> # >>>
pax(0, { # pax(0, {
"model_gbuf_shards" : self.model_gbuf_shards, # "model_gbuf_shards" : self.model_gbuf_shards,
"opt_group_shards" : self.opt_group_shards, # "opt_group_shards" : self.opt_group_shards,
"main_param_shards" : self.main_param_shards, # "main_param_shards" : self.main_param_shards,
}) # })
# <<< # <<<
# Initialize main params. # Initialize main params.
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
# Leverage state_dict() and load_state_dict() to # Update optimizer groups.
# recast preexisting per-param state tensors # - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self.optimizer.param_groups = \
[ g["orig_group"] for g in self.opt_group_shards ]
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
...@@ -1069,11 +1099,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1069,11 +1099,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
return [ p.grad.data for p in self.main_param_shards ] return [ p.grad.data for p in self.main_param_shards ]
# return [ p.grad.data for p in self.main_param_shards if p is not None ]
def _copy_model_params_to_main_params(self): def _copy_model_params_to_main_params(self):
for group_index, group_shard in enumerate(self.opt_group_shards): for group_index, group_shard in enumerate(self.opt_group_shards):
main_param = self.main_param_shards[group_index] # main_param = self.main_param_shards[group_index]
main_param = self.optimizer.param_groups[group_index]["params"][0]
pax(0, {"main_param": tp(main_param)})
for model_param, main_shard in group_shard["param_map"].items(): for model_param, main_shard in group_shard["param_map"].items():
# Model shard. # Model shard.
...@@ -1098,13 +1131,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1098,13 +1131,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "model_shard" : str(model_shard), # "model_shard" : str(model_shard),
# }) # })
pax(1, { # pax(1, {
**{ # **{
"opt_group_shards / %d" % i : s # "opt_group_shards / %d" % i : s
for i, s in enumerate(self.opt_group_shards) # for i, s in enumerate(self.opt_group_shards)
}, # },
"main_param_shards" : self.main_param_shards, # "main_param_shards" : self.main_param_shards,
}) # })
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment