Commit 41ad0c5d authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

training; loss not yet checked.

parent fe3cfd86
...@@ -559,15 +559,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -559,15 +559,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
grads. grads.
''' '''
# >>>
# print_seq([
# tp(b.data)
# for m in self.models
# for b in m._grad_buffers.values()
# ])
# print_seq("hi.")
# <<<
# All-reduce embedding grads. # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads(args) self.allreduce_embedding_grads(args)
...@@ -593,16 +584,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -593,16 +584,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group = data_parallel_group, group = data_parallel_group,
) )
# >>>
# print_seq("hi.")
# <<<
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
def gather_model_params(self, args, timers): def gather_model_params(self, args, timers):
raise Exception("hi.")
timers('backward-params-all-gather').start() timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
...@@ -756,14 +741,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -756,14 +741,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# param.main_grad.detach().copy_(param) # param.main_grad.detach().copy_(param)
def _copy_main_params_to_model_params(self): def _copy_main_params_to_model_params(self):
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
def copy_group_params(shard_main_groups, full_model_groups): def copy_group_params(shard_main_groups, full_model_groups):
for shard_main_group, full_model_group in zip(shard_main_groups, for shard_main_group, full_model_group in zip(shard_main_groups,
full_model_groups): full_model_groups):
...@@ -777,14 +754,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -777,14 +754,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
full_model_grad = full_model_param.main_grad full_model_grad = full_model_param.main_grad
shard_model_grad = full_model_grad.view(-1) \ shard_model_grad = full_model_grad.view(-1) \
[param_range.start:param_range.end] [param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float()
# print_seq([ "%s / %d, [%d] %s" % ( shard_model_grad.data.copy_(shard_main_param)
# k, i, len(g), ", ".join(str(p.nelement()) for p in g),
# ) for k, gs in [
# ("model", self.full_float16_groups),
# ("main", self.shard_fp32_from_float16_groups),
# ] for i, g in enumerate(gs)])
copy_group_params(self.shard_fp32_from_float16_groups, copy_group_params(self.shard_fp32_from_float16_groups,
self.full_float16_groups) self.full_float16_groups)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment