Commit 9b7854e4 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

more cleanup of main params/grads

parent 291592e4
...@@ -989,6 +989,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -989,6 +989,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
def get_main_param(self, group_index):
return self.optimizer.param_groups[group_index]["params"][0]
def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad
def load_state_dict(self): def load_state_dict(self):
raise Exception("hi.") raise Exception("hi.")
def reload_model_params(self): def reload_model_params(self):
...@@ -1098,15 +1103,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1098,15 +1103,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# }) # })
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
return [ p.grad.data for p in self.main_param_shards ] # return [ p.grad.data for p in self.main_param_shards ]
# return [ p.grad.data for p in self.main_param_shards if p is not None ] # return [ p.grad.data for p in self.main_param_shards if p is not None ]
return [ self.get_main_grad(gi).data
for gi in range(len(self.opt_group_shards)) ]
def _copy_model_params_to_main_params(self): def _copy_model_params_to_main_params(self):
for group_index, group_shard in enumerate(self.opt_group_shards): for group_index, group_shard in enumerate(self.opt_group_shards):
# main_param = self.main_param_shards[group_index] # main_param = self.main_param_shards[group_index]
main_param = self.optimizer.param_groups[group_index]["params"][0] # main_param = self.optimizer.param_groups[group_index]["params"][0]
pax(0, {"main_param": tp(main_param)}) main_param = self.get_main_param(group_index)
# if group_index > 0:
# pax({"main_param": tp(main_param)})
for model_param, main_shard in group_shard["param_map"].items(): for model_param, main_shard in group_shard["param_map"].items():
# Model shard. # Model shard.
...@@ -1152,7 +1161,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1152,7 +1161,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Copy from DDP's contiguous buffer to main shard's grad. # Copy from DDP's contiguous buffer to main shard's grad.
model_grad = self.models[model_index]._grad_buffers[dtype].data model_grad = self.models[model_index]._grad_buffers[dtype].data
main_grad = self.main_param_shards[group_index].grad # main_grad = self.main_param_shards[group_index].grad
main_grad = self.get_main_grad(group_index)
# Copy sub-range within tensor. # Copy sub-range within tensor.
model_view = model_grad[model_shard.start:model_shard.end] model_view = model_grad[model_shard.start:model_shard.end]
...@@ -1203,7 +1213,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1203,7 +1213,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Use DDP's contiguous buffer to temporarily hold params. # Use DDP's contiguous buffer to temporarily hold params.
model_param = self.models[model_index]._grad_buffers[dtype].data model_param = self.models[model_index]._grad_buffers[dtype].data
main_param = self.main_param_shards[group_index] # main_param = self.main_param_shards[group_index]
main_param = self.get_main_param(group_index)
# Copy sub-range within tensor. # Copy sub-range within tensor.
model_view = model_param[model_shard.start:model_shard.end] model_view = model_param[model_shard.start:model_shard.end]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment