Unverified Commit 46018859 authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

Samyamr/inference hook fix (#851)



* Fix mis-aligned-grad

When a parameter is not divisible by world size, the partitioned gradients are mis-aligned due to incorrect padding handling. This PR should fix for that.

* Formatting fix

* Adding static_scale test back for Z3, and also changing hidden size to be not divisile by world_size

* also removing alignment from flat fp16 buffers

* Testing for hidden dim alignment

* inference hook fix

* Update stage3.py

* formatting

* [bug-fix] move params to gpu if offload params is turned off
Co-authored-by: default avatarSamyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 73d762c8
......@@ -807,8 +807,12 @@ class Init(InsertPostInitMethodToModuleSubClasses):
if start < param.ds_numel:
elements = min(param.ds_numel - start, partition_size)
dest_tensor = partition_buffer.view(-1).narrow(0, 0, elements)
dest_tensor_full_buffer = partition_buffer.view(-1).narrow(
0,
0,
partition_size)
dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
src_tensor = param.grad.view(-1).narrow(0, start, elements)
# just copy the grad partition to the buffer
......@@ -841,7 +845,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
# elements))
#print("after partition gradients")
param.grad.data = dest_tensor.data
param.grad.data = dest_tensor_full_buffer.data
see_memory_usage("After partitioning gradients", force=False)
......
......@@ -961,10 +961,9 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
#create flat buffer in CPU and move to GPU
self.fp16_partitioned_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
flatten_dense_tensors_aligned(self.fp16_partitioned_groups[i],
1).cuda(
torch.cuda.current_device()))
see_memory_usage(
f"After flattening and moving param group {i} to GPU",
force=False)
......@@ -976,10 +975,12 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
flat_offset,
total_elements)
self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat)
self._move_to_flat_buffer(self.fp16_partitioned_groups[i],
self.fp16_partitioned_groups_flat[i])
flat_offset += total_elements
# move param to flat buffer for both param offload on/off
self._move_to_flat_buffer(self.fp16_partitioned_groups[i],
self.fp16_partitioned_groups_flat[i])
see_memory_usage(f"After Flattening param group {i}", force=False)
def _create_fp32_partitions(self):
......@@ -1036,6 +1037,14 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
self.hierarchy = 0
self._register_hooks_recursively(self.module)
#reset step if in inference mode
def _end_of_forward_hook(module, *args):
if not torch._C.is_grad_enabled():
self.param_coordinator.reset_step()
self.module.register_forward_hook(_end_of_forward_hook)
def persistent_parameters(self):
persistent_params = []
total_persistent_parameters = 0
......
......@@ -347,9 +347,6 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip("skip for now")
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
......@@ -371,8 +368,9 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=2)
def _test_zero_static_scale(args, zero_stage):
hidden_dim = 10
def _test_zero_static_scale(args, zero_stage, hidden_dim):
#making hidden size not divisible by DP for covering this scenario
hidden_dim = hidden_dim
model = SimpleModel(hidden_dim)
model, optim, _, _ = deepspeed.initialize(args=args,
......@@ -393,7 +391,10 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
model.backward(loss)
model.step()
_test_zero_static_scale(args=args, zero_stage=zero_stage)
#test when hidden_dim is not aligned with world size
_test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=9)
#test when hidden_dim is aligned with world size
_test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=10)
def test_zero_static_scale_deprecated_format(tmpdir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment