Commit eaa0c1df authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

updated zero_grad

parent 33626179
...@@ -1143,11 +1143,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1143,11 +1143,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>> # >>>
pax(0, { # pax(0, {
"model_gbuf_shards" : self.model_gbuf_shards, # "model_gbuf_shards" : self.model_gbuf_shards,
"opt_group_shards" : self.opt_group_shards, # "opt_group_shards" : self.opt_group_shards,
"main_param_shards" : self.main_param_shards, # "main_param_shards" : self.main_param_shards,
}) # })
# <<< # <<<
# def get_loss_scale(self): # def get_loss_scale(self):
...@@ -1162,21 +1162,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1162,21 +1162,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def state_dict(self): def state_dict(self):
raise Exception("hi.") raise Exception("hi.")
# def zero_grad(self, set_to_none=True):
# params = []
# for model_param_group in self.model_param_groups:
# params.extend(model_param_group["offset_map"].keys())
# for main_group in self.optimizer.param_groups:
# params.extend(main_group["params"])
# # _zero_grad_group_helper(params, set_to_none)
# _zero_grad_group_helper(params, set_to_none = False)
# # pax(0, {
# # "model_param_groups" : self.model_param_groups,
# # "params" : params,
# # })
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
params = [] params = []
for model_param_group in self.model_param_groups: for model in self.models:
params.extend(model_param_group["offset_map"].keys()) for dtype, param_map in model._grad_buffer_param_index_map.items():
params.extend(param_map.keys())
for main_group in self.optimizer.param_groups: for main_group in self.optimizer.param_groups:
params.extend(main_group["params"]) params.extend(main_group["params"])
# _zero_grad_group_helper(params, set_to_none) # _zero_grad_group_helper(params, set_to_none)
_zero_grad_group_helper(params, set_to_none = False) _zero_grad_group_helper(params, set_to_none = False)
# pax(0, { # pax(0, {"params": params})
# "model_param_groups" : self.model_param_groups,
# "params" : params,
# })
# def reduce_gradients(self, model): # def reduce_gradients(self, model):
...@@ -1289,6 +1302,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1289,6 +1302,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] ** # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
for model_index, model in enuemrate(self.models):
pax(0, {
"model_index" : model_index,
"model" : model,
})
world_sizes = [] world_sizes = []
for r in self.world_shard_infos: for r in self.world_shard_infos:
# world_sizes.append(sum(g["size"] for g in r)) # world_sizes.append(sum(g["size"] for g in r))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment