Commit c662c703 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Enable --peer_memory and --nccl_p2p extensions for ROCm

parent 96850dfa
...@@ -5,7 +5,11 @@ ...@@ -5,7 +5,11 @@
#include <cstdio> #include <cstdio>
#include <ctime> #include <ctime>
#include <cassert> #include <cassert>
#ifdef __HIP_PLATFORM_HCC__
#include "rccl.h"
#else
#include "nccl.h" #include "nccl.h"
#endif
/* /*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks * This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
......
...@@ -5,8 +5,15 @@ ...@@ -5,8 +5,15 @@
#include <cstdio> #include <cstdio>
#include <cassert> #include <cassert>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#include "rccl.h"
#else
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include "nccl.h" #include "nccl.h"
#endif
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
#define CUDACHECK(cmd) do { \ #define CUDACHECK(cmd) do { \
...@@ -164,22 +171,50 @@ __device__ void checked_signal( ...@@ -164,22 +171,50 @@ __device__ void checked_signal(
do { do {
do { do {
if (!top_zeroed) { if (!top_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal1_flag);
r2 = __builtin_nontemporal_load(signal1_flag + 1);
r3 = __builtin_nontemporal_load(signal1_flag + 2);
r4 = __builtin_nontemporal_load(signal1_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory"); asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true; if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
} }
if (!btm_zeroed) { if (!btm_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal2_flag);
r2 = __builtin_nontemporal_load(signal2_flag + 1);
r3 = __builtin_nontemporal_load(signal2_flag + 2);
r4 = __builtin_nontemporal_load(signal2_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory"); asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true; if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
} }
} while((top_zeroed == top_done) && (btm_zeroed == btm_done)); } while((top_zeroed == top_done) && (btm_zeroed == btm_done));
if (!top_done && top_zeroed) { if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready // signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal1_flag);
__builtin_nontemporal_store(v2, signal1_flag + 1);
__builtin_nontemporal_store(v3, signal1_flag + 2);
__builtin_nontemporal_store(v4, signal1_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
top_done = true; top_done = true;
} }
if (!btm_done && btm_zeroed) { if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready // signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal2_flag);
__builtin_nontemporal_store(v2, signal2_flag + 1);
__builtin_nontemporal_store(v3, signal2_flag + 2);
__builtin_nontemporal_store(v4, signal2_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
btm_done = true; btm_done = true;
} }
} while (!top_done || !btm_done); } while (!top_done || !btm_done);
...@@ -196,7 +231,14 @@ __device__ void wait_for( ...@@ -196,7 +231,14 @@ __device__ void wait_for(
register int r1, r2, r3, r4; register int r1, r2, r3, r4;
// wait for senders to signal their output is read // wait for senders to signal their output is read
do { do {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(wait_flag);
r2 = __builtin_nontemporal_load(wait_flag + 1);
r3 = __builtin_nontemporal_load(wait_flag + 2);
r4 = __builtin_nontemporal_load(wait_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory"); asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory");
#endif
} while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4); } while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4);
} }
cg::this_grid().sync(); // all threads wait for main cg::this_grid().sync(); // all threads wait for main
...@@ -212,7 +254,14 @@ __device__ void clear_flag( ...@@ -212,7 +254,14 @@ __device__ void clear_flag(
if (is_main_thread) { if (is_main_thread) {
register int r1, r2, r3, r4; register int r1, r2, r3, r4;
r1 = 0; r2 = 0; r3 = 0; r4 = 0; r1 = 0; r2 = 0; r3 = 0; r4 = 0;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(r1, wait_flag);
__builtin_nontemporal_store(r2, wait_flag + 1);
__builtin_nontemporal_store(r3, wait_flag + 2);
__builtin_nontemporal_store(r4, wait_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory"); asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
#endif
} }
} }
...@@ -495,7 +544,11 @@ void push_pull_halos_1d( ...@@ -495,7 +544,11 @@ void push_pull_halos_1d(
int numBlocksPerSm; int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true>, numThreads, 0); cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1); dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream); cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else { } else {
// cannot do int4 transfers // cannot do int4 transfers
if (diagnostics) printf("CAN NOT DO INT4\n"); if (diagnostics) printf("CAN NOT DO INT4\n");
...@@ -515,11 +568,19 @@ void push_pull_halos_1d( ...@@ -515,11 +568,19 @@ void push_pull_halos_1d(
if (is_nhwc) { if (is_nhwc) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true>, numThreads, 0); cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1); dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream); cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else { } else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false>, numThreads, 0); cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1); dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream); cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} }
} }
} ); } );
......
...@@ -536,9 +536,9 @@ if "--fast_bottleneck" in sys.argv: ...@@ -536,9 +536,9 @@ if "--fast_bottleneck" in sys.argv:
) )
) )
if "--peer_memory" in sys.argv: if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
sys.argv.remove("--peer_memory") if "--peer_memory" in sys.argv:
raise_if_cuda_home_none("--peer_memory") sys.argv.remove("--peer_memory")
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="peer_memory_cuda", name="peer_memory_cuda",
...@@ -550,9 +550,9 @@ if "--peer_memory" in sys.argv: ...@@ -550,9 +550,9 @@ if "--peer_memory" in sys.argv:
) )
) )
if "--nccl_p2p" in sys.argv: if "--nccl_p2p" in sys.argv or "--cuda_ext" in sys.argv:
sys.argv.remove("--nccl_p2p") if "--nccl_p2p" in sys.argv:
raise_if_cuda_home_none("--nccl_p2p") sys.argv.remove("--nccl_p2p")
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="nccl_p2p_cuda", name="nccl_p2p_cuda",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment