Commit 6875dff5 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fix zero_grad; set_to_none = False

parent 1215c420
...@@ -768,9 +768,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -768,9 +768,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Allocate shards. # Allocate shards.
# (Also, collect world DP shard info.) # (Also, collect world DP shard info.)
model_main_dtypes = set([ args.params_dtype, torch.float ]) # model_main_dtypes = set([ args.params_dtype, torch.float ])
model_main_dtypes = set([ torch.float ]) # fp32 only, for now
self.world_shard_info_groups = [] # world_group_shard_infos ? self.world_shard_info_groups = [] # world_group_shard_infos ?
self.main_param_shard_groups = [] # self.main_param_shard_groups = []
for group_index, model_param_group in enumerate(self.model_param_groups): for group_index, model_param_group in enumerate(self.model_param_groups):
# pax(0, { # pax(0, {
...@@ -820,26 +821,27 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -820,26 +821,27 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
local_shard_end_index, local_shard_end_index,
orig_end_index - local_shard_start_index) orig_end_index - local_shard_start_index)
if param_shard_end_index > param_shard_start_index: # if param_shard_end_index > param_shard_start_index:
# Indexes are relative to local shard start index. # # Indexes are relative to local shard start index.
# local_shard_info["param_index_map"][param] = { # # local_shard_info["param_index_map"][param] = {
# "param" : ( # # "param" : (
# param_shard_start_index, # # param_shard_start_index,
# param_shard_end_index, # # param_shard_end_index,
# ), # # ),
# "shard" : ( # # "shard" : (
# param_shard_start_index - local_shard_start_index, # # param_shard_start_index - local_shard_start_index,
# param_shard_end_index - local_shard_start_index, # # param_shard_end_index - local_shard_start_index,
# ), # # ),
# } # # }
# local_shard_info["param_slice_index_map"][param] = { # local_shard_info["param_slice_index_map"][param] = {
# "param_start" : # "param_start" :
# param_shard_start_index, # param_shard_start_index,
# "shard_start" : # "shard_start" :
# param_shard_start_index - local_shard_start_index, # param_shard_start_index - local_shard_start_index,
# "size": # "size":
# param_shard_end_index - param_shard_start_index, # param_shard_end_index - param_shard_start_index,
# } # }
if shard_end_index > shard_start_index:
local_shard_info["param_slice_index_map"][param] = { local_shard_info["param_slice_index_map"][param] = {
"orig_start" : orig_start_index, "orig_start" : orig_start_index,
"shard_start" : shard_start_index, "shard_start" : shard_start_index,
...@@ -872,9 +874,15 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -872,9 +874,15 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Allocate shards. # Allocate shards.
# (Non-fp32 shards are for convenience; e.g., intermediaries # (Non-fp32 shards are for convenience; e.g., intermediaries
# between model params and main fp32 shard. Necessary???) # between model params and main fp32 shard. Necessary???)
main_param_shards = { # main_param_shards = {
ty : allocate_shard(local_shard_size, ty) # ty : allocate_shard(local_shard_size, ty)
for ty in model_main_dtypes} # for ty in model_main_dtypes}
main_param_shards = {}
for dtype in model_main_dtypes:
main_param = allocate_shard(local_shard_size, dtype)
main_param.grad = allocate_shard(local_shard_size, dtype)
# pax(0, {"main_param": main_param})
main_param_shards[dtype] = main_param
# self.main_param_shard_groups.append(main_param_shards) # self.main_param_shard_groups.append(main_param_shards)
local_shard_info["data"] = main_param_shards local_shard_info["data"] = main_param_shards
...@@ -891,6 +899,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -891,6 +899,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<<
# def get_loss_scale(self): # def get_loss_scale(self):
# if self.grad_scaler is None: # if self.grad_scaler is None:
# return self._scale_one # return self._scale_one
...@@ -911,7 +923,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -911,7 +923,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for main_group in self.optimizer.param_groups: for main_group in self.optimizer.param_groups:
params.extend(main_group["params"]) params.extend(main_group["params"])
_zero_grad_group_helper(params, set_to_none) # _zero_grad_group_helper(params, set_to_none)
_zero_grad_group_helper(params, set_to_none = False)
# pax(0, { # pax(0, {
# "model_param_groups" : self.model_param_groups, # "model_param_groups" : self.model_param_groups,
...@@ -920,6 +933,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -920,6 +933,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def reduce_gradients(self, model): def reduce_gradients(self, model):
# >>>
pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<<
# >>> # >>>
args = get_args() args = get_args()
# timers = get_timers() # timers = get_timers()
...@@ -962,27 +979,32 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -962,27 +979,32 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
main_slice_index_map = local_shard_info["param_slice_index_map"] main_slice_index_map = local_shard_info["param_slice_index_map"]
for param, main_slice_indexes in main_slice_index_map.items(): for param, main_slice_indexes in main_slice_index_map.items():
main_param_start_index = main_slice_indexes["param_start"] main_slice_orig_start_index = main_slice_indexes["orig_start"]
main_shard_start_index = main_slice_indexes["shard_start"] main_slice_shard_start_index = main_slice_indexes["shard_start"]
main_slice_size = ddd main_slice_size = main_slice_indexes["size"]
main_size = main_shard_indexesddd
dtype_model_dict = param_model_map[param] dtype_model_dict = param_model_map[param]
dtype = dtype_model_dict["dtype"] dtype = dtype_model_dict["dtype"]
vmodel = dtype_model_dict["model"] vmodel = dtype_model_dict["model"]
model_grad_buffer = vmodel._grad_buffers[dtype] model_grad_buffer = vmodel._grad_buffers[dtype]
model_grad_buffer_start_index = \ model_grad_buffer_start_index = \
vmodel._grad_buffer_param_index_map[dtype][param][0] vmodel._grad_buffer_param_index_map[dtype][param][0] + \
main_slice_orig_start_index
# model_grad_buffer_indexes = [ model_grad_buffer_start_index + i
# for i in main_
# model_grad_view = model_grad_buffer.data[
pax(0, {"model_grad_buffer_indexes": model_grad_buffer_indexes}) # main_grad_view = self.main_param_shard_groups \
# [group_index][torch.float].grad \
# [shard_indexes["shard"][0]:shard_indexes["shard"][1]]
main_grad_view = local_shard_info["data"][torch.float]
main_grad_view = self.main_param_shard_groups \ pax(0, {
[group_index][torch.float].grad \ "local_shard_info" : local_shard_info,
[shard_indexes["shard"][0]:shard_indexes["shard"][1]] "main_slice_orig_start_index" : main_slice_orig_start_index,
"main_slice_shard_start_index" : main_slice_shard_start_index,
"main_slice_size" : main_slice_size,
"model_grad_buffer_start_index" :
model_grad_buffer_start_index,
"main_grad_view" : main_grad_view,
})
pax(0, { pax(0, {
# "dtype" : dtype, # "dtype" : dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment