Commit c662c703 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Enable --peer_memory and --nccl_p2p extensions for ROCm

parent 96850dfa
......@@ -5,7 +5,11 @@
#include <cstdio>
#include <ctime>
#include <cassert>
#ifdef __HIP_PLATFORM_HCC__
#include "rccl.h"
#else
#include "nccl.h"
#endif
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
......
......@@ -5,8 +5,15 @@
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#include "rccl.h"
#else
#include <cooperative_groups.h>
#include "nccl.h"
#endif
namespace cg = cooperative_groups;
#define CUDACHECK(cmd) do { \
......@@ -164,22 +171,50 @@ __device__ void checked_signal(
do {
do {
if (!top_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal1_flag);
r2 = __builtin_nontemporal_load(signal1_flag + 1);
r3 = __builtin_nontemporal_load(signal1_flag + 2);
r4 = __builtin_nontemporal_load(signal1_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
if (!btm_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal2_flag);
r2 = __builtin_nontemporal_load(signal2_flag + 1);
r3 = __builtin_nontemporal_load(signal2_flag + 2);
r4 = __builtin_nontemporal_load(signal2_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while((top_zeroed == top_done) && (btm_zeroed == btm_done));
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal1_flag);
__builtin_nontemporal_store(v2, signal1_flag + 1);
__builtin_nontemporal_store(v3, signal1_flag + 2);
__builtin_nontemporal_store(v4, signal1_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
top_done = true;
}
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal2_flag);
__builtin_nontemporal_store(v2, signal2_flag + 1);
__builtin_nontemporal_store(v3, signal2_flag + 2);
__builtin_nontemporal_store(v4, signal2_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
btm_done = true;
}
} while (!top_done || !btm_done);
......@@ -196,7 +231,14 @@ __device__ void wait_for(
register int r1, r2, r3, r4;
// wait for senders to signal their output is read
do {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(wait_flag);
r2 = __builtin_nontemporal_load(wait_flag + 1);
r3 = __builtin_nontemporal_load(wait_flag + 2);
r4 = __builtin_nontemporal_load(wait_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory");
#endif
} while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4);
}
cg::this_grid().sync(); // all threads wait for main
......@@ -212,7 +254,14 @@ __device__ void clear_flag(
if (is_main_thread) {
register int r1, r2, r3, r4;
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(r1, wait_flag);
__builtin_nontemporal_store(r2, wait_flag + 1);
__builtin_nontemporal_store(r3, wait_flag + 2);
__builtin_nontemporal_store(r4, wait_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
#endif
}
}
......@@ -495,7 +544,11 @@ void push_pull_halos_1d(
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
// cannot do int4 transfers
if (diagnostics) printf("CAN NOT DO INT4\n");
......@@ -515,11 +568,19 @@ void push_pull_halos_1d(
if (is_nhwc) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
}
} );
......
......@@ -536,9 +536,9 @@ if "--fast_bottleneck" in sys.argv:
)
)
if "--peer_memory" in sys.argv:
sys.argv.remove("--peer_memory")
raise_if_cuda_home_none("--peer_memory")
if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
if "--peer_memory" in sys.argv:
sys.argv.remove("--peer_memory")
ext_modules.append(
CUDAExtension(
name="peer_memory_cuda",
......@@ -550,9 +550,9 @@ if "--peer_memory" in sys.argv:
)
)
if "--nccl_p2p" in sys.argv:
sys.argv.remove("--nccl_p2p")
raise_if_cuda_home_none("--nccl_p2p")
if "--nccl_p2p" in sys.argv or "--cuda_ext" in sys.argv:
if "--nccl_p2p" in sys.argv:
sys.argv.remove("--nccl_p2p")
ext_modules.append(
CUDAExtension(
name="nccl_p2p_cuda",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment