You need to sign in or sign up before continuing.
Commit 34df0f79 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

wgrad2 in parallel stream, optional mode to wait for halo transfer

parent 834b1d01
......@@ -253,31 +253,35 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_out1_halo = out1_pad[:,:1,:,:]
btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1)
with torch.cuda.stream(stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args)
else:
with torch.cuda.stream(stream2):
btm_out1_halo.zero_()
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args)
if spatial_method == 1:
# overlap mid convolution with halo transfer
if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1)
with torch.cuda.stream(stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args)
inc.add_delay(10)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
torch.cuda.current_stream().wait_stream(stream1)
torch.cuda.current_stream().wait_stream(stream3)
fast_bottleneck.forward_out2_pad(nhwc, stride_1x1, args, outputs, out1_pad)
else:
with torch.cuda.stream(stream1):
top_out1_halo.zero_()
inc.add_delay(10)
assert(False), "spatial_method must be 1 or 2"
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
if spatial_group_size <= 1 or spatial_method == 1:
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
# compute halo cells for outputs[1] (out2)
if spatial_group_size > 1:
if spatial_group_size > 1 and spatial_method == 1:
out2 = outputs[1]
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
......@@ -290,6 +294,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
# save halos for backward pass
if spatial_group_size > 1:
torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else:
ctx.save_for_backward(*(args+outputs))
......@@ -368,10 +373,13 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
inc.add_delay(10)
if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute wgrad2 for internal cells
#wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
......@@ -406,6 +414,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
torch.cuda.current_stream().wait_stream(wgrad2_stream)
return (None, None, None, None, None, None, None, None, None, *grads)
......
......@@ -161,8 +161,8 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
spatial_group_rank = rank
spatial_communicator = None
spatial_halo_exchanger = halex
spatial_stream = None # Not in use
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_stream)
spatial_method = 2 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method)
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
with torch.no_grad():
......@@ -217,8 +217,14 @@ def main():
peer_pool = PeerMemoryPool(rank, world_size, spatial_group_size, 64*1024*1024, 2*1024*1024)
#halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator)
#halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator)
halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator)
#halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
#torch.distributed.barrier()
bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True)
compare(gt, bt2)
......
......@@ -9,14 +9,17 @@ import peer_memory as pm
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class HaloExchanger(object):
def __init__(self):
def __init__(self, spatial_group_size, rank):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
spatial_rank = rank % spatial_group_size
self.left_zero = True if spatial_rank == 0 else False
self.right_zero = True if spatial_rank == spatial_group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerNoComm, self).__init__()
super(HaloExchangerNoComm, self).__init__(spatial_group_size, rank)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
......@@ -27,7 +30,7 @@ class HaloExchangerNoComm(HaloExchanger):
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerAllGather, self).__init__()
super(HaloExchangerAllGather, self).__init__(spatial_group_size, rank)
self.spatial_group_size = spatial_group_size
self.local_rank = rank % spatial_group_size
self.comm = comm
......@@ -43,14 +46,24 @@ class HaloExchangerAllGather(HaloExchanger):
ag_left_input_halo = all_halos[(self.spatial_group_size+self.local_rank-1)%self.spatial_group_size][:,Hh:,:,:]
ag_right_input_halo = all_halos[(self.local_rank+1)%self.spatial_group_size][:,:Hh,:,:]
if left_input_halo is None:
if self.left_zero:
ag_left_input_halo.zero_()
if self.right_zero:
ag_right_input_halo.zero_()
return ag_left_input_halo, ag_right_input_halo
else:
left_input_halo.copy_(ag_left_input_halo)
right_input_halo.copy_(ag_right_input_halo)
if self.left_zero:
left_input_halo.zero_()
else:
left_input_halo.copy_(ag_left_input_halo)
if self.right_zero:
right_input_halo.zero_()
else:
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerSendRecv, self).__init__()
super(HaloExchangerSendRecv, self).__init__(spatial_group_size, rank)
self.world_size = world_size
self.spatial_group_size = spatial_group_size
nccl_id = inc.get_unique_nccl_id(1).cuda()
......@@ -60,14 +73,14 @@ class HaloExchangerSendRecv(HaloExchanger):
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, left_output_halo, right_output_halo, self.spatial_group_size)
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, self.spatial_group_size)
return left_input_halo, right_input_halo
else:
inc.left_right_halo_exchange_inplace(self.handle, left_output_halo, right_output_halo, left_input_halo, right_input_halo, self.spatial_group_size)
inc.left_right_halo_exchange_inplace(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, self.spatial_group_size)
class HaloExchangerPeer(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
super(HaloExchangerPeer, self).__init__()
super(HaloExchangerPeer, self).__init__(spatial_group_size, rank)
self.diagnostics = False
self.spatial_group_size = spatial_group_size
self.peer_rank = rank % spatial_group_size
......@@ -93,6 +106,11 @@ class HaloExchangerPeer(HaloExchanger):
right_output_halo, right_tx[self.peer_rank], left_tx[self.right_neighbor], right_input_halo,
self.signals[self.left_neighbor], self.signals[self.right_neighbor], self.signals[self.peer_rank]
)
# TODO: Add to push_pull_halos_1d kernel
if self.left_zero:
left_input_halo.zero_()
if self.right_zero:
right_input_halo.zero_()
if not inplace:
return left_input_halo, right_input_halo
......
......@@ -1620,6 +1620,7 @@ struct bottleneck_forward_status {
int64_t outdimA0[4];
int64_t outdimA1[4];
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4];
int64_t outdimA3[4];
int64_t outdimA4[4];
......@@ -1633,6 +1634,7 @@ struct bottleneck_forward_status {
int64_t outdim0[4]; // halo input shape
int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim4[4]; // halo output shape
......@@ -1672,6 +1674,7 @@ struct bottleneck_forward_status {
// output dim in n,c,h,w used by backend
outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0;
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0;
......@@ -1690,6 +1693,13 @@ struct bottleneck_forward_status {
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
......@@ -1715,6 +1725,7 @@ struct bottleneck_forward_status {
// Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
if (explicit_nhwc) {
......@@ -1726,6 +1737,7 @@ struct bottleneck_forward_status {
for (int dim=0;dim<4;dim++) {
outdim0[dim] = outdimA0[axis[dim]];
outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
outdim4[dim] = outdimA4[axis[dim]];
......@@ -1859,6 +1871,44 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor out1_pad) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1_pad.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation(forward_state.outdimA1b,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed;
......@@ -2520,6 +2570,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward");
m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward");
m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward");
m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward");
m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward");
m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init");
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
......
......@@ -100,7 +100,7 @@ class NcclCommWrapper
});
}
void left_right_halo_exchange_inplace(at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
void left_right_halo_exchange_inplace(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{
auto stream = at::cuda::getCurrentCUDAStream();
ncclGroupStart();
......@@ -132,16 +132,18 @@ class NcclCommWrapper
});
}
ncclGroupEnd();
if (left_zero) left_input_halo.zero_();
if (right_zero) right_input_halo.zero_();
}
std::vector<at::Tensor> left_right_halo_exchange(at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
std::vector<at::Tensor> left_right_halo_exchange(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto right_input_halo = torch::empty_like(left_output_halo);
auto left_input_halo = torch::empty_like(right_output_halo);
left_right_halo_exchange_inplace(left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
return {left_input_halo, right_input_halo};
}
};
......@@ -195,18 +197,18 @@ void nccl_recv(int handle, at::Tensor input, int sender)
communicator.recv(input, sender);
}
void left_right_halo_exchange_inplace(int handle, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
void left_right_halo_exchange_inplace(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper& communicator = nccl_comms[handle];
return communicator.left_right_halo_exchange_inplace(left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
return communicator.left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
}
std::vector<at::Tensor> left_right_halo_exchange(int handle, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
std::vector<at::Tensor> left_right_halo_exchange(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper& communicator = nccl_comms[handle];
return communicator.left_right_halo_exchange(left_output_halo, right_output_halo, group_size);
return communicator.left_right_halo_exchange(left_zero, right_zero, left_output_halo, right_output_halo, group_size);
}
void add_delay(int delay)
......
......@@ -38,6 +38,8 @@ void nccl_recv(
);
void left_right_halo_exchange_inplace(
int handle,
bool left_zero,
bool right_zero,
at::Tensor left_output_halo,
at::Tensor right_output_halo,
at::Tensor left_input_halo,
......@@ -45,6 +47,8 @@ void left_right_halo_exchange_inplace(
int group_size);
std::vector<at::Tensor> left_right_halo_exchange(
int handle,
bool left_zero,
bool right_zero,
at::Tensor left_output_halo,
at::Tensor right_output_halo,
int group_size
......
......@@ -239,6 +239,7 @@ int64_t allocate_raw(int64_t size)
{
float* ptr = 0L;
cudaMalloc(&ptr, size);
cudaMemset(ptr, 0, size);
return (int64_t)ptr;
}
......
......@@ -53,7 +53,7 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
btm_inp_halo.copy_(top_inp_halos[btm_rank])
def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, numSM=1):
def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1):
if memory_format == 1:
# 1 -> explicit nhwc
explicit_nhwc = True
......@@ -77,10 +77,23 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
if memory_format == 2:
y = y.to(memory_format=torch.channels_last)
ym = y[:,:,:,half_halo:W+half_halo]
y2 = y.clone()
halo_ex(y, H_split, explicit_nhwc, numSM)
nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split)
is_equal = torch.all(torch.eq(y,y2))
y3 = y.clone()
list_y = []
for step in range(num_steps):
halo_ex(y, H_split, explicit_nhwc, numSM)
list_y.append(y.clone())
y.copy_(y3)
halo_ex.peer_pool.reset()
torch.distributed.barrier()
y2 = y3.clone()
list_y2 = []
for step in range(num_steps):
nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split)
list_y2.append(y2.clone())
y2.copy_(y3)
is_equal = [torch.all(torch.eq(yy,yy2)) for yy,yy2 in zip(list_y,list_y2)]
is_equal = torch.tensor(is_equal, dtype=torch.bool)
is_equal = torch.all(is_equal)
if peer_rank == 0:
if memory_format == 1:
memory_format_str = "explicit_nhwc"
......@@ -99,26 +112,26 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
torch.distributed.barrier()
def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex):
def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Hr = 8*world_size
Hp = ((H + Hr - 1) // Hr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 1, True)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 2, True)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 3, True)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 1, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 2, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 3, True, num_steps)
def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex):
def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Wr = 8*world_size
Wp = ((W + Wr - 1) // Wr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 1, False)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 2, False)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 3, False)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 1, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 2, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 3, False, num_steps)
def main():
......@@ -130,11 +143,13 @@ def main():
torch.cuda.set_device(rank)
pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024)
num_steps = 100
half_halo = 1
halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment