Unverified Commit 7240abf3 authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

Samyamr/grad acc stage2 (#338)



* Adding gradient accumulation support for ZeRO Stage 2. Changing all Megatron-LM tests to also test gradient accumulation

* Gradient Accumulation support for Stage 2. Model tests added to test the feature

* formatting

* Update deepspeed_light.py

removing comment

* Update ds_config_func_bs8_zero1.json

reverting this file back. Its not needed for this PR

* defining baseline prefix
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 458c0d92
......@@ -598,8 +598,6 @@ class DeepSpeedLight(Module):
dp_process_group=self.data_parallel_group,
mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
assert self.gradient_accumulation_steps(
) == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1"
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
timers=self.timers,
......@@ -721,15 +719,19 @@ class DeepSpeedLight(Module):
return loss
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
if self.is_gradient_accumulation_boundary():
#Zero stage 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
#Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter()
self.optimizer.reduce_scatter_gradients(
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_average=self.gradient_average)
elif self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
......
......@@ -446,16 +446,34 @@ class FP16_DeepSpeedZeroOptimizer(object):
torch.cuda.synchronize()
for i, _ in enumerate(self.fp16_groups):
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=torch.half,
device=torch.cuda.current_device(),
return_tensor_list=True)
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=torch.half,
device=torch.cuda.current_device(),
return_tensor_list=True)
else:
#When gradient accumulation is greater that 1
#This code path will be triggered and will add
#to the accumulated averaged gradients
avg_new = self.get_flat_partition(self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=torch.half,
device=torch.cuda.current_device(),
return_tensor_list=True)
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new):
accumulated_grad.add_(new_avg_grad)
self._release_ipg_buffers()
# No need to keep the gradients anymore.
# All gradients required by the step
# are in self.averaged_gradients
self.zero_grad()
see_memory_usage(f"End ipg_epilogue")
# resets all partition to no reduced
......@@ -1103,6 +1121,9 @@ class FP16_DeepSpeedZeroOptimizer(object):
if self.overflow:
see_memory_usage('After overflow before clearing gradients')
self.zero_grad()
for key in self.averaged_gradients:
self.averaged_gradients[key] = None
see_memory_usage('After overflow after clearing gradients')
logger.info(
......
{
"train_micro_batch_size_per_gpu":8,
"gradient_accumulation_steps": 3,
"steps_per_print": 1,
"zero_optimization": {
"stage":0,
"reduce_bucket_size": 7000000,
"allgather_bucket_size": 7000000,
"reduce_scatter": true
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"activation_checkpointing": {
"partition_activations": true,
"contiguous_memory_optimization": true
}
}
{
"train_micro_batch_size_per_gpu":8,
"gradient_accumulation_steps": 3,
"steps_per_print": 1,
"zero_optimization": {
"stage":2,
"reduce_bucket_size": 7000000,
"allgather_bucket_size": 7000000,
"reduce_scatter": true
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"activation_checkpointing": {
"partition_activations": true,
"contiguous_memory_optimization": true
}
}
......@@ -198,6 +198,28 @@ class GPT2FuncTestCase(BaseTestCase):
succ = self.run_partition_activations_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu4_node1_zero2_gas(self):
test_config = {
"mp": 2,
"gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1000,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": SEQ_LEN,
"heads": ATTN_HEADS,
"deepspeed": True,
"json": "ds_config_func_bs8_zero2_gas10.json",
"baseline": "ds_config_func_bs8_zero0_gas10.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
succ = self.run_partition_activations_test(test_config, 0.01)
self.assertTrue(succ)
def test_optimizer_scheduler(self):
test_config = {
"mp": 1,
......@@ -224,9 +246,22 @@ class GPT2FuncTestCase(BaseTestCase):
baseline_prefix = "gpt2_func_"
prefix = "gpt2_partition_activation_"
deepspeed_config = test_config["json"]
baseline_deepspeed_config = False
# baseline run...
test_config["deepspeed"] = False
base_file = self.gen_output_name(test_config, baseline_prefix)
# turnoff deepspeed if baseline deepspeed config
# is not provided
if not "baseline" in test_config:
test_config["deepspeed"] = False
else:
test_config["json"] = test_config["baseline"]
baseline_prefix += test_config["json"][0:-5]
baseline_deepspeed_config = True
base_file = self.gen_output_name(test_config,
baseline_prefix,
baseline_config=baseline_deepspeed_config)
# skip baseline run if it exists.
if not self.has_loss_data(base_file):
......@@ -238,6 +273,7 @@ class GPT2FuncTestCase(BaseTestCase):
# DeepSpeed run...
test_config["deepspeed"] = True
test_config["other_args"] = "--deepspeed-activation-checkpointing"
test_config["json"] = deepspeed_config
print("{0}: DeepSpeed run.".format(self.id()))
test_file = self.gen_output_name(test_config, prefix)
self.run_gpt2_test(test_config, test_file)
......@@ -249,10 +285,25 @@ class GPT2FuncTestCase(BaseTestCase):
print("{0}: starting......".format(self.id()))
prefix = "gpt2_func"
baseline_prefix = prefix
deepspeed_config = test_config["json"]
baseline_deepspeed_config = False
# baseline run...
# turn off deepspeed if a baseline deepspeed config
# is not provided
if not "baseline" in test_config:
test_config["deepspeed"] = False
else:
test_config["json"] = test_config["baseline"]
baseline_prefix = prefix + test_config["json"][0:-5]
baseline_deepspeed_config = True
# baseline run...
test_config["deepspeed"] = False
base_file = self.gen_output_name(test_config, prefix)
base_file = self.gen_output_name(test_config,
baseline_prefix,
baseline_config=baseline_deepspeed_config)
# skip baseline run if it exists.
if not self.has_loss_data(base_file):
......@@ -263,6 +314,8 @@ class GPT2FuncTestCase(BaseTestCase):
# DeepSpeed run...
test_config["deepspeed"] = True
test_config["json"] = deepspeed_config
print("{0}: DeepSpeed run.".format(self.id()))
test_file = self.gen_output_name(test_config, prefix)
self.run_gpt2_test(test_config, test_file)
......@@ -305,7 +358,10 @@ def suite():
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2'))
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero2'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2_gas'))
suite.addTest(GPT2FuncTestCase('test_optimizer_scheduler'))
return suite
......
......@@ -16,7 +16,7 @@ class BaseTestCase(unittest.TestCase):
self.baseline_dir = "./baseline"
self.timestr = time.strftime("%Y%m%d-%H%M%S")
def gen_output_name(self, test_config, prefix):
def gen_output_name(self, test_config, prefix, baseline_config=False):
other_args = test_config["other_args"] if "other_args" in test_config else ""
zero_args = "_zero" if "zero" in test_config and test_config["zero"] else ""
other_args = other_args.strip(' -\\').replace(" ", "").replace("\"", "")
......@@ -24,7 +24,7 @@ class BaseTestCase(unittest.TestCase):
if other_args:
other_args = "_" + other_args
if test_config["deepspeed"]:
if test_config["deepspeed"] and not baseline_config:
file_name = "_mp{0}_gpu{1}_node{2}_bs{3}_step{4}_layer{5}_hidden{6}_seq{7}_head{8}{9}_ds{10}-{11}.log".format(
test_config["mp"],
test_config["gpus"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment