Commit 2c1660e7 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

cleaned distrib_optimizer.py.

parent efa3cbcf
...@@ -25,11 +25,6 @@ from megatron import mpu ...@@ -25,11 +25,6 @@ from megatron import mpu
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
# >>>
from lutil import pax, tp
DEBUG_ITERATION = 2 # 10
# <<<
class Shard: class Shard:
def __init__(self, start, end): def __init__(self, start, end):
...@@ -42,10 +37,6 @@ class Shard: ...@@ -42,10 +37,6 @@ class Shard:
return "%d,%d [%d]" % (self.start, self.end, self.size) return "%d,%d [%d]" % (self.start, self.end, self.size)
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(MegatronOptimizer):
# class Float16DistributedOptimizer(BaseFloat16Optimizer):
# class DistributedOptimizer(MegatronOptimizer):
class DistributedOptimizer(MixedPrecisionOptimizer): class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
...@@ -177,18 +168,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -177,18 +168,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def allocate_main_param_shards(cls, opt_group_shards): def allocate_main_param_shards(cls, opt_group_shards):
# Allocator method. # Allocator method.
# >>>
allocate_shard = lambda shard_size, dtype : torch.empty( allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,), (shard_size,),
dtype = dtype, dtype = dtype,
device = torch.cuda.current_device(), device = torch.cuda.current_device(),
requires_grad = True) requires_grad = True)
# allocate_shard = lambda shard_size, dtype : torch.zeros(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# <<<
# Allocate each group's param/grad shard. # Allocate each group's param/grad shard.
for group_index, group_shard in enumerate(opt_group_shards): for group_index, group_shard in enumerate(opt_group_shards):
...@@ -295,29 +279,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -295,29 +279,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_main_grad(self, group_index): def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad return self.get_main_param(group_index).grad
# def load_state_dict(self):
# raise Exception("hi.")
# # def reload_model_params(self): # ... done in MixedPrecisionOptimizer
# # raise Exception("hi.")
# def state_dict(self):
# raise Exception("hi.")
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict() state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler: if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict() state_dict['grad_scaler'] = self.grad_scaler.state_dict()
# state_dict['params'] = \
# [ p for g in self.optimizer.param_groups for p in g["params"] ]
state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups] state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups]
# pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer,
# "state_dict" : state_dict,
# "state_dict / param_groups" : state_dict["optimizer"]["param_groups"],
# "optimizer / groups" : self.optimizer.param_groups,
# "state_dict / params" : [ p.shape for p in state_dict["params"] ],
# "optimizer / params" :
# [ p.shape for g in self.optimizer.param_groups for p in g["params"] ],
# })
return state_dict return state_dict
...@@ -330,10 +297,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -330,10 +297,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'an old checkpoint ...') 'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key]) self.optimizer.load_state_dict(state_dict[optimizer_key])
# pax(0, {
# "state_dict" : state_dict,
# "params" : state_dict["params"],
# })
# Grad scaler. # Grad scaler.
if 'grad_scaler' not in state_dict: if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not ' print_rank_0('***WARNING*** found an old checkpoint, will not '
...@@ -349,32 +312,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -349,32 +312,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Copy data for the main params. # Copy data for the main params.
current_groups = [ g["params"] for g in self.optimizer.param_groups ] current_groups = [ g["params"] for g in self.optimizer.param_groups ]
assert "groups" in state_dict, "key 'groups' not in state_dict." assert "groups" in state_dict, "key 'groups' not in state_dict."
# pax(0, {
# "state_dict" : state_dict,
# "current_groups" : current_groups,
# "saved_groups" : state_dict[params_key],
# })
for current_group, saved_group in zip(current_groups, state_dict["groups"]): for current_group, saved_group in zip(current_groups, state_dict["groups"]):
# pax(0, {
# "current_group" : current_group,
# "saved_group" : saved_group,
# })
for current_param, saved_param in zip(current_group, saved_group): for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data) current_param.data.copy_(saved_param.data)
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
# Collect model params.
model_params = [] model_params = []
for model in self.models: for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items(): for dtype, param_map in model._grad_buffer_param_index_map.items():
model_params.extend(param_map.keys()) model_params.extend(param_map.keys())
# main_params = []
# for main_group in self.optimizer.param_groups:
# main_params.extend(main_group["params"])
# ** using contiguous buffer; don't set_to_none ** # Distributed optimizer requires contiguous buffer; don't set to None.
_zero_grad_group_helper(model_params, set_to_none = False) _zero_grad_group_helper(model_params, set_to_none = False)
# _zero_grad_group_helper(params, set_to_none = False)
def get_model_grad_buffer_dp_views(self): def get_model_grad_buffer_dp_views(self):
...@@ -422,7 +373,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -422,7 +373,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
def gather_model_params(self, args, timers, ITERATION): def gather_model_params(self, args, timers):
timers('backward-params-all-gather').start() timers('backward-params-all-gather').start()
...@@ -471,7 +422,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -471,7 +422,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
main_view.detach().copy_(model_view) main_view.detach().copy_(model_view)
def _copy_model_grads_to_main_grads(self, ITERATION): def _copy_model_grads_to_main_grads(self):
for group_index, group_shard in enumerate(self.opt_group_shards): for group_index, group_shard in enumerate(self.opt_group_shards):
for model_param, main_shard in group_shard["param_map"].items(): for model_param, main_shard in group_shard["param_map"].items():
...@@ -494,7 +445,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -494,7 +445,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
main_view.detach().copy_(model_view) main_view.detach().copy_(model_view)
def _copy_main_params_to_model_params(self, ITERATION): def _copy_main_params_to_model_params(self):
for group_index, group_shard in enumerate(self.opt_group_shards): for group_index, group_shard in enumerate(self.opt_group_shards):
for model_param, main_shard in group_shard["param_map"].items(): for model_param, main_shard in group_shard["param_map"].items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment