"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5559d0423771e8b6e454b1541164e0d5d54b6265"
Unverified Commit bc36b91d authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

add support for predivide as a config option (#235)

* add support for predivide as a flag
* add predivide json config, remove allgather_disable (as it's not currently used anymore)
parent 01e848b3
......@@ -119,6 +119,12 @@ def get_prescale_gradients(param_dict):
return get_scalar_param(param_dict, PRESCALE_GRADIENTS, PRESCALE_GRADIENTS_DEFAULT)
def get_gradient_predivide_factor(param_dict):
return get_scalar_param(param_dict,
GRADIENT_PREDIVIDE_FACTOR,
GRADIENT_PREDIVIDE_FACTOR_DEFAULT)
def get_steps_per_print(param_dict):
return get_scalar_param(param_dict, STEPS_PER_PRINT, STEPS_PER_PRINT_DEFAULT)
......@@ -295,6 +301,7 @@ class DeepSpeedConfig(object):
self.disable_allgather = get_disable_allgather(param_dict)
self.allreduce_always_fp32 = get_allreduce_always_fp32(param_dict)
self.prescale_gradients = get_prescale_gradients(param_dict)
self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict)
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
self.allgather_size = get_allgather_size(param_dict)
......
......@@ -171,7 +171,7 @@ FP32_ALLREDUCE = "fp32_allreduce"
FP32_ALLREDUCE_DEFAULT = False
#########################################
# Scale gradients before allreduce
# Scale/predivide gradients before allreduce
#########################################
# Prescale gradients. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
......@@ -182,6 +182,13 @@ Gradient prescaling should be enabled as:
PRESCALE_GRADIENTS = "prescale_gradients"
PRESCALE_GRADIENTS_DEFAULT = False
GRADIENT_PREDIVIDE_FACTOR_FORMAT = '''
Gradient predivide factor should be enabled as:
"gradient_predivide_factor": 1.0
'''
GRADIENT_PREDIVIDE_FACTOR = "gradient_predivide_factor"
GRADIENT_PREDIVIDE_FACTOR_DEFAULT = 1.0
#########################################
# Disable AllGather
#########################################
......
......@@ -119,7 +119,6 @@ class DeepSpeedLight(Module):
self.global_steps = 0
self.micro_steps = 0
self.skipped_steps = 0
self.gradient_predivide_factor = 1.0
self.gradient_average = True
self.warn_unscaled_loss = True
self.config_params = config_params
......@@ -327,6 +326,9 @@ class DeepSpeedLight(Module):
def postscale_gradients(self):
return not self._config.prescale_gradients
def gradient_predivide_factor(self):
return self._config.gradient_predivide_factor
def steps_per_print(self):
return self._config.steps_per_print
......@@ -587,7 +589,9 @@ class DeepSpeedLight(Module):
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
mpu=self.mpu)
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor())
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
logging.info('Creating fp16 zero stage {} optimizer'.format(zero_stage))
......@@ -690,7 +694,7 @@ class DeepSpeedLight(Module):
assert self.zero_reduce_scatter()
self.optimizer.reduce_scatter_gradients(
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor,
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_average=self.gradient_average)
elif self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
......@@ -905,14 +909,14 @@ class DeepSpeedLight(Module):
tensor_to_allreduce = tensor.float()
if self.postscale_gradients():
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
if self.gradient_predivide_factor() != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor())
dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group)
if self.gradient_average:
if self.gradient_predivide_factor != self.dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor /
if self.gradient_predivide_factor() != self.dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor() /
self.dp_world_size)
else:
tensor_to_allreduce.div_(self.dp_world_size)
......
......@@ -111,7 +111,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
reduce_scatter=True,
overlap_comm=False,
mpu=None,
clip_grad=0.0):
clip_grad=0.0,
allreduce_always_fp32=False,
postscale_gradients=True,
gradient_predivide_factor=1.0):
if dist.get_rank() == 0:
print(f"Reduce bucket size {reduce_bucket_size}")
......@@ -148,6 +151,14 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.overflow = False
self.clip_grad = clip_grad
self.allreduce_always_fp32 = allreduce_always_fp32
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
# param flattened by groups
self.fp16_groups = []
......@@ -562,6 +573,32 @@ class FP16_DeepSpeedZeroOptimizer(object):
if dist.get_rank() == 0:
print(message)
def gradient_reduction_w_predivide(self, tensor):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.postscale_gradients:
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.gradient_predivide_factor() != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor() /
dp_world_size)
else:
tensor_to_allreduce.div_(dp_world_size)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def average_tensor(self, tensor):
if self.overlap_comm:
torch.cuda.synchronize()
......@@ -571,8 +608,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
with torch.cuda.stream(stream):
if not self.reduce_scatter:
tensor.div_(dist.get_world_size(group=self.dp_process_group))
dist.all_reduce(tensor, group=self.dp_process_group)
self.gradient_reduction_w_predivide(tensor)
return
# Accumulate destination ranks and bucket offsets for each gradient slice.
......
......@@ -84,18 +84,18 @@ Example of ***scheduler***
| ------------------------------------ | ------- |
| During gradient averaging perform allreduce with 32 bit values | `false` |
***disable\_allgather***: [boolean]
| Description | Default |
| ---------------------------- | ------- |
| Disable allgather when using ZeRO optimizer and instead use broadcast | `false`
***prescale\_gradients***: [boolean]
| Description | Default |
| -------------------------------------- | ------- |
| Scale gradients before doing allreduce | `false` |
***gradient_predivide_factor***: [float]
| Description | Default |
| ---------------------------- | ------- |
| Before gradient averaging predivide gradients by a specified factor, can sometimes help with fp16 stability when scaling to large numbers of GPUs | `1.0`
***sparse\_gradients***: [boolean]
| Description | Default |
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment