Commit 87833e1f authored by Jeff Rasley's avatar Jeff Rasley
Browse files

calculate grad norm wrt sub partitions

parent 17f36f1b
Pipeline #200 failed with stages
in 0 seconds
...@@ -67,7 +67,7 @@ class CheckOverflow(object): ...@@ -67,7 +67,7 @@ class CheckOverflow(object):
return bool(overflow) return bool(overflow)
def check(self, param_groups=None): def check(self, param_groups=None, raw_grads=False):
params = [] params = []
if param_groups is None: if param_groups is None:
params = self.params params = self.params
...@@ -79,17 +79,18 @@ class CheckOverflow(object): ...@@ -79,17 +79,18 @@ class CheckOverflow(object):
for param in group: for param in group:
params.append(param) params.append(param)
return self.has_overflow(params) return self.has_overflow(params, raw_grads)
# `params` is a list / generator of torch.Variable # `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params): def has_overflow_serial(self, params, raw_grads=False):
for i, p in enumerate(params): for i, p in enumerate(params):
if p.grad is not None and self._has_inf_or_nan(p.grad.data, i): grad = p if raw_grads else p.grad
if grad is not None and self._has_inf_or_nan(grad.data, i):
return True return True
return False return False
def has_overflow(self, params): def has_overflow(self, params, raw_grads=False):
overflow = self.has_overflow_serial(params) overflow = self.has_overflow_serial(params, raw_grads)
# Since each model parallel GPU carries only part of the model, # Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs # make sure overflow flag is synced across all the model parallel GPUs
overflow_gpu = torch.cuda.ByteTensor([overflow]) overflow_gpu = torch.cuda.ByteTensor([overflow])
......
...@@ -6,7 +6,7 @@ from collections import defaultdict ...@@ -6,7 +6,7 @@ from collections import defaultdict
from deepspeed.runtime.zero.utils import _initialize_parameter_parallel_groups from deepspeed.runtime.zero.utils import _initialize_parameter_parallel_groups
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_OPTIMIZER_STATES from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_OPTIMIZER_STATES
from deepspeed.utils import logger, log_dist from deepspeed.utils import logger, log_dist
...@@ -642,7 +642,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -642,7 +642,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
partition_id = dist.get_rank(group=self.dp_process_group) partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups): for i, group in enumerate(self.fp16_groups):
#TODO RS: update get grad norm to support sub partitions #TODO RS: update get grad norm to support sub partitions
norm_groups.append(get_grad_norm(group, mpu=self.mpu)) # norm_groups.append(get_grad_norm(group, mpu=self.mpu))
#RS: update free grads w.r.t. sub partitions #RS: update free grads w.r.t. sub partitions
#free gradients for all the parameters that are not updated by this process #free gradients for all the parameters that are not updated by this process
...@@ -667,6 +667,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -667,6 +667,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
self.free_grad_in_param_list( self.free_grad_in_param_list(
self.params_in_rank_sub_partitions[i][partition_id]) self.params_in_rank_sub_partitions[i][partition_id])
# calculate grad norm w.r.t. local sub partitions
norm_groups.append(
self.get_grad_norm_sub_partitions(local_grad_sub_partitions,
mpu=self.mpu))
local_sub_partitions_grad_groups.append(local_grad_sub_partitions) local_sub_partitions_grad_groups.append(local_grad_sub_partitions)
#RS: update unscale/clip with sub partitions #RS: update unscale/clip with sub partitions
...@@ -706,6 +711,40 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -706,6 +711,40 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
return self.overflow return self.overflow
def get_grad_norm_sub_partitions(self, sub_partitions, mpu):
norm_type = 2.0
total_norm = 0.
for partition in sub_partitions:
if mpu is not None:
# if (mpu.get_model_parallel_rank() == 0
# ) or is_model_parallel_parameter(p):
# param_norm = p.grad.data.float().norm(norm_type)
# total_norm += param_norm.item()**norm_type
raise NotImplementedError(
"support grad norm of model parallel parameters")
else:
param_norm = partition.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type
# Sum across all DP ranks who each have different grad sub-partitions
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
if mpu is not None:
# Sum across all model parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups): def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
total_norm = 0.0 total_norm = 0.0
for norm in norm_groups: for norm in norm_groups:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment