Unverified Commit a53b4417 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #87 from ROCmSoftwarePlatform/dev/hubertlu/apex_peer_memory_nccl_p2p

Enable --peer_memory and --nccl p2p extensions for ROCm
parents a27b4e43 bc64ee83
......@@ -107,15 +107,10 @@ class HaloExchangerPeer(HaloExchanger):
right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True)
pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM,
left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
self.signals[self.wrap_around_left_rank_in_group], self.signals[self.wrap_around_right_rank_in_group], self.signals[self.rank_in_group]
)
# TODO: Add to push_pull_halos_1d kernel
if self.left_zero:
left_input_halo.zero_()
if self.right_zero:
right_input_halo.zero_()
if not inplace:
return left_input_halo, right_input_halo
......
......@@ -5,7 +5,11 @@
#include <cstdio>
#include <ctime>
#include <cassert>
#ifdef __HIP_PLATFORM_HCC__
#include "rccl.h"
#else
#include "nccl.h"
#endif
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
......
......@@ -5,8 +5,15 @@
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#include "rccl.h"
#else
#include <cooperative_groups.h>
#include "nccl.h"
#endif
namespace cg = cooperative_groups;
#define CUDACHECK(cmd) do { \
......@@ -117,7 +124,20 @@ void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride
}
}
template<class T, bool is_HWC>
template<class T>
__device__ void __zero(T* dst)
{
*dst = T(0);
}
__device__ void __zero(int4* dst)
{
int4 v;
v.x = v.y = v.z = v.w = 0;
*dst = v;
}
template<class T, bool is_HWC, bool zero>
__device__ void strided_copy_kernel(
T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W,
const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W,
......@@ -131,23 +151,28 @@ __device__ void strided_copy_kernel(
{
size_t c,h,w;
if (is_HWC) {
c = i % NC;
w = i / NC;
c = i - w * NC;
h = w / NW;
w = w % NW;
w = w - h * NW;
}
else {
w = i % NW;
h = i / NW;
w = i - h * NW;
c = h / NH;
h = h % NH;
h = h - c * NH;
}
size_t dst_off = c*dst_stride_C + h*dst_stride_H + w*dst_stride_W;
size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W;
dst[dst_off] = src[src_off];
if (zero) {
__zero(dst+dst_off);
} else {
size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W;
dst[dst_off] = src[src_off];
}
}
}
template<bool top_zero, bool btm_zero>
__device__ void checked_signal(
volatile int* signal1_flag, volatile int* signal2_flag,
const int v1, const int v2, const int v3, const int v4
......@@ -160,29 +185,119 @@ __device__ void checked_signal(
__threadfence_system();
// wait for top or bottom neighbor to clear signal
register int r1, r2, r3, r4;
bool top_zeroed=false, btm_zeroed=false, top_done=false, btm_done=false;
do {
if (!(top_zero || btm_zero)) {
bool top_zeroed=false, top_done=false;
bool btm_zeroed=false, btm_done=false;
do {
if (!top_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
do {
if (!top_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal1_flag);
r2 = __builtin_nontemporal_load(signal1_flag + 1);
r3 = __builtin_nontemporal_load(signal1_flag + 2);
r4 = __builtin_nontemporal_load(signal1_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
if (!btm_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal2_flag);
r2 = __builtin_nontemporal_load(signal2_flag + 1);
r3 = __builtin_nontemporal_load(signal2_flag + 2);
r4 = __builtin_nontemporal_load(signal2_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while((top_zeroed == top_done) && (btm_zeroed == btm_done));
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal1_flag);
__builtin_nontemporal_store(v2, signal1_flag + 1);
__builtin_nontemporal_store(v3, signal1_flag + 2);
__builtin_nontemporal_store(v4, signal1_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
top_done = true;
}
if (!btm_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal2_flag);
__builtin_nontemporal_store(v2, signal2_flag + 1);
__builtin_nontemporal_store(v3, signal2_flag + 2);
__builtin_nontemporal_store(v4, signal2_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
btm_done = true;
}
} while((top_zeroed == top_done) && (btm_zeroed == btm_done));
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
top_done = true;
}
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
btm_done = true;
}
} while (!top_done || !btm_done);
} while (!top_done || !btm_done);
} else if (top_zero) {
bool btm_zeroed=false, btm_done=false;
do {
do {
if (!btm_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal2_flag);
r2 = __builtin_nontemporal_load(signal2_flag + 1);
r3 = __builtin_nontemporal_load(signal2_flag + 2);
r4 = __builtin_nontemporal_load(signal2_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while(btm_zeroed == btm_done);
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal2_flag);
__builtin_nontemporal_store(v2, signal2_flag + 1);
__builtin_nontemporal_store(v3, signal2_flag + 2);
__builtin_nontemporal_store(v4, signal2_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
btm_done = true;
}
} while (!btm_done);
} else if (btm_zero) {
bool top_zeroed=false, top_done=false;
do {
do {
if (!top_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal1_flag);
r2 = __builtin_nontemporal_load(signal1_flag + 1);
r3 = __builtin_nontemporal_load(signal1_flag + 2);
r4 = __builtin_nontemporal_load(signal1_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
} while(top_zeroed == top_done);
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal1_flag);
__builtin_nontemporal_store(v2, signal1_flag + 1);
__builtin_nontemporal_store(v3, signal1_flag + 2);
__builtin_nontemporal_store(v4, signal1_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
top_done = true;
}
} while (!top_done);
}
}
}
......@@ -196,7 +311,14 @@ __device__ void wait_for(
register int r1, r2, r3, r4;
// wait for senders to signal their output is read
do {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(wait_flag);
r2 = __builtin_nontemporal_load(wait_flag + 1);
r3 = __builtin_nontemporal_load(wait_flag + 2);
r4 = __builtin_nontemporal_load(wait_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory");
#endif
} while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4);
}
cg::this_grid().sync(); // all threads wait for main
......@@ -212,12 +334,19 @@ __device__ void clear_flag(
if (is_main_thread) {
register int r1, r2, r3, r4;
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(r1, wait_flag);
__builtin_nontemporal_store(r2, wait_flag + 1);
__builtin_nontemporal_store(r3, wait_flag + 2);
__builtin_nontemporal_store(r4, wait_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
#endif
}
}
template<class T, bool is_HWC>
#if __CUDA_ARCH__ >= 700
template<class T, bool is_HWC, bool top_zero, bool btm_zero>
#if __CUDA_ARCH__ == 700 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900
__launch_bounds__(128, 16)
#endif
__global__ void push_pull_halos_1d_kernel(
......@@ -241,20 +370,34 @@ __global__ void push_pull_halos_1d_kernel(
)
{
// push top output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW);
if (!top_zero) strided_copy_kernel<T,is_HWC,false>(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW);
// push btm output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
if (!btm_zero) strided_copy_kernel<T,is_HWC,false>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
if (!(top_zero || btm_zero)) {
checked_signal<false,false>(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
} else if (top_zero) {
checked_signal<true,false>(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
} else if (btm_zero) {
checked_signal<false,true>(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
}
// pull top halo from transfer buffer in peer memory to input
wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
clear_flag(wait1_flag);
if (top_zero) {
strided_copy_kernel<T,is_HWC,true>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
} else {
wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC,false>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
clear_flag(wait1_flag);
}
// pull btm halo from transfer buffer in peer memory to input
wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
clear_flag(wait2_flag);
if (btm_zero) {
strided_copy_kernel<T,is_HWC,true>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
} else {
wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC,false>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
clear_flag(wait2_flag);
}
}
__global__ void delay_kernel(int delay_nanoseconds, int* counter)
......@@ -343,10 +486,12 @@ void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
bool top_zero, // true if top halo should be zeroed
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
bool btm_zero, // true if btm halo should be zeroed
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
......@@ -368,6 +513,7 @@ void push_pull_halos_1d(
TORCH_CHECK(top_signal.is_cuda());
TORCH_CHECK(btm_signal.is_cuda());
TORCH_CHECK(waits.is_cuda());
TORCH_CHECK(!(top_zero && btm_zero));
// shapes and strides
int toh_N, toh_C, toh_H, toh_W;
......@@ -492,10 +638,34 @@ void push_pull_halos_1d(
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
if (top_zero) {
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true,true,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} else if (btm_zero) {
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true,false,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true,false,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
} else {
// cannot do int4 transfers
if (diagnostics) printf("CAN NOT DO INT4\n");
......@@ -513,13 +683,57 @@ void push_pull_halos_1d(
};
int numBlocksPerSm;
if (is_nhwc) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
if (top_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true,true,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} else if (btm_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true,false,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true,false,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
if (top_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false,true,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,true,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,true,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} else if (btm_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false,false,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false,false,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
}
}
} );
......
......@@ -32,10 +32,12 @@ namespace apex { namespace contrib { namespace peer_memory {
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
bool top_zero, // true if top halo should be zeroed
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
bool btm_zero, // true if btm halo should be zeroed
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
......
from .peer_memory import PeerMemoryPool
from .peer_halo_exchanger_1d import PeerHaloExchanger1d
......@@ -40,8 +40,9 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
btm_out_halo = y[:,:,:,W:W+half_halo]
btm_inp_halo = y[:,:,:,W+half_halo:W+2*half_halo]
top_out_halo = top_out_halo.clone(memory_format=torch.preserve_format)
btm_out_halo = btm_out_halo.clone(memory_format=torch.preserve_format)
mf = torch.channels_last if y.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
top_out_halo = top_out_halo.contiguous()
btm_out_halo = btm_out_halo.contiguous()
top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(top_inp_halos, top_out_halo)
......@@ -49,8 +50,14 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
torch.distributed.all_gather(btm_inp_halos, btm_out_halo)
top_rank = (peer_rank + peer_group_size - 1) % peer_group_size
btm_rank = (peer_rank + 1) % peer_group_size
top_inp_halo.copy_(btm_inp_halos[top_rank])
btm_inp_halo.copy_(top_inp_halos[btm_rank])
if peer_rank == 0:
top_inp_halo.zero_()
else:
top_inp_halo.copy_(btm_inp_halos[top_rank].to(memory_format=mf))
if peer_rank == peer_group_size-1:
btm_inp_halo.zero_()
else:
btm_inp_halo.copy_(top_inp_halos[btm_rank].to(memory_format=mf))
def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1):
......@@ -141,12 +148,13 @@ def main():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024)
peer_ranks = [i for i in range(world_size)]
pool = PeerMemoryPool(64*1024, 2*1024*1024, peer_ranks)
num_steps = 100
half_halo = 1
halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo)
halo_ex = PeerHaloExchanger1d(peer_ranks, rank, pool, half_halo)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps)
......
......@@ -3,9 +3,15 @@ from apex.contrib.peer_memory import PeerMemoryPool
import peer_memory_cuda as pm
class PeerHaloExchanger1d:
def __init__(self, rank, peer_group_size, peer_pool, half_halo):
self.peer_group_size = peer_group_size
self.peer_rank = rank % peer_group_size
def __init__(self, ranks, rank_in_group, peer_pool, half_halo):
self.peer_group_size = len(ranks)
self.ranks = ranks
self.peer_rank = rank_in_group
self.low_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
self.high_neighbor = (self.peer_rank + 1) % self.peer_group_size
self.low_zero = True if self.peer_rank == 0 else False
self.high_zero = True if self.peer_rank == self.peer_group_size - 1 else False
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
......@@ -17,45 +23,43 @@ class PeerHaloExchanger1d:
if explicit_nhwc:
_, Hs, _, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,self.half_halo:2*self.half_halo,:,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:self.half_halo,:,:]
btm_out_halo = y[:,H:H+self.half_halo,:,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:]
low_out_halo = y[:,self.half_halo:2*self.half_halo,:,:]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True)
low_inp_halo = y[:,:self.half_halo,:,:]
high_out_halo = y[:,H:H+self.half_halo,:,:]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True)
high_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:]
else:
_, _, Hs, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,H:H+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:]
low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True)
low_inp_halo = y[:,:,:self.half_halo,:]
high_out_halo = y[:,:,H:H+self.half_halo,:]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True)
high_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:]
else:
if explicit_nhwc:
_, _, Ws, _ = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,W:W+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:]
low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True)
low_inp_halo = y[:,:,:self.half_halo,:]
high_out_halo = y[:,:,W:W+self.half_halo,:]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True)
high_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:]
else:
_, _, _, Ws = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,:,self.half_halo:2*self.half_halo]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:,:self.half_halo]
btm_out_halo = y[:,:,:,W:W+self.half_halo]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo]
top_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
btm_neighbor = (self.peer_rank + 1) % self.peer_group_size
low_out_halo = y[:,:,:,self.half_halo:2*self.half_halo]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True)
low_inp_halo = y[:,:,:,:self.half_halo]
high_out_halo = y[:,:,:,W:W+self.half_halo]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True)
high_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo]
pm.push_pull_halos_1d(
diagnostics, explicit_nhwc, numSM,
top_out_halo, top_tx[self.peer_rank], btm_tx[top_neighbor], top_inp_halo,
btm_out_halo, btm_tx[self.peer_rank], top_tx[btm_neighbor], btm_inp_halo,
self.signals[top_neighbor], self.signals[btm_neighbor], self.signals[self.peer_rank]
self.low_zero, low_out_halo, low_tx[self.peer_rank], high_tx[self.low_neighbor], low_inp_halo,
self.high_zero, high_out_halo, high_tx[self.peer_rank], low_tx[self.high_neighbor], high_inp_halo,
self.signals[self.low_neighbor], self.signals[self.high_neighbor], self.signals[self.peer_rank]
)
......@@ -42,6 +42,55 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+ "In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def raise_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
raise RuntimeError(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
cudnn_available = torch.backends.cudnn.is_available()
cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
warnings.warn(
f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later, "
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
)
return False
return True
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
......@@ -536,9 +585,13 @@ if "--fast_bottleneck" in sys.argv:
)
)
if "--peer_memory" in sys.argv:
sys.argv.remove("--peer_memory")
raise_if_cuda_home_none("--peer_memory")
if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
if "--peer_memory" in sys.argv:
sys.argv.remove("--peer_memory")
if not IS_ROCM_PYTORCH:
raise_if_cuda_home_none("--peer_memory")
ext_modules.append(
CUDAExtension(
name="peer_memory_cuda",
......@@ -550,9 +603,13 @@ if "--peer_memory" in sys.argv:
)
)
if "--nccl_p2p" in sys.argv:
sys.argv.remove("--nccl_p2p")
raise_if_cuda_home_none("--nccl_p2p")
if "--nccl_p2p" in sys.argv or "--cuda_ext" in sys.argv:
if "--nccl_p2p" in sys.argv:
sys.argv.remove("--nccl_p2p")
if not IS_ROCM_PYTORCH:
raise_if_cuda_home_none("--nccl_p2p")
ext_modules.append(
CUDAExtension(
name="nccl_p2p_cuda",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment