Commit bec558b1 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add graphing, switch to peer mem exchanger as default

parent 4aeb24cb
......@@ -49,16 +49,29 @@ def rel_diff(x1, x2):
return rel_diff_t(x1,x2)
def fprop_and_bprop(x, bottleneck, dy=None):
def graph_it(bottleneck, x):
print("Graphing")
with torch.no_grad():
x = x.clone()
x.grad = None
x.requires_grad = True
y = bottleneck(x)
if dy is None:
with torch.no_grad():
return torch.cuda.make_graphed_callables(bottleneck, (x,))
def clone_inputs(bottleneck, x, dy=None):
with torch.no_grad():
x = x.clone()
x.grad = None
x.requires_grad = True
if dy is None:
y = bottleneck(x)
dy = torch.randn_like(y) / 1e2
torch.distributed.broadcast(dy, 0)
return x, dy
def fprop_and_bprop(bottleneck, x, dy):
y = bottleneck(x)
y.backward(dy)
dgrad = x.grad.detach()
wgrad = {}
......@@ -74,7 +87,8 @@ def ground_truth(N, C, H, W, dtype, memory_format, bottleneck):
with torch.no_grad():
x = torch.randn([N,H,W,C], dtype=dtype, device='cuda')
torch.distributed.broadcast(x, 0)
return fprop_and_bprop(x, bottleneck)
x, dy = clone_inputs(bottleneck, x)
return fprop_and_bprop(bottleneck, x, dy)
else:
# 2 -> native nhwc
# 3 -> nchw
......@@ -92,11 +106,9 @@ def print_ground_truth(gt):
def apply_to_different_bottleneck(gt, bottleneck):
with torch.no_grad():
x, y, dy, dgrad, wgrad = gt
x = x.clone()
x.requires_grad = True
dy = dy.clone()
return fprop_and_bprop(x, bottleneck, dy)
x, _, dy, _, _ = gt
x, dy = clone_inputs(bottleneck, x, dy)
return fprop_and_bprop(bottleneck, x, dy)
def compare_single_field(results, f1, f2, l0, l1, l2):
......@@ -161,15 +173,20 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
spatial_group_rank = rank
spatial_communicator = None
spatial_halo_exchanger = halex
spatial_method = 3 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method)
spatial_method = 1 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
use_delay_kernel = False
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel)
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
with torch.no_grad():
Hs = H // spatial_group_size
xs = x[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:]
dys = dy[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:]
_, y, _, dgrad, wgrad = fprop_and_bprop(xs, spatial_bottleneck, dys)
xs = x[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
dys = dy[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
xs.requires_grad = True
spatial_bottleneck = graph_it(spatial_bottleneck, xs)
_, y, _, dgrad, wgrad = fprop_and_bprop(spatial_bottleneck, xs, dys)
# gather output pieces
for n,p in wgrad.items():
if fp32_reduce:
......@@ -217,9 +234,9 @@ def main():
peer_pool = PeerMemoryPool(rank, world_size, spatial_group_size, 64*1024*1024, 2*1024*1024)
#halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator)
halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator)
#halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator)
#halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment