Commit 9b7854e4 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

more cleanup of main params/grads

parent 291592e4
......@@ -989,6 +989,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.optimizer.load_state_dict(self.optimizer.state_dict())
def get_main_param(self, group_index):
return self.optimizer.param_groups[group_index]["params"][0]
def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad
def load_state_dict(self):
raise Exception("hi.")
def reload_model_params(self):
......@@ -1098,15 +1103,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# })
def _collect_main_grad_data_for_unscaling(self):
return [ p.grad.data for p in self.main_param_shards ]
# return [ p.grad.data for p in self.main_param_shards ]
# return [ p.grad.data for p in self.main_param_shards if p is not None ]
return [ self.get_main_grad(gi).data
for gi in range(len(self.opt_group_shards)) ]
def _copy_model_params_to_main_params(self):
for group_index, group_shard in enumerate(self.opt_group_shards):
# main_param = self.main_param_shards[group_index]
main_param = self.optimizer.param_groups[group_index]["params"][0]
pax(0, {"main_param": tp(main_param)})
# main_param = self.optimizer.param_groups[group_index]["params"][0]
main_param = self.get_main_param(group_index)
# if group_index > 0:
# pax({"main_param": tp(main_param)})
for model_param, main_shard in group_shard["param_map"].items():
# Model shard.
......@@ -1152,7 +1161,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Copy from DDP's contiguous buffer to main shard's grad.
model_grad = self.models[model_index]._grad_buffers[dtype].data
main_grad = self.main_param_shards[group_index].grad
# main_grad = self.main_param_shards[group_index].grad
main_grad = self.get_main_grad(group_index)
# Copy sub-range within tensor.
model_view = model_grad[model_shard.start:model_shard.end]
......@@ -1203,7 +1213,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Use DDP's contiguous buffer to temporarily hold params.
model_param = self.models[model_index]._grad_buffers[dtype].data
main_param = self.main_param_shards[group_index]
# main_param = self.main_param_shards[group_index]
main_param = self.get_main_param(group_index)
# Copy sub-range within tensor.
model_view = model_param[model_shard.start:model_shard.end]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment