Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c662c703
Commit
c662c703
authored
Aug 22, 2022
by
hubertlu-tw
Browse files
Enable --peer_memory and --nccl_p2p extensions for ROCm
parent
96850dfa
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
71 additions
and
6 deletions
+71
-6
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+4
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+61
-0
setup.py
setup.py
+6
-6
No files found.
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
View file @
c662c703
...
@@ -5,7 +5,11 @@
...
@@ -5,7 +5,11 @@
#include <cstdio>
#include <cstdio>
#include <ctime>
#include <ctime>
#include <cassert>
#include <cassert>
#ifdef __HIP_PLATFORM_HCC__
#include "rccl.h"
#else
#include "nccl.h"
#include "nccl.h"
#endif
/*
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
c662c703
...
@@ -5,8 +5,15 @@
...
@@ -5,8 +5,15 @@
#include <cstdio>
#include <cstdio>
#include <cassert>
#include <cassert>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#include "rccl.h"
#else
#include <cooperative_groups.h>
#include <cooperative_groups.h>
#include "nccl.h"
#include "nccl.h"
#endif
namespace
cg
=
cooperative_groups
;
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
#define CUDACHECK(cmd) do { \
...
@@ -164,22 +171,50 @@ __device__ void checked_signal(
...
@@ -164,22 +171,50 @@ __device__ void checked_signal(
do
{
do
{
do
{
do
{
if
(
!
top_zeroed
)
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
}
if
(
!
btm_zeroed
)
{
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
if
(
!
top_done
&&
top_zeroed
)
{
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
top_done
=
true
;
}
}
if
(
!
btm_done
&&
btm_zeroed
)
{
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
btm_done
=
true
;
}
}
}
while
(
!
top_done
||
!
btm_done
);
}
while
(
!
top_done
||
!
btm_done
);
...
@@ -196,7 +231,14 @@ __device__ void wait_for(
...
@@ -196,7 +231,14 @@ __device__ void wait_for(
register
int
r1
,
r2
,
r3
,
r4
;
register
int
r1
,
r2
,
r3
,
r4
;
// wait for senders to signal their output is read
// wait for senders to signal their output is read
do
{
do
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
wait_flag
);
r2
=
__builtin_nontemporal_load
(
wait_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
wait_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
wait_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
#endif
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
}
cg
::
this_grid
().
sync
();
// all threads wait for main
cg
::
this_grid
().
sync
();
// all threads wait for main
...
@@ -212,7 +254,14 @@ __device__ void clear_flag(
...
@@ -212,7 +254,14 @@ __device__ void clear_flag(
if
(
is_main_thread
)
{
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
register
int
r1
,
r2
,
r3
,
r4
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
r1
,
wait_flag
);
__builtin_nontemporal_store
(
r2
,
wait_flag
+
1
);
__builtin_nontemporal_store
(
r3
,
wait_flag
+
2
);
__builtin_nontemporal_store
(
r4
,
wait_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
#endif
}
}
}
}
...
@@ -495,7 +544,11 @@ void push_pull_halos_1d(
...
@@ -495,7 +544,11 @@ void push_pull_halos_1d(
int
numBlocksPerSm
;
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
>
,
numThreads
,
0
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
}
else
{
// cannot do int4 transfers
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
...
@@ -515,11 +568,19 @@ void push_pull_halos_1d(
...
@@ -515,11 +568,19 @@ void push_pull_halos_1d(
if
(
is_nhwc
)
{
if
(
is_nhwc
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
numThreads
,
0
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
numThreads
,
0
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
}
}
}
);
}
);
...
...
setup.py
View file @
c662c703
...
@@ -536,9 +536,9 @@ if "--fast_bottleneck" in sys.argv:
...
@@ -536,9 +536,9 @@ if "--fast_bottleneck" in sys.argv:
)
)
)
)
if
"--peer_memory"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--peer_memory"
)
sys
.
argv
.
remove
(
"--peer_memory"
)
raise_if_cuda_home_none
(
"--peer_memory"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
name
=
"peer_memory_cuda"
,
name
=
"peer_memory_cuda"
,
...
@@ -550,9 +550,9 @@ if "--peer_memory" in sys.argv:
...
@@ -550,9 +550,9 @@ if "--peer_memory" in sys.argv:
)
)
)
)
if
"--nccl_p2p"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--nccl_p2p"
)
sys
.
argv
.
remove
(
"--nccl_p2p"
)
raise_if_cuda_home_none
(
"--nccl_p2p"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
name
=
"nccl_p2p_cuda"
,
name
=
"nccl_p2p_cuda"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment