Commit db8fb976 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add back support for multi tensor scale flattening

parent 3f717d95
...@@ -107,6 +107,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -107,6 +107,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
p_offset = ((p_offset + 63) // 64) * 64 p_offset = ((p_offset + 63) // 64) * 64
p_i += 1 p_i += 1
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = [None]*len(self._grads_info) if self._flat_mt else None
if self._overlap_reductions: if self._overlap_reductions:
self._current_block = self._num_blocks self._current_block = self._num_blocks
...@@ -118,6 +120,14 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -118,6 +120,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._shard_size = self._block_size // self._group_size self._shard_size = self._block_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size)) print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size]).half().cuda() self._flat_grads = torch.zeros([self._total_param_size]).half().cuda()
self._new_params = None self._new_params = None
self._fp32_p = None self._fp32_p = None
...@@ -175,20 +185,21 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -175,20 +185,21 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _get_flush_block(self): def _get_flush_block(self):
flush_block = [] flush_block = []
num_grads = len(self._grads_generated) if self._grads_generated[self._low_param_i[self._current_block-1]]:
contiguous_idx = num_grads num_grads = len(self._grads_generated)
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]: contiguous_idx = num_grads
contiguous_idx -= 1 while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1 if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
start = self._current_block * self._block_size self._current_block -= 1
end = (self._current_block+1) * self._block_size start = self._current_block * self._block_size
flush_block = [start, end] end = (self._current_block+1) * self._block_size
flush_block = [start, end]
if self._current_block == 0:
# reset if self._current_block == 0:
self._grads_generated = [False]*len(self._grads_info) # reset
self._grads_generated = [False]*len(self._grads_info)
return flush_block return flush_block
...@@ -244,15 +255,38 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -244,15 +255,38 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._pipeline_block_reductions(block_id, flat_grads) self._pipeline_block_reductions(block_id, flat_grads)
self._pipeline_block_step(block_id, flat_grads, new_params) self._pipeline_block_step(block_id, flat_grads, new_params)
def _flatten_grad_mt(self, scale):
if self._flat_mt:
grads = []
flat_grads = []
for p_i, (grads_info, grad) in enumerate(zip(self._grads_info, self._grads)):
if grad is not None:
grads.append(grad)
flat_grads.append( self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]] )
self._grads[p_i] = None
if len(grads) > 0:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
[grads, flat_grads],
scale)
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad): def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
# handle overlapped reductions # handle overlapped reductions
torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size]) if self._flat_mt:
self._grads[param_i] = grad
else:
torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size])
self._grads_generated[param_i]=True self._grads_generated[param_i]=True
if not self._last_step: if not self._last_step:
if self._overlap_reductions: if self._overlap_reductions:
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
while flush_block: while flush_block:
block_id = flush_block[0] // self._block_size block_id = flush_block[0] // self._block_size
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream()) self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]): with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
if self._full_pipeline: if self._full_pipeline:
...@@ -462,6 +496,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -462,6 +496,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# nothing done so far, run full pipeline after reductions # nothing done so far, run full pipeline after reductions
for inv_block_id in range(self._num_blocks): for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1 block_id = self._num_blocks - inv_block_id - 1
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream()) self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]): with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
self._pipeline_block_reductions(block_id, self._flat_grads) self._pipeline_block_reductions(block_id, self._flat_grads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment