Commit c70f0e32 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Fix deadlock issue when peer memory halo exchanger is used with cuda graph

parent d8db8c15
import functools as func
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
...@@ -8,6 +9,17 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): ...@@ -8,6 +9,17 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw = tensor weight_tensor_nchw = tensor
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias):
scale = weight * running_var.rsqrt()
bias = bias - running_mean * scale
w_scale.copy_(scale)
w_bias.copy_(bias)
def compute_scale_bias_method(nhwc, args):
for arg in args:
# arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
compute_scale_bias_one(nhwc, *arg)
class FrozenBatchNorm2d(torch.jit.ScriptModule): class FrozenBatchNorm2d(torch.jit.ScriptModule):
""" """
BatchNorm2d where the batch statistics and the affine parameters are fixed BatchNorm2d where the batch statistics and the affine parameters are fixed
...@@ -150,6 +162,7 @@ class Bottleneck(torch.nn.Module): ...@@ -150,6 +162,7 @@ class Bottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels) self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels) self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn self.use_cudnn = use_cudnn
...@@ -173,23 +186,47 @@ class Bottleneck(torch.nn.Module): ...@@ -173,23 +186,47 @@ class Bottleneck(torch.nn.Module):
for p in self.parameters(): for p in self.parameters():
with torch.no_grad(): with torch.no_grad():
p.data = p.data.permute(0,2,3,1).contiguous() p.data = p.data.permute(0,2,3,1).contiguous()
return return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x): def forward(self, x):
if self.use_cudnn: if self.use_cudnn:
# calculate scale/bias from registered buffers if self.w_scale is None:
# TODO: make this better # calculate scale/bias from registered buffers
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) # TODO: make this better
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc) s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc) s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3] s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_bias = [b1, b2, b3] w_scale = [s1, s2, s3]
if self.downsample is not None: w_bias = [b1, b2, b3]
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) if self.downsample is not None:
w_scale.append(s4) s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_bias.append(b4) w_scale.append(s4)
w_bias.append(b4)
out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
else:
out = bottleneck_function(self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, x, *self.w_conv)
return out return out
if self.explicit_nhwc: if self.explicit_nhwc:
...@@ -251,7 +288,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -251,7 +288,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format) out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
stream1.wait_stream(torch.cuda.current_stream()) stream1.wait_stream(torch.cuda.current_stream())
stream3.wait_stream(torch.cuda.current_stream()) if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream3): with torch.cuda.stream(stream3):
if explicit_nhwc: if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1) out1_pad[:,1:Hs+1,:,:].copy_(out1)
...@@ -291,7 +328,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -291,7 +328,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_fat_halo[:,:,:1,:].copy_(top_out1_halo) top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:]) top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args) top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10) #inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3: elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3" assert(False), "spatial_method must be 1, 2 or 3"
...@@ -299,13 +336,26 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -299,13 +336,26 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 1: elif spatial_method == 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
elif spatial_method == 2: elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x # wait for halo transfer to finish before doing a full convolution of padded x
torch.cuda.current_stream().wait_stream(stream1) torch.cuda.current_stream().wait_stream(stream1)
torch.cuda.current_stream().wait_stream(stream3) if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad) fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
elif spatial_method == 3: elif spatial_method == 3:
fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom) fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
# compute halo cells for outputs[1] (out2) # compute halo cells for outputs[1] (out2)
if spatial_group_size > 1: if spatial_group_size > 1:
...@@ -405,6 +455,11 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -405,6 +455,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads) grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
wgrad2_stream = torch.cuda.Stream() wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream()) wgrad2_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# do halo exchange of grad_out2 here # do halo exchange of grad_out2 here
# compute halo cells for grad_out1 # compute halo cells for grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
...@@ -463,16 +518,10 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -463,16 +518,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else: else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:] top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
inc.add_delay(10) #inc.add_delay(10)
elif ctx.spatial_method != 3: elif ctx.spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3" assert(False), "spatial_method must be 1, 2 or 3"
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute grad_out1 for internal cells # compute grad_out1 for internal cells
if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2: if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2) grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
...@@ -577,6 +626,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -577,6 +626,7 @@ class SpatialBottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels) self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels) self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn self.use_cudnn = use_cudnn
...@@ -610,6 +660,27 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -610,6 +660,27 @@ class SpatialBottleneck(torch.nn.Module):
self.spatial_parallel_args = spatial_parallel_args self.spatial_parallel_args = spatial_parallel_args
return return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x): def forward(self, x):
if self.use_cudnn: if self.use_cudnn:
if self.thresholdTop is None: if self.thresholdTop is None:
...@@ -620,19 +691,24 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -620,19 +691,24 @@ class SpatialBottleneck(torch.nn.Module):
N,C,H,W = list(x.shape) N,C,H,W = list(x.shape)
self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda') self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda')
self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda') self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda')
# calculate scale/bias from registered buffers
# TODO: make this better if self.w_scale is None:
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) # calculate scale/bias from registered buffers
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc) # TODO: make this better
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc) s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3] s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
w_bias = [b1, b2, b3] s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
if self.downsample is not None: w_scale = [s1, s2, s3]
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) w_bias = [b1, b2, b3]
w_scale.append(s4) if self.downsample is not None:
w_bias.append(b4) s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv) w_bias.append(b4)
self.w_scale = w_scale
self.w_bias = w_bias
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
else:
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
return out return out
if self.explicit_nhwc: if self.explicit_nhwc:
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw"); m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw");
m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw"); m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw");
m.def("zero", &apex::contrib::peer_memory::zero, "zero");
m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address"); m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address");
m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers"); m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers");
m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half"); m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half");
......
...@@ -148,33 +148,58 @@ __device__ void strided_copy_kernel( ...@@ -148,33 +148,58 @@ __device__ void strided_copy_kernel(
} }
} }
template<bool wait, bool clear> __device__ void checked_signal(
__device__ void dual_signal_wait_clear( volatile int* signal1_flag, volatile int* signal2_flag,
volatile int* signal1_flag, volatile int* wait1_flag,
volatile int* signal2_flag, volatile int* wait2_flag,
const int v1, const int v2, const int v3, const int v4 const int v1, const int v2, const int v3, const int v4
) )
{ {
register int r1, r2, r3, r4, r5, r6, r7, r8; if (blockIdx.x == 0) {
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false; register int r1, r2, r3, r4;
// signal and wait if (threadIdx.x == 0) {
if (is_main_thread) { // wait for top neighbor to clear bottom signal (indicating ready for new input)
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); do {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
if (wait) { } while (r1 == v1 && r2 == v2 && r3 == v3 && r4 == v4);
// signal to top neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
} else if (threadIdx.x == 1) {
// wait for bottom neighbor to clear top signal (indicating ready for new input)
do { do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait1_flag) : "memory"); asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r5), "=r"(r6), "=r"(r7), "=r"(r8) : "l"(wait2_flag) : "memory"); } while (r1 == v1 && r2 == v2 && r3 == v3 && r4 == v4);
} while (r1 != v1 || r5 != v1 || r2 != v2 || r6 != v2 || r3 != v3 || r7 != v3 || r4 != v4 || r8 != v4); // signal to bottom neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
} }
} }
cg::this_grid().sync(); }
if (clear) {
if (is_main_thread) { __device__ void wait_for(
r1 = 0; r2 = 0; r3 = 0; r4 = 0; volatile int* wait_flag,
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait1_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory"); const int v1, const int v2, const int v3, const int v4
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait2_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory"); )
} {
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
// wait for senders to signal their output is read
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory");
} while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4);
}
cg::this_grid().sync(); // all threads wait for main
}
__device__ void clear_flag(
volatile int* wait_flag
)
{
cg::this_grid().sync(); // wait for all threads in kernel to finish
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
} }
} }
...@@ -208,11 +233,15 @@ __global__ void push_pull_halos_1d_kernel( ...@@ -208,11 +233,15 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW); strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
// signal to top and btm neigbhbors that output halos are ready to be read // signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values // the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
dual_signal_wait_clear<true,true>(signal1_flag, wait1_flag, signal2_flag, wait2_flag, -987751720, 840868300, -225529332, 281513358); checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
// pull top halo from transfer buffer in peer memory to input // pull top halo from transfer buffer in peer memory to input
wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW); strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
clear_flag(wait1_flag);
// pull btm halo from transfer buffer in peer memory to input // pull btm halo from transfer buffer in peer memory to input
wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW); strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
clear_flag(wait2_flag);
} }
__global__ void delay_kernel(int delay_nanoseconds, int* counter) __global__ void delay_kernel(int delay_nanoseconds, int* counter)
...@@ -248,6 +277,11 @@ void free_raw(int64_t raw) ...@@ -248,6 +277,11 @@ void free_raw(int64_t raw)
cudaFree((void*)raw); cudaFree((void*)raw);
} }
void zero(int64_t raw, int64_t size)
{
cudaMemset((void*)raw, 0, size);
}
at::Tensor get_raw_ipc_address(int64_t raw) at::Tensor get_raw_ipc_address(int64_t raw)
{ {
cudaIpcMemHandle_t mem_handle; cudaIpcMemHandle_t mem_handle;
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
namespace apex { namespace contrib { namespace peer_memory { namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size); int64_t allocate_raw(int64_t size);
void free_raw(int64_t raw); void free_raw(int64_t raw);
void zero(int64_t raw, int64_t size);
at::Tensor get_raw_ipc_address(int64_t raw); at::Tensor get_raw_ipc_address(int64_t raw);
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw); std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw);
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last); at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment