Commit fd0f7631 authored by Thor Johnsen's avatar Thor Johnsen Committed by hubertlu-tw
Browse files

Fixed peer halo exchange module test

parent c662c703
......@@ -107,15 +107,10 @@ class HaloExchangerPeer(HaloExchanger):
right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True)
pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM,
left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
self.signals[self.wrap_around_left_rank_in_group], self.signals[self.wrap_around_right_rank_in_group], self.signals[self.rank_in_group]
)
# TODO: Add to push_pull_halos_1d kernel
if self.left_zero:
left_input_halo.zero_()
if self.right_zero:
right_input_halo.zero_()
if not inplace:
return left_input_halo, right_input_halo
......
......@@ -124,7 +124,20 @@ void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride
}
}
template<class T, bool is_HWC>
template<class T>
__device__ void __zero(T* dst)
{
*dst = T(0);
}
__device__ void __zero(int4* dst)
{
int4 v;
v.x = v.y = v.z = v.w = 0;
*dst = v;
}
template<class T, bool is_HWC, bool zero>
__device__ void strided_copy_kernel(
T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W,
const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W,
......@@ -138,23 +151,28 @@ __device__ void strided_copy_kernel(
{
size_t c,h,w;
if (is_HWC) {
c = i % NC;
w = i / NC;
c = i - w * NC;
h = w / NW;
w = w % NW;
w = w - h * NW;
}
else {
w = i % NW;
h = i / NW;
w = i - h * NW;
c = h / NH;
h = h % NH;
h = h - c * NH;
}
size_t dst_off = c*dst_stride_C + h*dst_stride_H + w*dst_stride_W;
if (zero) {
__zero(dst+dst_off);
} else {
size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W;
dst[dst_off] = src[src_off];
}
}
}
template<bool top_zero, bool btm_zero>
__device__ void checked_signal(
volatile int* signal1_flag, volatile int* signal2_flag,
const int v1, const int v2, const int v3, const int v4
......@@ -167,7 +185,9 @@ __device__ void checked_signal(
__threadfence_system();
// wait for top or bottom neighbor to clear signal
register int r1, r2, r3, r4;
bool top_zeroed=false, btm_zeroed=false, top_done=false, btm_done=false;
if (!(top_zero || btm_zero)) {
bool top_zeroed=false, top_done=false;
bool btm_zeroed=false, btm_done=false;
do {
do {
if (!top_zeroed) {
......@@ -218,6 +238,66 @@ __device__ void checked_signal(
btm_done = true;
}
} while (!top_done || !btm_done);
} else if (top_zero) {
bool btm_zeroed=false, btm_done=false;
do {
do {
if (!btm_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal2_flag);
r2 = __builtin_nontemporal_load(signal2_flag + 1);
r3 = __builtin_nontemporal_load(signal2_flag + 2);
r4 = __builtin_nontemporal_load(signal2_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while(btm_zeroed == btm_done);
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal2_flag);
__builtin_nontemporal_store(v2, signal2_flag + 1);
__builtin_nontemporal_store(v3, signal2_flag + 2);
__builtin_nontemporal_store(v4, signal2_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
btm_done = true;
}
} while (!btm_done);
} else if (btm_zero) {
bool top_zeroed=false, top_done=false;
do {
do {
if (!top_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal1_flag);
r2 = __builtin_nontemporal_load(signal1_flag + 1);
r3 = __builtin_nontemporal_load(signal1_flag + 2);
r4 = __builtin_nontemporal_load(signal1_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
} while(top_zeroed == top_done);
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal1_flag);
__builtin_nontemporal_store(v2, signal1_flag + 1);
__builtin_nontemporal_store(v3, signal1_flag + 2);
__builtin_nontemporal_store(v4, signal1_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
top_done = true;
}
} while (!top_done);
}
}
}
......@@ -265,8 +345,8 @@ __device__ void clear_flag(
}
}
template<class T, bool is_HWC>
#if __CUDA_ARCH__ >= 700
template<class T, bool is_HWC, bool top_zero, bool btm_zero>
#if __CUDA_ARCH__ == 700 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900
__launch_bounds__(128, 16)
#endif
__global__ void push_pull_halos_1d_kernel(
......@@ -290,20 +370,34 @@ __global__ void push_pull_halos_1d_kernel(
)
{
// push top output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW);
if (!top_zero) strided_copy_kernel<T,is_HWC,false>(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW);
// push btm output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
if (!btm_zero) strided_copy_kernel<T,is_HWC,false>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
if (!(top_zero || btm_zero)) {
checked_signal<false,false>(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
} else if (top_zero) {
checked_signal<true,false>(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
} else if (btm_zero) {
checked_signal<false,true>(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
}
// pull top halo from transfer buffer in peer memory to input
if (top_zero) {
strided_copy_kernel<T,is_HWC,true>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
} else {
wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
strided_copy_kernel<T,is_HWC,false>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
clear_flag(wait1_flag);
}
// pull btm halo from transfer buffer in peer memory to input
if (btm_zero) {
strided_copy_kernel<T,is_HWC,true>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
} else {
wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
strided_copy_kernel<T,is_HWC,false>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
clear_flag(wait2_flag);
}
}
__global__ void delay_kernel(int delay_nanoseconds, int* counter)
......@@ -392,10 +486,12 @@ void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
bool top_zero, // true if top halo should be zeroed
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
bool btm_zero, // true if btm halo should be zeroed
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
......@@ -417,6 +513,7 @@ void push_pull_halos_1d(
TORCH_CHECK(top_signal.is_cuda());
TORCH_CHECK(btm_signal.is_cuda());
TORCH_CHECK(waits.is_cuda());
TORCH_CHECK(!(top_zero && btm_zero));
// shapes and strides
int toh_N, toh_C, toh_H, toh_W;
......@@ -541,14 +638,34 @@ void push_pull_halos_1d(
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
if (top_zero) {
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true,true,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} else if (btm_zero) {
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true>, numThreads, 0);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true,false,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true,false,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
} else {
// cannot do int4 transfers
if (diagnostics) printf("CAN NOT DO INT4\n");
......@@ -566,23 +683,59 @@ void push_pull_halos_1d(
};
int numBlocksPerSm;
if (is_nhwc) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true>, numThreads, 0);
if (top_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true,true,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} else if (btm_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true,false,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true,false,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
} else {
if (top_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false,true,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,true,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,true,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} else if (btm_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false,false,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false>, numThreads, 0);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false,false,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
}
}
} );
}
......
......@@ -32,10 +32,12 @@ namespace apex { namespace contrib { namespace peer_memory {
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
bool top_zero, // true if top halo should be zeroed
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
bool btm_zero, // true if btm halo should be zeroed
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
......
from .peer_memory import PeerMemoryPool
from .peer_halo_exchanger_1d import PeerHaloExchanger1d
......@@ -40,8 +40,9 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
btm_out_halo = y[:,:,:,W:W+half_halo]
btm_inp_halo = y[:,:,:,W+half_halo:W+2*half_halo]
top_out_halo = top_out_halo.clone(memory_format=torch.preserve_format)
btm_out_halo = btm_out_halo.clone(memory_format=torch.preserve_format)
mf = torch.channels_last if y.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
top_out_halo = top_out_halo.contiguous()
btm_out_halo = btm_out_halo.contiguous()
top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(top_inp_halos, top_out_halo)
......@@ -49,8 +50,14 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
torch.distributed.all_gather(btm_inp_halos, btm_out_halo)
top_rank = (peer_rank + peer_group_size - 1) % peer_group_size
btm_rank = (peer_rank + 1) % peer_group_size
top_inp_halo.copy_(btm_inp_halos[top_rank])
btm_inp_halo.copy_(top_inp_halos[btm_rank])
if peer_rank == 0:
top_inp_halo.zero_()
else:
top_inp_halo.copy_(btm_inp_halos[top_rank].to(memory_format=mf))
if peer_rank == peer_group_size-1:
btm_inp_halo.zero_()
else:
btm_inp_halo.copy_(top_inp_halos[btm_rank].to(memory_format=mf))
def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1):
......@@ -141,12 +148,13 @@ def main():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024)
peer_ranks = [i for i in range(world_size)]
pool = PeerMemoryPool(64*1024, 2*1024*1024, peer_ranks)
num_steps = 100
half_halo = 1
halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo)
halo_ex = PeerHaloExchanger1d(peer_ranks, rank, pool, half_halo)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps)
......
......@@ -3,9 +3,15 @@ from apex.contrib.peer_memory import PeerMemoryPool
import peer_memory_cuda as pm
class PeerHaloExchanger1d:
def __init__(self, rank, peer_group_size, peer_pool, half_halo):
self.peer_group_size = peer_group_size
self.peer_rank = rank % peer_group_size
def __init__(self, ranks, rank_in_group, peer_pool, half_halo):
self.peer_group_size = len(ranks)
self.ranks = ranks
self.peer_rank = rank_in_group
self.low_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
self.high_neighbor = (self.peer_rank + 1) % self.peer_group_size
self.low_zero = True if self.peer_rank == 0 else False
self.high_zero = True if self.peer_rank == self.peer_group_size - 1 else False
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
......@@ -17,45 +23,43 @@ class PeerHaloExchanger1d:
if explicit_nhwc:
_, Hs, _, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,self.half_halo:2*self.half_halo,:,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:self.half_halo,:,:]
btm_out_halo = y[:,H:H+self.half_halo,:,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:]
low_out_halo = y[:,self.half_halo:2*self.half_halo,:,:]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True)
low_inp_halo = y[:,:self.half_halo,:,:]
high_out_halo = y[:,H:H+self.half_halo,:,:]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True)
high_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:]
else:
_, _, Hs, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,H:H+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:]
low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True)
low_inp_halo = y[:,:,:self.half_halo,:]
high_out_halo = y[:,:,H:H+self.half_halo,:]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True)
high_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:]
else:
if explicit_nhwc:
_, _, Ws, _ = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,W:W+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:]
low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True)
low_inp_halo = y[:,:,:self.half_halo,:]
high_out_halo = y[:,:,W:W+self.half_halo,:]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True)
high_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:]
else:
_, _, _, Ws = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,:,self.half_halo:2*self.half_halo]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:,:self.half_halo]
btm_out_halo = y[:,:,:,W:W+self.half_halo]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo]
top_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
btm_neighbor = (self.peer_rank + 1) % self.peer_group_size
low_out_halo = y[:,:,:,self.half_halo:2*self.half_halo]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True)
low_inp_halo = y[:,:,:,:self.half_halo]
high_out_halo = y[:,:,:,W:W+self.half_halo]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True)
high_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo]
pm.push_pull_halos_1d(
diagnostics, explicit_nhwc, numSM,
top_out_halo, top_tx[self.peer_rank], btm_tx[top_neighbor], top_inp_halo,
btm_out_halo, btm_tx[self.peer_rank], top_tx[btm_neighbor], btm_inp_halo,
self.signals[top_neighbor], self.signals[btm_neighbor], self.signals[self.peer_rank]
self.low_zero, low_out_halo, low_tx[self.peer_rank], high_tx[self.low_neighbor], low_inp_halo,
self.high_zero, high_out_halo, high_tx[self.peer_rank], low_tx[self.high_neighbor], high_inp_halo,
self.signals[self.low_neighbor], self.signals[self.high_neighbor], self.signals[self.peer_rank]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment