Commit eaa0c1df authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

updated zero_grad

parent 33626179
......@@ -1143,11 +1143,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
pax(0, {
"model_gbuf_shards" : self.model_gbuf_shards,
"opt_group_shards" : self.opt_group_shards,
"main_param_shards" : self.main_param_shards,
})
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# "main_param_shards" : self.main_param_shards,
# })
# <<<
# def get_loss_scale(self):
......@@ -1162,21 +1162,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def state_dict(self):
raise Exception("hi.")
# def zero_grad(self, set_to_none=True):
# params = []
# for model_param_group in self.model_param_groups:
# params.extend(model_param_group["offset_map"].keys())
# for main_group in self.optimizer.param_groups:
# params.extend(main_group["params"])
# # _zero_grad_group_helper(params, set_to_none)
# _zero_grad_group_helper(params, set_to_none = False)
# # pax(0, {
# # "model_param_groups" : self.model_param_groups,
# # "params" : params,
# # })
def zero_grad(self, set_to_none=True):
params = []
for model_param_group in self.model_param_groups:
params.extend(model_param_group["offset_map"].keys())
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
params.extend(param_map.keys())
for main_group in self.optimizer.param_groups:
params.extend(main_group["params"])
# _zero_grad_group_helper(params, set_to_none)
_zero_grad_group_helper(params, set_to_none = False)
# pax(0, {
# "model_param_groups" : self.model_param_groups,
# "params" : params,
# })
# pax(0, {"params": params})
# def reduce_gradients(self, model):
......@@ -1289,6 +1302,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert args.use_contiguous_buffers_in_local_ddp
for model_index, model in enuemrate(self.models):
pax(0, {
"model_index" : model_index,
"model" : model,
})
world_sizes = []
for r in self.world_shard_infos:
# world_sizes.append(sum(g["size"] for g in r))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment