Commit 41038d54 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

running & saving memory w/ _reduce_scatter_base/_all_gather_base.

parent 64b94f00
......@@ -410,50 +410,53 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_size = int(gbuf.numel_padded / data_parallel_world_size)
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf_views))
# gbuf_view_items.append((model_index, dtype, gbuf_views))
gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
return gbuf_view_items
# >>>
def get_model_grad_buffer_dp_views_SINGLE(self):
# def get_model_grad_buffer_dp_views_SINGLE(self):
data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views.
gbuf_items = []
for model_index, model in enumerate(self.models):
for dtype, gbuf in model._grad_buffers.items():
gbuf_items.append((model_index, dtype, gbuf.data))
# # Grad buffer views.
# gbuf_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf in model._grad_buffers.items():
# gbuf_items.append((model_index, dtype, gbuf.data))
return gbuf_items
# return gbuf_items
# <<<
def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
# >>>
# def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
# Iterate grad buffers & chunk.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
chunk_view_items = []
for model_index, dtype, gbuf_views in gbuf_view_items:
# ** Sanity check. ** (should be unnecessary; see comment above)
view_numel = gbuf_views[0].nelement()
for view in gbuf_views:
assert view.nelement() == view_numel
# Compute chunk size (via savings factor).
chunk_numel_min = 131072
chunk_numel_max = view_numel
chunk_numel = int(
mem_savings_factor * chunk_numel_min
+ (1 - mem_savings_factor) * chunk_numel_max
)
# # Iterate grad buffers & chunk.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# chunk_view_items = []
# for model_index, dtype, gbuf_views in gbuf_view_items:
# Chunk views.
for start_index in range(0, view_numel, chunk_numel):
end_index = min(view_numel, start_index + chunk_numel)
chunk_views = [ t[start_index:end_index] for t in gbuf_views ]
chunk_view_items.append((model_index, dtype, chunk_views))
# # ** Sanity check. ** (should be unnecessary; see comment above)
# view_numel = gbuf_views[0].nelement()
# for view in gbuf_views:
# assert view.nelement() == view_numel
# # Compute chunk size (via savings factor).
# chunk_numel_min = 131072
# chunk_numel_max = view_numel
# chunk_numel = int(
# mem_savings_factor * chunk_numel_min
# + (1 - mem_savings_factor) * chunk_numel_max
# )
# # Chunk views.
# for start_index in range(0, view_numel, chunk_numel):
# end_index = min(view_numel, start_index + chunk_numel)
# chunk_views = [ t[start_index:end_index] for t in gbuf_views ]
# chunk_view_items.append((model_index, dtype, chunk_views))
return chunk_view_items
# return chunk_view_items
# <<<
def reduce_model_grads(self, args, timers):
'''Note: this is a different order of reduction, versus the non-
......@@ -490,17 +493,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# )
# +++
gbuf_view_items = self.get_model_grad_buffer_dp_views()
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf_views) in enumerate(gbuf_view_items):
# gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
# >>>
pax(0, {
"gbuf_view" : gbuf_views[data_parallel_rank].shape,
"gbuf SINGLE" : gbuf_view_items_SINGLE[index][2].shape,
})
# pax(0, {
# "gbuf_view" : gbuf_views[data_parallel_rank].shape,
# "gbuf SINGLE" : gbuf_view_items_SINGLE[index][2].shape,
# })
# <<<
torch.distributed._reduce_scatter_base(
gbuf_views[data_parallel_rank],
gbuf_view_items_SINGLE[index][2],
gbuf, # gbuf_view_items_SINGLE[index][2],
group = data_parallel_group,
)
# torch.distributed.reduce_scatter(
......@@ -535,12 +538,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# )
# +++
gbuf_view_items = self.get_model_grad_buffer_dp_views()
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf_views) in enumerate(gbuf_view_items):
# gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
torch.distributed._all_gather_base(
# gbuf_view_items_SINGLE[index][2],
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
gbuf, # gbuf_view_items_SINGLE[index][2],
gbuf_views[data_parallel_rank],
group = data_parallel_group,
)
# <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment