Commit c70f0e32 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Fix deadlock issue when peer memory halo exchanger is used with cuda graph

parent d8db8c15
import functools as func
import torch
import torch.distributed as dist
from torch import nn
......@@ -8,6 +9,17 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw = tensor
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias):
scale = weight * running_var.rsqrt()
bias = bias - running_mean * scale
w_scale.copy_(scale)
w_bias.copy_(bias)
def compute_scale_bias_method(nhwc, args):
for arg in args:
# arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
compute_scale_bias_one(nhwc, *arg)
class FrozenBatchNorm2d(torch.jit.ScriptModule):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
......@@ -150,6 +162,7 @@ class Bottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn
......@@ -173,23 +186,47 @@ class Bottleneck(torch.nn.Module):
for p in self.parameters():
with torch.no_grad():
p.data = p.data.permute(0,2,3,1).contiguous()
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x):
if self.use_cudnn:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
if self.w_scale is None:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
else:
out = bottleneck_function(self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, x, *self.w_conv)
return out
if self.explicit_nhwc:
......@@ -251,7 +288,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
stream1.wait_stream(torch.cuda.current_stream())
stream3.wait_stream(torch.cuda.current_stream())
if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
......@@ -291,7 +328,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10)
#inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
......@@ -299,13 +336,26 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
torch.cuda.current_stream().wait_stream(stream1)
torch.cuda.current_stream().wait_stream(stream3)
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
elif spatial_method == 3:
fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
# compute halo cells for outputs[1] (out2)
if spatial_group_size > 1:
......@@ -405,6 +455,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if ctx.spatial_group_size > 1:
......@@ -463,16 +518,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
inc.add_delay(10)
#inc.add_delay(10)
elif ctx.spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute grad_out1 for internal cells
if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
......@@ -577,6 +626,7 @@ class SpatialBottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn
......@@ -610,6 +660,27 @@ class SpatialBottleneck(torch.nn.Module):
self.spatial_parallel_args = spatial_parallel_args
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x):
if self.use_cudnn:
if self.thresholdTop is None:
......@@ -620,19 +691,24 @@ class SpatialBottleneck(torch.nn.Module):
N,C,H,W = list(x.shape)
self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda')
self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda')
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
if self.w_scale is None:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
self.w_scale = w_scale
self.w_bias = w_bias
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
else:
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
return out
if self.explicit_nhwc:
......
......@@ -19,6 +19,7 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw");
m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw");
m.def("zero", &apex::contrib::peer_memory::zero, "zero");
m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address");
m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers");
m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half");
......
......@@ -148,33 +148,58 @@ __device__ void strided_copy_kernel(
}
}
template<bool wait, bool clear>
__device__ void dual_signal_wait_clear(
volatile int* signal1_flag, volatile int* wait1_flag,
volatile int* signal2_flag, volatile int* wait2_flag,
__device__ void checked_signal(
volatile int* signal1_flag, volatile int* signal2_flag,
const int v1, const int v2, const int v3, const int v4
)
{
register int r1, r2, r3, r4, r5, r6, r7, r8;
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
// signal and wait
if (is_main_thread) {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
if (wait) {
if (blockIdx.x == 0) {
register int r1, r2, r3, r4;
if (threadIdx.x == 0) {
// wait for top neighbor to clear bottom signal (indicating ready for new input)
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
} while (r1 == v1 && r2 == v2 && r3 == v3 && r4 == v4);
// signal to top neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
} else if (threadIdx.x == 1) {
// wait for bottom neighbor to clear top signal (indicating ready for new input)
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait1_flag) : "memory");
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r5), "=r"(r6), "=r"(r7), "=r"(r8) : "l"(wait2_flag) : "memory");
} while (r1 != v1 || r5 != v1 || r2 != v2 || r6 != v2 || r3 != v3 || r7 != v3 || r4 != v4 || r8 != v4);
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
} while (r1 == v1 && r2 == v2 && r3 == v3 && r4 == v4);
// signal to bottom neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
}
}
cg::this_grid().sync();
if (clear) {
if (is_main_thread) {
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait1_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait2_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
}
}
__device__ void wait_for(
volatile int* wait_flag,
const int v1, const int v2, const int v3, const int v4
)
{
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
// wait for senders to signal their output is read
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory");
} while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4);
}
cg::this_grid().sync(); // all threads wait for main
}
__device__ void clear_flag(
volatile int* wait_flag
)
{
cg::this_grid().sync(); // wait for all threads in kernel to finish
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
}
}
......@@ -208,11 +233,15 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
dual_signal_wait_clear<true,true>(signal1_flag, wait1_flag, signal2_flag, wait2_flag, -987751720, 840868300, -225529332, 281513358);
checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
// pull top halo from transfer buffer in peer memory to input
wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
clear_flag(wait1_flag);
// pull btm halo from transfer buffer in peer memory to input
wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
clear_flag(wait2_flag);
}
__global__ void delay_kernel(int delay_nanoseconds, int* counter)
......@@ -248,6 +277,11 @@ void free_raw(int64_t raw)
cudaFree((void*)raw);
}
void zero(int64_t raw, int64_t size)
{
cudaMemset((void*)raw, 0, size);
}
at::Tensor get_raw_ipc_address(int64_t raw)
{
cudaIpcMemHandle_t mem_handle;
......
......@@ -22,6 +22,7 @@
namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size);
void free_raw(int64_t raw);
void zero(int64_t raw, int64_t size);
at::Tensor get_raw_ipc_address(int64_t raw);
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw);
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment