Commit 41038d54 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

running & saving memory w/ _reduce_scatter_base/_all_gather_base.

parent 64b94f00
...@@ -410,50 +410,53 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -410,50 +410,53 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_size = int(gbuf.numel_padded / data_parallel_world_size) shard_size = int(gbuf.numel_padded / data_parallel_world_size)
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)] gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)] for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf_views)) # gbuf_view_items.append((model_index, dtype, gbuf_views))
gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
return gbuf_view_items return gbuf_view_items
# >>> # >>>
def get_model_grad_buffer_dp_views_SINGLE(self): # def get_model_grad_buffer_dp_views_SINGLE(self):
data_parallel_world_size = mpu.get_data_parallel_world_size() # data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views. # # Grad buffer views.
gbuf_items = [] # gbuf_items = []
for model_index, model in enumerate(self.models): # for model_index, model in enumerate(self.models):
for dtype, gbuf in model._grad_buffers.items(): # for dtype, gbuf in model._grad_buffers.items():
gbuf_items.append((model_index, dtype, gbuf.data)) # gbuf_items.append((model_index, dtype, gbuf.data))
return gbuf_items # return gbuf_items
# <<< # <<<
def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor): # >>>
# def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
# Iterate grad buffers & chunk.
gbuf_view_items = self.get_model_grad_buffer_dp_views() # # Iterate grad buffers & chunk.
chunk_view_items = [] # gbuf_view_items = self.get_model_grad_buffer_dp_views()
for model_index, dtype, gbuf_views in gbuf_view_items: # chunk_view_items = []
# for model_index, dtype, gbuf_views in gbuf_view_items:
# ** Sanity check. ** (should be unnecessary; see comment above)
view_numel = gbuf_views[0].nelement() # # ** Sanity check. ** (should be unnecessary; see comment above)
for view in gbuf_views: # view_numel = gbuf_views[0].nelement()
assert view.nelement() == view_numel # for view in gbuf_views:
# assert view.nelement() == view_numel
# Compute chunk size (via savings factor).
chunk_numel_min = 131072 # # Compute chunk size (via savings factor).
chunk_numel_max = view_numel # chunk_numel_min = 131072
chunk_numel = int( # chunk_numel_max = view_numel
mem_savings_factor * chunk_numel_min # chunk_numel = int(
+ (1 - mem_savings_factor) * chunk_numel_max # mem_savings_factor * chunk_numel_min
) # + (1 - mem_savings_factor) * chunk_numel_max
# )
# Chunk views.
for start_index in range(0, view_numel, chunk_numel): # # Chunk views.
end_index = min(view_numel, start_index + chunk_numel) # for start_index in range(0, view_numel, chunk_numel):
chunk_views = [ t[start_index:end_index] for t in gbuf_views ] # end_index = min(view_numel, start_index + chunk_numel)
chunk_view_items.append((model_index, dtype, chunk_views)) # chunk_views = [ t[start_index:end_index] for t in gbuf_views ]
# chunk_view_items.append((model_index, dtype, chunk_views))
return chunk_view_items
# return chunk_view_items
# <<<
def reduce_model_grads(self, args, timers): def reduce_model_grads(self, args, timers):
'''Note: this is a different order of reduction, versus the non- '''Note: this is a different order of reduction, versus the non-
...@@ -490,17 +493,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -490,17 +493,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# ) # )
# +++ # +++
gbuf_view_items = self.get_model_grad_buffer_dp_views() gbuf_view_items = self.get_model_grad_buffer_dp_views()
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE() # gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf_views) in enumerate(gbuf_view_items): for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
# >>> # >>>
pax(0, { # pax(0, {
"gbuf_view" : gbuf_views[data_parallel_rank].shape, # "gbuf_view" : gbuf_views[data_parallel_rank].shape,
"gbuf SINGLE" : gbuf_view_items_SINGLE[index][2].shape, # "gbuf SINGLE" : gbuf_view_items_SINGLE[index][2].shape,
}) # })
# <<< # <<<
torch.distributed._reduce_scatter_base( torch.distributed._reduce_scatter_base(
gbuf_views[data_parallel_rank], gbuf_views[data_parallel_rank],
gbuf_view_items_SINGLE[index][2], gbuf, # gbuf_view_items_SINGLE[index][2],
group = data_parallel_group, group = data_parallel_group,
) )
# torch.distributed.reduce_scatter( # torch.distributed.reduce_scatter(
...@@ -535,12 +538,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -535,12 +538,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# ) # )
# +++ # +++
gbuf_view_items = self.get_model_grad_buffer_dp_views() gbuf_view_items = self.get_model_grad_buffer_dp_views()
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE() # gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf_views) in enumerate(gbuf_view_items): for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
# gbuf_view_items_SINGLE[index][2], gbuf, # gbuf_view_items_SINGLE[index][2],
# gbuf_views[data_parallel_rank], gbuf_views[data_parallel_rank],
# group = data_parallel_group, group = data_parallel_group,
) )
# <<< # <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment