Commit 6875dff5 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fix zero_grad; set_to_none = False

parent 1215c420
......@@ -768,9 +768,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Allocate shards.
# (Also, collect world DP shard info.)
model_main_dtypes = set([ args.params_dtype, torch.float ])
# model_main_dtypes = set([ args.params_dtype, torch.float ])
model_main_dtypes = set([ torch.float ]) # fp32 only, for now
self.world_shard_info_groups = [] # world_group_shard_infos ?
self.main_param_shard_groups = []
# self.main_param_shard_groups = []
for group_index, model_param_group in enumerate(self.model_param_groups):
# pax(0, {
......@@ -820,26 +821,27 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
local_shard_end_index,
orig_end_index - local_shard_start_index)
if param_shard_end_index > param_shard_start_index:
# Indexes are relative to local shard start index.
# local_shard_info["param_index_map"][param] = {
# "param" : (
# param_shard_start_index,
# param_shard_end_index,
# ),
# "shard" : (
# param_shard_start_index - local_shard_start_index,
# param_shard_end_index - local_shard_start_index,
# ),
# }
# local_shard_info["param_slice_index_map"][param] = {
# "param_start" :
# param_shard_start_index,
# "shard_start" :
# param_shard_start_index - local_shard_start_index,
# "size":
# param_shard_end_index - param_shard_start_index,
# }
# if param_shard_end_index > param_shard_start_index:
# # Indexes are relative to local shard start index.
# # local_shard_info["param_index_map"][param] = {
# # "param" : (
# # param_shard_start_index,
# # param_shard_end_index,
# # ),
# # "shard" : (
# # param_shard_start_index - local_shard_start_index,
# # param_shard_end_index - local_shard_start_index,
# # ),
# # }
# local_shard_info["param_slice_index_map"][param] = {
# "param_start" :
# param_shard_start_index,
# "shard_start" :
# param_shard_start_index - local_shard_start_index,
# "size":
# param_shard_end_index - param_shard_start_index,
# }
if shard_end_index > shard_start_index:
local_shard_info["param_slice_index_map"][param] = {
"orig_start" : orig_start_index,
"shard_start" : shard_start_index,
......@@ -872,9 +874,15 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Allocate shards.
# (Non-fp32 shards are for convenience; e.g., intermediaries
# between model params and main fp32 shard. Necessary???)
main_param_shards = {
ty : allocate_shard(local_shard_size, ty)
for ty in model_main_dtypes}
# main_param_shards = {
# ty : allocate_shard(local_shard_size, ty)
# for ty in model_main_dtypes}
main_param_shards = {}
for dtype in model_main_dtypes:
main_param = allocate_shard(local_shard_size, dtype)
main_param.grad = allocate_shard(local_shard_size, dtype)
# pax(0, {"main_param": main_param})
main_param_shards[dtype] = main_param
# self.main_param_shard_groups.append(main_param_shards)
local_shard_info["data"] = main_param_shards
......@@ -891,6 +899,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<<
# def get_loss_scale(self):
# if self.grad_scaler is None:
# return self._scale_one
......@@ -911,7 +923,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for main_group in self.optimizer.param_groups:
params.extend(main_group["params"])
_zero_grad_group_helper(params, set_to_none)
# _zero_grad_group_helper(params, set_to_none)
_zero_grad_group_helper(params, set_to_none = False)
# pax(0, {
# "model_param_groups" : self.model_param_groups,
......@@ -920,6 +933,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def reduce_gradients(self, model):
# >>>
pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<<
# >>>
args = get_args()
# timers = get_timers()
......@@ -962,27 +979,32 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
main_slice_index_map = local_shard_info["param_slice_index_map"]
for param, main_slice_indexes in main_slice_index_map.items():
main_param_start_index = main_slice_indexes["param_start"]
main_shard_start_index = main_slice_indexes["shard_start"]
main_slice_size = ddd
main_size = main_shard_indexesddd
main_slice_orig_start_index = main_slice_indexes["orig_start"]
main_slice_shard_start_index = main_slice_indexes["shard_start"]
main_slice_size = main_slice_indexes["size"]
dtype_model_dict = param_model_map[param]
dtype = dtype_model_dict["dtype"]
vmodel = dtype_model_dict["model"]
model_grad_buffer = vmodel._grad_buffers[dtype]
model_grad_buffer_start_index = \
vmodel._grad_buffer_param_index_map[dtype][param][0]
# model_grad_buffer_indexes = [ model_grad_buffer_start_index + i
# for i in main_
# model_grad_view = model_grad_buffer.data[
vmodel._grad_buffer_param_index_map[dtype][param][0] + \
main_slice_orig_start_index
pax(0, {"model_grad_buffer_indexes": model_grad_buffer_indexes})
# main_grad_view = self.main_param_shard_groups \
# [group_index][torch.float].grad \
# [shard_indexes["shard"][0]:shard_indexes["shard"][1]]
main_grad_view = local_shard_info["data"][torch.float]
main_grad_view = self.main_param_shard_groups \
[group_index][torch.float].grad \
[shard_indexes["shard"][0]:shard_indexes["shard"][1]]
pax(0, {
"local_shard_info" : local_shard_info,
"main_slice_orig_start_index" : main_slice_orig_start_index,
"main_slice_shard_start_index" : main_slice_shard_start_index,
"main_slice_size" : main_slice_size,
"model_grad_buffer_start_index" :
model_grad_buffer_start_index,
"main_grad_view" : main_grad_view,
})
pax(0, {
# "dtype" : dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment