Unverified Commit d24d3de9 authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

Samyamr/cpu memory bloat fix zero (#233)

* Fix for CPU memory Bloating Issue caused by pyorch backward graph creation in allgather. Fixed by calling detach on tensors before calling all_gather

* Fix for CPU memory Bloating Issue caused by pyorch backward graph creation in allgather. Fixed by calling detach on tensors before calling all_gather

* Fix for CPU memory Bloating Issue caused by pyorch backward graph creation in allgather. Fixed by calling detach on tensors before calling all_gather
parent abe2204d
...@@ -1112,23 +1112,15 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1112,23 +1112,15 @@ class FP16_DeepSpeedZeroOptimizer(object):
1, 1,
partitioned_params[partition_id].numel() * dp_world_size // partitioned_params[partition_id].numel() * dp_world_size //
self.allgather_bucket_size) self.allgather_bucket_size)
if num_shards == 1:
dist.all_gather(partitioned_params,
partitioned_params[partition_id],
group=self.dp_process_group)
else:
shard_size = partitioned_params[partition_id].numel() // num_shards shard_size = partitioned_params[partition_id].numel() // num_shards
num_elements = shard_size num_elements = shard_size
assert shard_size * num_shards <= partitioned_params[partition_id].numel()
for shard_id in range(num_shards): for shard_id in range(num_shards):
#boundary condition
#TODO: Check correctness of boundary condition
if shard_id == (num_shards - 1): if shard_id == (num_shards - 1):
if shard_size * num_shards >= partitioned_params[
partition_id].numel():
break
else:
num_elements = partitioned_params[partition_id].numel( num_elements = partitioned_params[partition_id].numel(
) - shard_id * shard_size ) - shard_id * shard_size
...@@ -1137,8 +1129,9 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1137,8 +1129,9 @@ class FP16_DeepSpeedZeroOptimizer(object):
curr_shard = partitioned_params[dp_id].narrow( curr_shard = partitioned_params[dp_id].narrow(
0, 0,
shard_id * shard_size, shard_id * shard_size,
num_elements) num_elements).detach()
shard_list.append(curr_shard) shard_list.append(curr_shard)
dist.all_gather(shard_list, dist.all_gather(shard_list,
shard_list[partition_id], shard_list[partition_id],
group=self.dp_process_group) group=self.dp_process_group)
......
...@@ -331,7 +331,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -331,7 +331,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
list) # [rank] -> [(start,end), (start,end), ...] list) # [rank] -> [(start,end), (start,end), ...]
for idx in range(num_sub_partitions): for idx in range(num_sub_partitions):
rank_id = idx % world_size rank_id = idx % world_size
sub_partition = tensor.narrow(0, start, sub_partition_size) sub_partition = tensor.narrow(0, start, sub_partition_size).detach()
element_intervals[rank_id].append((start, start + sub_partition_size)) element_intervals[rank_id].append((start, start + sub_partition_size))
comm_partitions[comm_id].append(sub_partition) comm_partitions[comm_id].append(sub_partition)
start = start + sub_partition_size start = start + sub_partition_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment