Commit 208c91e0 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

internal pipelining more similar to micro-benchmarks

parent 7ba6a038
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdam(torch.optim.Optimizer):
......@@ -46,9 +47,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, internal_pipeline=False,
e5m2_allgather=False):
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
......@@ -67,6 +68,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._overflow_buf = torch.cuda.IntTensor([0])
assert (not flat_mt), "flat_mt option is not safe in this version"
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
......@@ -81,8 +84,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._internal_pipeline = internal_pipeline
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None
......@@ -118,11 +121,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._group_size
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._shard_size = self._block_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size))
self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
......@@ -143,7 +147,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
self._num_blk_st = dwu_num_blk_st
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
......@@ -152,6 +155,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream()]*self._num_ar_pg
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
......@@ -162,8 +166,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
self._rs_st = [torch.cuda.Stream()]*self._num_rs_pg
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
......@@ -172,16 +181,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._blk_st = []
for i in range(self._num_blk_st):
self._blk_st.append(torch.cuda.Stream())
self._ag_st = [torch.cuda.Stream()]*self._num_ag_pg
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream()
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
import inspect
if 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args:
self._pg_supports_no_copy = True
else:
self._pg_supports_no_copy = False
print("WARNING! torch.distributed.reduce_scatter does not support no_copy op.")
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def set_last_step(self, last_step):
......@@ -207,71 +215,65 @@ class DistributedFusedAdam(torch.optim.Optimizer):
return flush_block
def _pipeline_block_reductions(self, block_id, flat_grads):
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
start = block_id * self._block_size
end = start + self._block_size
grad_block = flat_grads[start:end]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
if self._internal_pipeline:
works = []
chunk_size = self._shard_size // self._num_chunks
for i in range(self._num_chunks):
chunks = [grad_shards[j][i*chunk_size:(i+1)*chunk_size] for j in range(self._group_size)]
if self._pg_supports_no_copy:
work = torch.distributed.reduce_scatter(chunks[self._rank_in_group],chunks,group=self._rs_pg[i%len(self._rs_pg)],async_op=True,no_copy=True)
else:
work = torch.distributed.reduce_scatter(chunks[self._rank_in_group],chunks,group=self._rs_pg[i%len(self._rs_pg)],async_op=True)
if self._num_groups > 1:
work.wait()
work = torch.distributed.all_reduce(chunks[self._rank_in_group],group=self._ar_pg[i%len(self._ar_pg)],async_op=True)
works.append(work)
else:
if self._pg_supports_no_copy:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True)
else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
works = [work]
grad_block = self._flat_grads[start:end]
works = [None]*self._num_chunks
for chunk in range(self._num_chunks):
grad_chunk = grad_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
grad_shards = [grad_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
rs_stream = self._rs_st[chunk%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[chunk%self._num_rs_pg],async_op=True,no_copy=True)
if self._num_groups > 1:
work.wait()
works = []
chunk_size = self._shard_size // self._num_chunks
for i in range(self._num_chunks):
chunks = [grad_shards[j][i*chunk_size:(i+1)*chunk_size] for j in range(self._group_size)]
work = torch.distributed.all_reduce(chunks[self._rank_in_group],group=self._ar_pg[i%len(self._ar_pg)],async_op=True)
works.append(work)
if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[0]):
for work in works:
ar_stream = self._ar_st[chunk%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
work.wait()
if block_id+1 == self._num_blocks:
self._L2_grad_norm = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
elif block_id != 0:
self._L2_grad_norm += grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
else:
self._L2_grad_norm += grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[0])
self._L2_grad_norm.sqrt_()
for work in works:
work.wait()
work = torch.distributed.all_reduce(grad_shards[self._rank_in_group],group=self._ar_pg[chunk%self._num_ar_pg],async_op=True)
works[chunk] = work
# NB!
# self._global_scale is used by this method.
if self._compute_L2_grad_norm:
for chunk in range(self._num_chunks):
grad_chunk = grad_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
grad_shards = [grad_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
with torch.cuda.stream(self._l2_grad_norm_st):
works[chunk].wait()
l2_grad_sq = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
if block_id+1 == self._num_blocks and chunk == 0:
self._L2_grad_norm = l2_grad_sq
else:
self._L2_grad_norm += l2_grad_sq
if block_id == 0 and chunk+1 == self._num_chunks:
torch.distributed.all_reduce(self._L2_grad_norm,group=self._l2_grad_norm_pg)
self._L2_grad_norm.sqrt_()
self._reductions_works[block_id] = works
def _pipeline_block_step(self, block_id):
if self._new_params is None:
self._new_params = torch.zeros_like(self._flat_grads,dtype=uint8 if self._e5m2_allgather else self._flat_grads.dtype)
def _pipeline_block_step(self, block_id, flat_grads, new_params):
start = block_id * self._block_size
new_params_shards = [new_params[start+shard_i*self._shard_size:start+(shard_i+1)*self._shard_size] for shard_i in range(self._group_size)]
self._partial_step_single_shard(block_id)
if self._pg_supports_no_copy:
torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],no_copy=True)
else:
torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)])
end = start + self._block_size
new_params_block = self._new_params[start:end]
def _pipeline_block(self, block_id, flat_grads, new_params):
self._pipeline_block_reductions(block_id, flat_grads)
self._pipeline_block_step(block_id, flat_grads, new_params)
works = [None]*self._num_chunks
for chunk in range(self._num_chunks):
new_params_chunk = new_params_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
new_params_shards = [new_params_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
ag_stream = self._ag_st[chunk%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
self._reductions_works[block_id][chunk].wait()
self._partial_step_single_shard(block_id,chunk)
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[chunk%self._num_ag_pg],async_op=True,no_copy=True)
works[chunk] = work
self._allgather_works[block_id] = works
def _flatten_grad_mt(self, scale):
if self._flat_mt:
......@@ -281,10 +283,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if grad is not None:
grads.append(grad)
flat_grads.append( self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]] )
self._grads[p_i] = None
self._grads = [None]*len(self._grads_info)
if len(grads) > 0:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
......@@ -295,7 +295,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
# handle overlapped reductions
if self._flat_mt:
self._grads[param_i] = grad
self._grads[param_i] = grad.view(-1)
else:
torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size])
self._grads_generated[param_i]=True
......@@ -304,19 +304,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
if self._full_pipeline:
if self._new_params is None:
if self._e5m2_allgather:
self._new_params = torch.zeros_like(self._flat_grads,dtype=torch.uint8)
else:
self._new_params = torch.zeros_like(self._flat_grads)
self._pipeline_block(block_id, self._flat_grads, self._new_params)
else:
self._pipeline_block_reductions(block_id, self._flat_grads)
self._pipeline_block_reductions(block_id)
if self._full_pipeline:
self._pipeline_block_step(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
......@@ -360,8 +350,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
@property
def L2_grad_norm(self):
if self._compute_L2_grad_norm:
for i, blk_st in enumerate(self._blk_st):
torch.cuda.current_stream().wait_stream(blk_st)
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
else:
return None
......@@ -376,7 +365,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# This means we have to play around with indexes, which requires knowledge of block and shard number.
# Implement a method that performs a partial update of a single shard within a single block.
def _partial_step_single_shard(self, block_id, undo=False):
def _partial_step_single_shard(self, block_id, chunk_id, undo=False):
"""Perform step function for a single shard.
Arguments:
......@@ -385,7 +374,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""
shard_id = self._rank_in_group
shard_start = block_id * self._block_size + shard_id * self._shard_size
shard_start = (block_id * self._num_chunks + chunk_id) * self._chunk_size + shard_id * self._shard_size
shard_end = shard_start + self._shard_size
if self._fp32_p is None:
......@@ -393,13 +382,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Allocate fp32 buffers on demand. Note that we don't make these part of the state
# since each rank only has partial buffers.
# To-Do:
self._fp32_p = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_m = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_v = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_p = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_m = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_v = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
if self._revert_method > 1:
self._fp32_backup_p = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_backup_m = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_backup_v = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_backup_p = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_backup_m = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._fp32_backup_v = torch.zeros([self._num_blocks*self._num_chunks*self._shard_size]).float().cuda()
self._copy_to_fp32 = True
step = None
......@@ -445,7 +434,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._copy_to_fp32:
param_offset = clipped_start - shard_start
param_size = clipped_end - clipped_start
buffer_start = block_id * self._shard_size + param_offset
buffer_start = (block_id * self._num_chunks + chunk_id) * self._shard_size + param_offset
buffer_end = buffer_start + param_size
param_start = (clipped_start - start)
param_end = param_start + param_size
......@@ -457,7 +446,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
group_offset = group_start - shard_start
group_shard_start = shard_start + group_offset
group_shard_end = group_shard_start + group_size
group_buffer_start = block_id * self._shard_size + group_offset
group_buffer_start = (block_id * self._num_chunks + chunk_id) * self._shard_size + group_offset
group_buffer_end = group_buffer_start + group_size
beta1, beta2 = group['betas']
......@@ -520,12 +509,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
self._pipeline_block_reductions(block_id, self._flat_grads)
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id)
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._copy_to_fp32 = False
self._decomp_stats = None
......@@ -536,7 +524,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Revert effect of previously calling partial_step.
"""
for block_id in range(self._num_blocks):
self._partial_step_single_shard(block_id, undo=True)
for chunk in range(self._num_chunks):
self._partial_step_single_shard(block_id, chunk, undo=True)
def step(self, closure=None, skip_overflow_check=False):
loss = None
......@@ -544,19 +533,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline:
if self._new_params is None:
if self._e5m2_allgather:
self._new_params = torch.zeros_like(self._flat_grads,dtype=torch.uint8)
else:
self._new_params = torch.zeros_like(self._flat_grads)
for inv_block_id in range(self._num_blocks):
block_id = self._num_blocks - inv_block_id - 1
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
self._pipeline_block_step(block_id, self._flat_grads, self._new_params)
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_step(block_id)
with torch.cuda.stream(self._blk_st[0]):
for i, blk_st in enumerate(self._blk_st):
torch.cuda.current_stream().wait_stream(blk_st)
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks-1,-1,-1):
for chunk in range(self._num_chunks):
self._allgather_works[block_id][chunk].wait()
# Check for overflow
# Store state for loss scaler calculation
......@@ -570,7 +553,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.revert_step()
else:
# Copy self._new_params to model params
if self._e5m2_allgather:
if self._e5m2_allgather or self._do_not_flatten_model:
p_in = []
p_out = []
with torch.no_grad():
......@@ -585,7 +568,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
state['step'] += 1
nels = p.numel()
offset = self._grads_info[param_i]['param_offset']
if self._e5m2_allgather:
if self._e5m2_allgather or self._do_not_flatten_model:
p_in.append(self._new_params[offset:offset+nels].view_as(p))
p_out.append(p)
else:
......@@ -596,9 +579,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
fused_adam_cuda.unpack_e5m2_mt,
self._overflow_buf,
[p_in, p_out]);
self._new_params = None
elif self._do_not_flatten_model:
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
[p_in, p_out],
1.0);
if not self._e5m2_allgather and not self._do_not_flatten_model:
self._new_params = None
torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
torch.cuda.current_stream().wait_stream(self._blk_st[0])
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment