Unverified Commit d24d3de9 authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

Samyamr/cpu memory bloat fix zero (#233)

* Fix for CPU memory Bloating Issue caused by pyorch backward graph creation in allgather. Fixed by calling detach on tensors before calling all_gather

* Fix for CPU memory Bloating Issue caused by pyorch backward graph creation in allgather. Fixed by calling detach on tensors before calling all_gather

* Fix for CPU memory Bloating Issue caused by pyorch backward graph creation in allgather. Fixed by calling detach on tensors before calling all_gather
parent abe2204d
...@@ -1112,36 +1112,29 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1112,36 +1112,29 @@ class FP16_DeepSpeedZeroOptimizer(object):
1, 1,
partitioned_params[partition_id].numel() * dp_world_size // partitioned_params[partition_id].numel() * dp_world_size //
self.allgather_bucket_size) self.allgather_bucket_size)
if num_shards == 1:
dist.all_gather(partitioned_params, shard_size = partitioned_params[partition_id].numel() // num_shards
partitioned_params[partition_id], num_elements = shard_size
assert shard_size * num_shards <= partitioned_params[partition_id].numel()
for shard_id in range(num_shards):
if shard_id == (num_shards - 1):
num_elements = partitioned_params[partition_id].numel(
) - shard_id * shard_size
shard_list = []
for dp_id in range(dp_world_size):
curr_shard = partitioned_params[dp_id].narrow(
0,
shard_id * shard_size,
num_elements).detach()
shard_list.append(curr_shard)
dist.all_gather(shard_list,
shard_list[partition_id],
group=self.dp_process_group) group=self.dp_process_group)
else:
shard_size = partitioned_params[partition_id].numel() // num_shards
num_elements = shard_size
for shard_id in range(num_shards):
#boundary condition
#TODO: Check correctness of boundary condition
if shard_id == (num_shards - 1):
if shard_size * num_shards >= partitioned_params[
partition_id].numel():
break
else:
num_elements = partitioned_params[partition_id].numel(
) - shard_id * shard_size
shard_list = []
for dp_id in range(dp_world_size):
curr_shard = partitioned_params[dp_id].narrow(
0,
shard_id * shard_size,
num_elements)
shard_list.append(curr_shard)
dist.all_gather(shard_list,
shard_list[partition_id],
group=self.dp_process_group)
timers('optimizer_allgather').stop() timers('optimizer_allgather').stop()
# TODO: we probably don't need this? just to be safe # TODO: we probably don't need this? just to be safe
......
...@@ -331,7 +331,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -331,7 +331,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
list) # [rank] -> [(start,end), (start,end), ...] list) # [rank] -> [(start,end), (start,end), ...]
for idx in range(num_sub_partitions): for idx in range(num_sub_partitions):
rank_id = idx % world_size rank_id = idx % world_size
sub_partition = tensor.narrow(0, start, sub_partition_size) sub_partition = tensor.narrow(0, start, sub_partition_size).detach()
element_intervals[rank_id].append((start, start + sub_partition_size)) element_intervals[rank_id].append((start, start + sub_partition_size))
comm_partitions[comm_id].append(sub_partition) comm_partitions[comm_id].append(sub_partition)
start = start + sub_partition_size start = start + sub_partition_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment