Commit 208c91e0 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

internal pipelining more similar to micro-benchmarks

parent 7ba6a038
import math import math
import torch import torch
import importlib import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
...@@ -46,9 +47,9 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -46,9 +47,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True, amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0, compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False, dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, internal_pipeline=False, dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
e5m2_allgather=False): do_not_flatten_model=False):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
...@@ -67,6 +68,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -67,6 +68,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
assert (not flat_mt), "flat_mt option is not safe in this version"
# Way to revert a step # Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference) # 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters # 2 -> double buffer fp32 parameters
...@@ -81,8 +84,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -81,8 +84,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_blocks = dwu_num_blocks self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._predivide = predivide self._predivide = predivide
self._internal_pipeline = internal_pipeline
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None
...@@ -118,11 +121,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -118,11 +121,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._net_total_param_size = p_offset self._net_total_param_size = p_offset
self._total_param_size = p_offset self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._group_size dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks self._block_size = self._total_param_size // self._num_blocks
self._shard_size = self._block_size // self._group_size self._chunk_size = self._block_size // self._num_chunks
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size)) self._shard_size = self._chunk_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1): for block_id in range(self._num_blocks-1,-1,-1):
...@@ -143,7 +147,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -143,7 +147,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_rs_pg = dwu_num_rs_pg self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg self._num_ag_pg = dwu_num_ag_pg
self._num_blk_st = dwu_num_blk_st
if self._num_groups > 1: if self._num_groups > 1:
self._ar_pg = [] self._ar_pg = []
for dev_i in range(self._group_size): for dev_i in range(self._group_size):
...@@ -152,6 +155,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -152,6 +155,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks) grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp) self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream()]*self._num_ar_pg
rs_ranks = [] rs_ranks = []
for group_i in range(self._num_groups): for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)]) rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
...@@ -162,8 +166,13 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -162,8 +166,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks) grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp) self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
self._rs_st = [torch.cuda.Stream()]*self._num_rs_pg
if self._num_ag_pg == 0: if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else: else:
self._ag_pg = [] self._ag_pg = []
for group_i in range(self._num_groups): for group_i in range(self._num_groups):
...@@ -172,16 +181,15 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -172,16 +181,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks) grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp) self._ag_pg.append(grp)
self._blk_st = [] self._ag_st = [torch.cuda.Stream()]*self._num_ag_pg
for i in range(self._num_blk_st): self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._blk_st.append(torch.cuda.Stream()) self._completion_st = torch.cuda.Stream()
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
import inspect import inspect
if 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args: assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
self._pg_supports_no_copy = True
else:
self._pg_supports_no_copy = False
print("WARNING! torch.distributed.reduce_scatter does not support no_copy op.")
def set_last_step(self, last_step): def set_last_step(self, last_step):
...@@ -207,71 +215,65 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -207,71 +215,65 @@ class DistributedFusedAdam(torch.optim.Optimizer):
return flush_block return flush_block
def _pipeline_block_reductions(self, block_id, flat_grads): def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
start = block_id * self._block_size start = block_id * self._block_size
end = start + self._block_size end = start + self._block_size
grad_block = flat_grads[start:end] grad_block = self._flat_grads[start:end]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
works = [None]*self._num_chunks
if self._internal_pipeline: for chunk in range(self._num_chunks):
works = [] grad_chunk = grad_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
chunk_size = self._shard_size // self._num_chunks grad_shards = [grad_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
for i in range(self._num_chunks): rs_stream = self._rs_st[chunk%self._num_rs_pg]
chunks = [grad_shards[j][i*chunk_size:(i+1)*chunk_size] for j in range(self._group_size)] rs_stream.wait_stream(torch.cuda.current_stream())
if self._pg_supports_no_copy: with torch.cuda.stream(rs_stream):
work = torch.distributed.reduce_scatter(chunks[self._rank_in_group],chunks,group=self._rs_pg[i%len(self._rs_pg)],async_op=True,no_copy=True) work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[chunk%self._num_rs_pg],async_op=True,no_copy=True)
else:
work = torch.distributed.reduce_scatter(chunks[self._rank_in_group],chunks,group=self._rs_pg[i%len(self._rs_pg)],async_op=True)
if self._num_groups > 1:
work.wait()
work = torch.distributed.all_reduce(chunks[self._rank_in_group],group=self._ar_pg[i%len(self._ar_pg)],async_op=True)
works.append(work)
else:
if self._pg_supports_no_copy:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True)
else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
works = [work]
if self._num_groups > 1: if self._num_groups > 1:
ar_stream = self._ar_st[chunk%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
work.wait() work.wait()
works = [] work = torch.distributed.all_reduce(grad_shards[self._rank_in_group],group=self._ar_pg[chunk%self._num_ar_pg],async_op=True)
chunk_size = self._shard_size // self._num_chunks works[chunk] = work
for i in range(self._num_chunks):
chunks = [grad_shards[j][i*chunk_size:(i+1)*chunk_size] for j in range(self._group_size)]
work = torch.distributed.all_reduce(chunks[self._rank_in_group],group=self._ar_pg[i%len(self._ar_pg)],async_op=True)
works.append(work)
if self._compute_L2_grad_norm: if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[0]): for chunk in range(self._num_chunks):
for work in works: grad_chunk = grad_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
work.wait() grad_shards = [grad_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
if block_id+1 == self._num_blocks: with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2 works[chunk].wait()
elif block_id != 0: l2_grad_sq = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
self._L2_grad_norm += grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2 if block_id+1 == self._num_blocks and chunk == 0:
self._L2_grad_norm = l2_grad_sq
else: else:
self._L2_grad_norm += grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2 self._L2_grad_norm += l2_grad_sq
torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[0]) if block_id == 0 and chunk+1 == self._num_chunks:
torch.distributed.all_reduce(self._L2_grad_norm,group=self._l2_grad_norm_pg)
self._L2_grad_norm.sqrt_() self._L2_grad_norm.sqrt_()
for work in works: self._reductions_works[block_id] = works
work.wait()
# NB! def _pipeline_block_step(self, block_id):
# self._global_scale is used by this method. if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads,dtype=uint8 if self._e5m2_allgather else self._flat_grads.dtype)
def _pipeline_block_step(self, block_id, flat_grads, new_params):
start = block_id * self._block_size start = block_id * self._block_size
new_params_shards = [new_params[start+shard_i*self._shard_size:start+(shard_i+1)*self._shard_size] for shard_i in range(self._group_size)] end = start + self._block_size
self._partial_step_single_shard(block_id) new_params_block = self._new_params[start:end]
if self._pg_supports_no_copy:
torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],no_copy=True)
else:
torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)])
def _pipeline_block(self, block_id, flat_grads, new_params): works = [None]*self._num_chunks
self._pipeline_block_reductions(block_id, flat_grads) for chunk in range(self._num_chunks):
self._pipeline_block_step(block_id, flat_grads, new_params) new_params_chunk = new_params_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
new_params_shards = [new_params_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
ag_stream = self._ag_st[chunk%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
self._reductions_works[block_id][chunk].wait()
self._partial_step_single_shard(block_id,chunk)
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[chunk%self._num_ag_pg],async_op=True,no_copy=True)
works[chunk] = work
self._allgather_works[block_id] = works
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
if self._flat_mt: if self._flat_mt:
...@@ -281,10 +283,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -281,10 +283,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if grad is not None: if grad is not None:
grads.append(grad) grads.append(grad)
flat_grads.append( self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]] ) flat_grads.append( self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]] )
self._grads[p_i] = None self._grads = [None]*len(self._grads_info)
if len(grads) > 0: if len(grads) > 0:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
self._overflow_buf.zero_() self._overflow_buf.zero_()
multi_tensor_applier( multi_tensor_applier(
amp_C.multi_tensor_scale, amp_C.multi_tensor_scale,
...@@ -295,7 +295,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -295,7 +295,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad): def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
# handle overlapped reductions # handle overlapped reductions
if self._flat_mt: if self._flat_mt:
self._grads[param_i] = grad self._grads[param_i] = grad.view(-1)
else: else:
torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size]) torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size])
self._grads_generated[param_i]=True self._grads_generated[param_i]=True
...@@ -304,19 +304,9 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -304,19 +304,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
while flush_block: while flush_block:
block_id = flush_block[0] // self._block_size block_id = flush_block[0] // self._block_size
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0) self._pipeline_block_reductions(block_id)
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
if self._full_pipeline: if self._full_pipeline:
if self._new_params is None: self._pipeline_block_step(block_id)
if self._e5m2_allgather:
self._new_params = torch.zeros_like(self._flat_grads,dtype=torch.uint8)
else:
self._new_params = torch.zeros_like(self._flat_grads)
self._pipeline_block(block_id, self._flat_grads, self._new_params)
else:
self._pipeline_block_reductions(block_id, self._flat_grads)
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
def set_global_scale(self, global_scale): def set_global_scale(self, global_scale):
...@@ -360,8 +350,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -360,8 +350,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
@property @property
def L2_grad_norm(self): def L2_grad_norm(self):
if self._compute_L2_grad_norm: if self._compute_L2_grad_norm:
for i, blk_st in enumerate(self._blk_st): torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
torch.cuda.current_stream().wait_stream(blk_st)
return self._L2_grad_norm return self._L2_grad_norm
else: else:
return None return None
...@@ -376,7 +365,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -376,7 +365,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# This means we have to play around with indexes, which requires knowledge of block and shard number. # This means we have to play around with indexes, which requires knowledge of block and shard number.
# Implement a method that performs a partial update of a single shard within a single block. # Implement a method that performs a partial update of a single shard within a single block.
def _partial_step_single_shard(self, block_id, undo=False): def _partial_step_single_shard(self, block_id, chunk_id, undo=False):
"""Perform step function for a single shard. """Perform step function for a single shard.
Arguments: Arguments:
...@@ -385,7 +374,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -385,7 +374,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
""" """
shard_id = self._rank_in_group shard_id = self._rank_in_group
shard_start = block_id * self._block_size + shard_id * self._shard_size shard_start = (block_id * self._num_chunks + chunk_id) * self._chunk_size + shard_id * self._shard_size
shard_end = shard_start + self._shard_size shard_end = shard_start + self._shard_size
if self._fp32_p is None: if self._fp32_p is None:
...@@ -393,13 +382,13 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -393,13 +382,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Allocate fp32 buffers on demand. Note that we don't make these part of the state # Allocate fp32 buffers on demand. Note that we don't make these part of the state
# since each rank only has partial buffers. # since each rank only has partial buffers.
# To-Do: # To-Do:
self._fp32_p = torch.zeros([self._num_blocks*self._shard_size]).float().cuda() self._fp32_p = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_m = torch.zeros([self._num_blocks*self._shard_size]).float().cuda() self._fp32_m = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_v = torch.zeros([self._num_blocks*self._shard_size]).float().cuda() self._fp32_v = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
if self._revert_method > 1: if self._revert_method > 1:
self._fp32_backup_p = torch.zeros([self._num_blocks*self._shard_size]).float().cuda() self._fp32_backup_p = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_backup_m = torch.zeros([self._num_blocks*self._shard_size]).float().cuda() self._fp32_backup_m = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_backup_v = torch.zeros([self._num_blocks*self._shard_size]).float().cuda() self._fp32_backup_v = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._copy_to_fp32 = True self._copy_to_fp32 = True
step = None step = None
...@@ -445,7 +434,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -445,7 +434,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._copy_to_fp32: if self._copy_to_fp32:
param_offset = clipped_start - shard_start param_offset = clipped_start - shard_start
param_size = clipped_end - clipped_start param_size = clipped_end - clipped_start
buffer_start = block_id * self._shard_size + param_offset buffer_start = (block_id * self._num_chunks + chunk_id) * self._shard_size + param_offset
buffer_end = buffer_start + param_size buffer_end = buffer_start + param_size
param_start = (clipped_start - start) param_start = (clipped_start - start)
param_end = param_start + param_size param_end = param_start + param_size
...@@ -457,7 +446,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -457,7 +446,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
group_offset = group_start - shard_start group_offset = group_start - shard_start
group_shard_start = shard_start + group_offset group_shard_start = shard_start + group_offset
group_shard_end = group_shard_start + group_size group_shard_end = group_shard_start + group_size
group_buffer_start = block_id * self._shard_size + group_offset group_buffer_start = (block_id * self._num_chunks + chunk_id) * self._shard_size + group_offset
group_buffer_end = group_buffer_start + group_size group_buffer_end = group_buffer_start + group_size
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
...@@ -520,12 +509,11 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -520,12 +509,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._last_step or not self._overlap_reductions: if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions # nothing done so far, run full pipeline after reductions
for inv_block_id in range(self._num_blocks): for block_id in range(self._num_blocks-1,-1,-1):
block_id = self._num_blocks - inv_block_id - 1 self._pipeline_block_reductions(block_id)
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream()) if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]): torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._pipeline_block_reductions(block_id, self._flat_grads)
self._copy_to_fp32 = False self._copy_to_fp32 = False
self._decomp_stats = None self._decomp_stats = None
...@@ -536,7 +524,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -536,7 +524,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Revert effect of previously calling partial_step. """Revert effect of previously calling partial_step.
""" """
for block_id in range(self._num_blocks): for block_id in range(self._num_blocks):
self._partial_step_single_shard(block_id, undo=True) for chunk in range(self._num_chunks):
self._partial_step_single_shard(block_id, chunk, undo=True)
def step(self, closure=None, skip_overflow_check=False): def step(self, closure=None, skip_overflow_check=False):
loss = None loss = None
...@@ -544,19 +533,13 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -544,19 +533,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
loss = closure() loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline: if self._last_step or not self._overlap_reductions or not self._full_pipeline:
if self._new_params is None: for block_id in range(self._num_blocks-1,-1,-1):
if self._e5m2_allgather: self._pipeline_block_step(block_id)
self._new_params = torch.zeros_like(self._flat_grads,dtype=torch.uint8)
else:
self._new_params = torch.zeros_like(self._flat_grads)
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
self._pipeline_block_step(block_id, self._flat_grads, self._new_params)
with torch.cuda.stream(self._blk_st[0]): with torch.cuda.stream(self._completion_st):
for i, blk_st in enumerate(self._blk_st): for block_id in range(self._num_blocks-1,-1,-1):
torch.cuda.current_stream().wait_stream(blk_st) for chunk in range(self._num_chunks):
self._allgather_works[block_id][chunk].wait()
# Check for overflow # Check for overflow
# Store state for loss scaler calculation # Store state for loss scaler calculation
...@@ -570,7 +553,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -570,7 +553,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.revert_step() self.revert_step()
else: else:
# Copy self._new_params to model params # Copy self._new_params to model params
if self._e5m2_allgather: if self._e5m2_allgather or self._do_not_flatten_model:
p_in = [] p_in = []
p_out = [] p_out = []
with torch.no_grad(): with torch.no_grad():
...@@ -585,7 +568,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -585,7 +568,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
state['step'] += 1 state['step'] += 1
nels = p.numel() nels = p.numel()
offset = self._grads_info[param_i]['param_offset'] offset = self._grads_info[param_i]['param_offset']
if self._e5m2_allgather: if self._e5m2_allgather or self._do_not_flatten_model:
p_in.append(self._new_params[offset:offset+nels].view_as(p)) p_in.append(self._new_params[offset:offset+nels].view_as(p))
p_out.append(p) p_out.append(p)
else: else:
...@@ -596,9 +579,20 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -596,9 +579,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
fused_adam_cuda.unpack_e5m2_mt, fused_adam_cuda.unpack_e5m2_mt,
self._overflow_buf, self._overflow_buf,
[p_in, p_out]); [p_in, p_out]);
elif self._do_not_flatten_model:
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
[p_in, p_out],
1.0);
if not self._e5m2_allgather and not self._do_not_flatten_model:
self._new_params = None self._new_params = None
torch.cuda.current_stream().wait_stream(self._blk_st[0]) torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment