Commit 23f9238d authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

set main_param_shard==None if local group_size==0 [ e.g., word embedding params

parent ac5ef637
...@@ -293,7 +293,7 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -293,7 +293,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# If we found inf/nan, skip the update. # If we found inf/nan, skip the update.
if found_inf_flag: if found_inf_flag:
pax(0, {"found_inf_flag": found_inf_flag}) # pax(0, {"found_inf_flag": found_inf_flag})
return False, None, None return False, None, None
# Clip the main gradients. # Clip the main gradients.
...@@ -758,8 +758,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -758,8 +758,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"param" : sub_param_shard, "param" : sub_param_shard,
} }
# >>> # >>>
if param_world_start < gbuf_world_shard.start: # if param_world_start < gbuf_world_shard.start:
raise Exception("hi.") # pax({"param shards": param_shard_map[param]})
# <<< # <<<
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]}) # pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
...@@ -865,13 +865,23 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -865,13 +865,23 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group_shard["size"] += param_size group_shard["size"] += param_size
group_shard["param_map"][param] = param_group_shard group_shard["param_map"][param] = param_group_shard
# raise Exception("hi.") # >>>
# if torch.distributed.get_rank() == 1:
# pax(0, {"param_group_map": [ # print(">>> [%d] ... group %d, size %d, param %s. <<<" % (
# (g, str(p.shape)) # torch.distributed.get_rank(),
# for p, g in param_group_map.items() # group_index,
# ]}) # param_size,
# pax(0, {"group_shards": group_shards}) # str(tuple(param.shape)),
# ))
# <<<
# pax(1, {
# "param_group_map": [
# (g, str(p.shape))
# for p, g in param_group_map.items()
# ],
# "group_shards" : group_shards,
# })
return group_shards return group_shards
...@@ -913,7 +923,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -913,7 +923,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
dtype = dtype, dtype = dtype,
device = torch.cuda.current_device(), device = torch.cuda.current_device(),
requires_grad = True) requires_grad = True)
self.main_param_shards = [] self.main_param_shards = []
for group_index, group_shard in enumerate(self.opt_group_shards): for group_index, group_shard in enumerate(self.opt_group_shards):
...@@ -922,14 +932,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -922,14 +932,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ** todo: for dtype in model_main_dtypes ........ ** # ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard. # Allocate shard.
main_param = allocate_shard(group_size, torch.float) if group_size == 0:
main_param.grad = allocate_shard(group_size, torch.float) main_param = None
else:
main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
self.main_param_shards.append(main_param) self.main_param_shards.append(main_param)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# Update optimizer group. # Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = [ main_param ] self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# >>>
pax(0, {
"model_gbuf_shards" : self.model_gbuf_shards,
"opt_group_shards" : self.opt_group_shards,
"main_param_shards" : self.main_param_shards,
})
# <<<
# Initialize main params. # Initialize main params.
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
...@@ -937,13 +958,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -937,13 +958,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# "main_param_shards" : self.main_param_shards,
# })
# <<<
def load_state_dict(self): def load_state_dict(self):
raise Exception("hi.") raise Exception("hi.")
...@@ -1071,22 +1085,26 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1071,22 +1085,26 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Copy shard data. # Copy shard data.
main_view = main_param[main_shard.start:main_shard.end] main_view = main_param[main_shard.start:main_shard.end]
model_view = model_param[model_shard.start:model_shard.end].view(-1) model_view = model_param.view(-1)[model_shard.start:model_shard.end]
# try:
main_view.detach().copy_(model_view) main_view.detach().copy_(model_view)
# except:
# pax(0, { # pax({
# "main_param" : tp(main_param), # "main_param" : tp(main_param),
# "model_param" : tp(model_param), # "model_param" : tp(model_param),
# "main_view" : tp(main_view), # "main_view" : tp(main_view),
# "model_view" : tp(model_view), # "model_view" : tp(model_view),
# "main_shard" : str(main_shard), # "main_shard" : str(main_shard),
# "model_shard" : str(model_shard), # "model_shard" : str(model_shard),
# }) # })
# pax(0, { pax(1, {
# "opt_group_shards" : self.opt_group_shards, **{
# "main_param_shards" : self.main_param_shards, "opt_group_shards / %d" % i : s
# }) for i, s in enumerate(self.opt_group_shards)
},
"main_param_shards" : self.main_param_shards,
})
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
...@@ -1128,14 +1146,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1128,14 +1146,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "model_gbuf_shards" : self.model_gbuf_shards, # "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards, # "opt_group_shards" : self.opt_group_shards,
# }) # })
for param in self.main_param_shards: # for param in self.main_param_shards:
grad = param.grad # grad = param.grad
is_nan = torch.any(torch.isnan(grad)).item() # is_nan = torch.any(torch.isnan(grad)).item()
if is_nan: # if is_nan:
pax(0, { # pax(0, {
"grad" : tp(grad), # "grad" : tp(grad),
"is_nan" : is_nan, # "is_nan" : is_nan,
}) # })
# <<< # <<<
...@@ -1183,7 +1201,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1183,7 +1201,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for param in self.param_gbuf_map: for param in self.param_gbuf_map:
is_nan = torch.any(torch.isnan(param)).item() is_nan = torch.any(torch.isnan(param)).item()
if is_nan: if is_nan:
pax(0, { pax({
"param" : tp(param), "param" : tp(param),
"is_nan" : is_nan, "is_nan" : is_nan,
}) })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment