Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
fd0f7631
Commit
fd0f7631
authored
Aug 15, 2022
by
Thor Johnsen
Committed by
hubertlu-tw
Aug 22, 2022
Browse files
Fixed peer halo exchange module test
parent
c662c703
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
276 additions
and
113 deletions
+276
-113
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+2
-7
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+221
-68
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+2
-0
apex/contrib/peer_memory/__init__.py
apex/contrib/peer_memory/__init__.py
+1
-0
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
+14
-6
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
+36
-32
No files found.
apex/contrib/bottleneck/halo_exchangers.py
View file @
fd0f7631
...
...
@@ -107,15 +107,10 @@ class HaloExchangerPeer(HaloExchanger):
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_output_halo
.
shape
),
right_output_halo
.
dtype
,
channels_last
,
True
)
pm
.
push_pull_halos_1d
(
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
left_zero
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
self
.
right_zero
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
signals
[
self
.
wrap_around_left_rank_in_group
],
self
.
signals
[
self
.
wrap_around_right_rank_in_group
],
self
.
signals
[
self
.
rank_in_group
]
)
# TODO: Add to push_pull_halos_1d kernel
if
self
.
left_zero
:
left_input_halo
.
zero_
()
if
self
.
right_zero
:
right_input_halo
.
zero_
()
if
not
inplace
:
return
left_input_halo
,
right_input_halo
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
fd0f7631
...
...
@@ -124,7 +124,20 @@ void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride
}
}
template
<
class
T
,
bool
is_HWC
>
template
<
class
T
>
__device__
void
__zero
(
T
*
dst
)
{
*
dst
=
T
(
0
);
}
__device__
void
__zero
(
int4
*
dst
)
{
int4
v
;
v
.
x
=
v
.
y
=
v
.
z
=
v
.
w
=
0
;
*
dst
=
v
;
}
template
<
class
T
,
bool
is_HWC
,
bool
zero
>
__device__
void
strided_copy_kernel
(
T
*
dst
,
const
int
dst_stride_C
,
const
int
dst_stride_H
,
const
int
dst_stride_W
,
const
T
*
src
,
const
int
src_stride_C
,
const
int
src_stride_H
,
const
int
src_stride_W
,
...
...
@@ -138,23 +151,28 @@ __device__ void strided_copy_kernel(
{
size_t
c
,
h
,
w
;
if
(
is_HWC
)
{
c
=
i
%
NC
;
w
=
i
/
NC
;
c
=
i
-
w
*
NC
;
h
=
w
/
NW
;
w
=
w
%
NW
;
w
=
w
-
h
*
NW
;
}
else
{
w
=
i
%
NW
;
h
=
i
/
NW
;
w
=
i
-
h
*
NW
;
c
=
h
/
NH
;
h
=
h
%
NH
;
h
=
h
-
c
*
NH
;
}
size_t
dst_off
=
c
*
dst_stride_C
+
h
*
dst_stride_H
+
w
*
dst_stride_W
;
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
dst
[
dst_off
]
=
src
[
src_off
];
if
(
zero
)
{
__zero
(
dst
+
dst_off
);
}
else
{
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
dst
[
dst_off
]
=
src
[
src_off
];
}
}
}
template
<
bool
top_zero
,
bool
btm_zero
>
__device__
void
checked_signal
(
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
...
...
@@ -167,57 +185,119 @@ __device__ void checked_signal(
__threadfence_system
();
// wait for top or bottom neighbor to clear signal
register
int
r1
,
r2
,
r3
,
r4
;
bool
top_zeroed
=
false
,
btm_zeroed
=
false
,
top_done
=
false
,
btm_done
=
false
;
do
{
if
(
!
(
top_zero
||
btm_zero
))
{
bool
top_zeroed
=
false
,
top_done
=
false
;
bool
btm_zeroed
=
false
,
btm_done
=
false
;
do
{
if
(
!
top_zeroed
)
{
do
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_
load
(
signal1_flag
);
r2
=
__builtin_nontemporal_
load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_
load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_
load
(
signal1_flag
+
3
);
__builtin_nontemporal_
store
(
v1
,
signal1_flag
);
__builtin_nontemporal_
store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_
store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_
store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"
ld
.volatile.global.v4.u32
{
%0
,
%1,%2,%3
}, [
%4
]
;"
:
"
=r"
(
r1
),
"
=
r"
(
r2
),
"
=
r"
(
r3
),
"
=
r"
(
r4
)
:
"
l
"
(
signal1_flag
)
:
"memory"
);
asm
volatile
(
"
st
.volatile.global.v4.u32
[
%0
], {
%1,%2,%3
,
%4
}
;"
:
:
"
l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"
r
"
(
v4
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
top_done
=
true
;
}
if
(
!
btm_zeroed
)
{
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_
load
(
signal2_flag
);
r2
=
__builtin_nontemporal_
load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_
load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_
load
(
signal2_flag
+
3
);
__builtin_nontemporal_
store
(
v1
,
signal2_flag
);
__builtin_nontemporal_
store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_
store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_
store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"
ld
.volatile.global.v4.u32
{
%0
,
%1,%2,%3
}, [
%4
]
;"
:
"
=r"
(
r1
),
"
=
r"
(
r2
),
"
=
r"
(
r3
),
"
=
r"
(
r4
)
:
"
l
"
(
signal2_flag
)
:
"memory"
);
asm
volatile
(
"
st
.volatile.global.v4.u32
[
%0
], {
%1,%2,%3
,
%4
}
;"
:
:
"
l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"
r
"
(
v4
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
btm_done
=
true
;
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
}
while
(
!
top_done
||
!
btm_done
);
}
else
if
(
top_zero
)
{
bool
btm_zeroed
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_
store
(
v1
,
signal
1
_flag
);
__builtin_nontemporal_
store
(
v2
,
signal
1
_flag
+
1
);
__builtin_nontemporal_
store
(
v3
,
signal
1
_flag
+
2
);
__builtin_nontemporal_
store
(
v4
,
signal
1
_flag
+
3
);
r1
=
__builtin_nontemporal_
load
(
signal
2
_flag
);
r2
=
__builtin_nontemporal_
load
(
signal
2
_flag
+
1
);
r3
=
__builtin_nontemporal_
load
(
signal
2
_flag
+
2
);
r4
=
__builtin_nontemporal_
load
(
signal
2
_flag
+
3
);
#else
asm
volatile
(
"
st
.volatile.global.v4.u32
[
%0
], {
%1,%2,%3
,
%4
}
;"
:
:
"
l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"
r
"
(
v4
)
:
"memory"
);
asm
volatile
(
"
ld
.volatile.global.v4.u32
{
%0
,
%1,%2,%3
}, [
%4
]
;"
:
"
=r"
(
r1
),
"
=
r"
(
r2
),
"
=
r"
(
r3
),
"
=
r"
(
r4
)
:
"
l
"
(
signal2_flag
)
:
"memory"
);
#endif
top_done
=
true
;
}
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
(
btm_zeroed
==
btm_done
);
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
}
}
while
(
!
top_done
||
!
btm_done
);
btm_done
=
true
;
}
}
while
(
!
btm_done
);
}
else
if
(
btm_zero
)
{
bool
top_zeroed
=
false
,
top_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
}
while
(
top_zeroed
==
top_done
);
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
}
}
while
(
!
top_done
);
}
}
}
...
...
@@ -238,7 +318,7 @@ __device__ void wait_for(
r4
=
__builtin_nontemporal_load
(
wait_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
#endif
#endif
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
cg
::
this_grid
().
sync
();
// all threads wait for main
...
...
@@ -265,8 +345,8 @@ __device__ void clear_flag(
}
}
template
<
class
T
,
bool
is_HWC
>
#if __CUDA_ARCH__
>
= 700
template
<
class
T
,
bool
is_HWC
,
bool
top_zero
,
bool
btm_zero
>
#if __CUDA_ARCH__
=
= 700
|| __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900
__launch_bounds__
(
128
,
16
)
#endif
__global__
void
push_pull_halos_1d_kernel
(
...
...
@@ -290,20 +370,34 @@ __global__ void push_pull_halos_1d_kernel(
)
{
// push top output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
if
(
!
top_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
// push btm output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
if
(
!
btm_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
if
(
!
(
top_zero
||
btm_zero
))
{
checked_signal
<
false
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
top_zero
)
{
checked_signal
<
true
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
btm_zero
)
{
checked_signal
<
false
,
true
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
// pull top halo from transfer buffer in peer memory to input
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
if
(
top_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
}
else
{
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
}
// pull btm halo from transfer buffer in peer memory to input
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
if
(
btm_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
}
else
{
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
...
...
@@ -392,10 +486,12 @@ void push_pull_halos_1d(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
...
...
@@ -417,6 +513,7 @@ void push_pull_halos_1d(
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
waits
.
is_cuda
());
TORCH_CHECK
(
!
(
top_zero
&&
btm_zero
));
// shapes and strides
int
toh_N
,
toh_C
,
toh_H
,
toh_W
;
...
...
@@ -541,14 +638,34 @@ void push_pull_halos_1d(
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
if
(
top_zero
)
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
...
...
@@ -566,21 +683,57 @@ void push_pull_halos_1d(
};
int
numBlocksPerSm
;
if
(
is_nhwc
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
if
(
top_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
if
(
top_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
}
}
);
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
fd0f7631
...
...
@@ -32,10 +32,12 @@ namespace apex { namespace contrib { namespace peer_memory {
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
...
...
apex/contrib/peer_memory/__init__.py
View file @
fd0f7631
from
.peer_memory
import
PeerMemoryPool
from
.peer_halo_exchanger_1d
import
PeerHaloExchanger1d
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
View file @
fd0f7631
...
...
@@ -40,8 +40,9 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
btm_out_halo
=
y
[:,:,:,
W
:
W
+
half_halo
]
btm_inp_halo
=
y
[:,:,:,
W
+
half_halo
:
W
+
2
*
half_halo
]
top_out_halo
=
top_out_halo
.
clone
(
memory_format
=
torch
.
preserve_format
)
btm_out_halo
=
btm_out_halo
.
clone
(
memory_format
=
torch
.
preserve_format
)
mf
=
torch
.
channels_last
if
y
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
top_out_halo
=
top_out_halo
.
contiguous
()
btm_out_halo
=
btm_out_halo
.
contiguous
()
top_inp_halos
=
[
torch
.
empty_like
(
top_out_halo
)
for
_
in
range
(
peer_group_size
)]
torch
.
distributed
.
all_gather
(
top_inp_halos
,
top_out_halo
)
...
...
@@ -49,8 +50,14 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
torch
.
distributed
.
all_gather
(
btm_inp_halos
,
btm_out_halo
)
top_rank
=
(
peer_rank
+
peer_group_size
-
1
)
%
peer_group_size
btm_rank
=
(
peer_rank
+
1
)
%
peer_group_size
top_inp_halo
.
copy_
(
btm_inp_halos
[
top_rank
])
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
])
if
peer_rank
==
0
:
top_inp_halo
.
zero_
()
else
:
top_inp_halo
.
copy_
(
btm_inp_halos
[
top_rank
].
to
(
memory_format
=
mf
))
if
peer_rank
==
peer_group_size
-
1
:
btm_inp_halo
.
zero_
()
else
:
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
].
to
(
memory_format
=
mf
))
def
single_test
(
peer_rank
,
peer_group_size
,
halo_ex
,
C
,
H
,
W
,
half_halo
,
dtype
,
memory_format
,
H_split
,
num_steps
,
numSM
=
1
):
...
...
@@ -141,12 +148,13 @@ def main():
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
torch
.
cuda
.
set_device
(
rank
)
pool
=
PeerMemoryPool
(
rank
,
world_size
,
world_size
,
64
*
1024
,
2
*
1024
*
1024
)
peer_ranks
=
[
i
for
i
in
range
(
world_size
)]
pool
=
PeerMemoryPool
(
64
*
1024
,
2
*
1024
*
1024
,
peer_ranks
)
num_steps
=
100
half_halo
=
1
halo_ex
=
PeerHaloExchanger1d
(
rank
,
world_size
,
pool
,
half_halo
)
halo_ex
=
PeerHaloExchanger1d
(
peer_
rank
s
,
rank
,
pool
,
half_halo
)
H_split_tests
(
1
,
64
,
336
,
200
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
W_split_tests
(
1
,
64
,
200
,
336
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
...
...
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
View file @
fd0f7631
...
...
@@ -3,9 +3,15 @@ from apex.contrib.peer_memory import PeerMemoryPool
import
peer_memory_cuda
as
pm
class
PeerHaloExchanger1d
:
def
__init__
(
self
,
rank
,
peer_group_size
,
peer_pool
,
half_halo
):
self
.
peer_group_size
=
peer_group_size
self
.
peer_rank
=
rank
%
peer_group_size
def
__init__
(
self
,
ranks
,
rank_in_group
,
peer_pool
,
half_halo
):
self
.
peer_group_size
=
len
(
ranks
)
self
.
ranks
=
ranks
self
.
peer_rank
=
rank_in_group
self
.
low_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
self
.
high_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
self
.
low_zero
=
True
if
self
.
peer_rank
==
0
else
False
self
.
high_zero
=
True
if
self
.
peer_rank
==
self
.
peer_group_size
-
1
else
False
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
peer_rank
].
zero_
()
...
...
@@ -17,45 +23,43 @@ class PeerHaloExchanger1d:
if
explicit_nhwc
:
_
,
Hs
,
_
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,
self
.
half_halo
:
2
*
self
.
half_halo
,:,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
False
,
True
)
top
_inp_halo
=
y
[:,:
self
.
half_halo
,:,:]
btm
_out_halo
=
y
[:,
H
:
H
+
self
.
half_halo
,:,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
False
,
True
)
btm
_inp_halo
=
y
[:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:,:]
low
_out_halo
=
y
[:,
self
.
half_halo
:
2
*
self
.
half_halo
,:,:]
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
False
,
True
)
low
_inp_halo
=
y
[:,:
self
.
half_halo
,:,:]
high
_out_halo
=
y
[:,
H
:
H
+
self
.
half_halo
,:,:]
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
False
,
True
)
high
_inp_halo
=
y
[:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:,:]
else
:
_
,
_
,
Hs
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
channels_last
,
True
)
top
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm
_out_halo
=
y
[:,:,
H
:
H
+
self
.
half_halo
,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
channels_last
,
True
)
btm
_inp_halo
=
y
[:,:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:]
low
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
channels_last
,
True
)
low
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
high
_out_halo
=
y
[:,:,
H
:
H
+
self
.
half_halo
,:]
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
channels_last
,
True
)
high
_inp_halo
=
y
[:,:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:]
else
:
if
explicit_nhwc
:
_
,
_
,
Ws
,
_
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
False
,
True
)
top
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm
_out_halo
=
y
[:,:,
W
:
W
+
self
.
half_halo
,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
False
,
True
)
btm
_inp_halo
=
y
[:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
,:]
low
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
False
,
True
)
low
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
high
_out_halo
=
y
[:,:,
W
:
W
+
self
.
half_halo
,:]
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
False
,
True
)
high
_inp_halo
=
y
[:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
,:]
else
:
_
,
_
,
_
,
Ws
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
top_out_halo
=
y
[:,:,:,
self
.
half_halo
:
2
*
self
.
half_halo
]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
channels_last
,
True
)
top_inp_halo
=
y
[:,:,:,:
self
.
half_halo
]
btm_out_halo
=
y
[:,:,:,
W
:
W
+
self
.
half_halo
]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
channels_last
,
True
)
btm_inp_halo
=
y
[:,:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
]
top_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
btm_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
low_out_halo
=
y
[:,:,:,
self
.
half_halo
:
2
*
self
.
half_halo
]
low_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low_out_halo
.
shape
),
low_out_halo
.
dtype
,
channels_last
,
True
)
low_inp_halo
=
y
[:,:,:,:
self
.
half_halo
]
high_out_halo
=
y
[:,:,:,
W
:
W
+
self
.
half_halo
]
high_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high_out_halo
.
shape
),
high_out_halo
.
dtype
,
channels_last
,
True
)
high_inp_halo
=
y
[:,:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
]
pm
.
push_pull_halos_1d
(
diagnostics
,
explicit_nhwc
,
numSM
,
top
_out_halo
,
top
_tx
[
self
.
peer_rank
],
btm_tx
[
top
_neighbor
],
top
_inp_halo
,
btm
_out_halo
,
btm
_tx
[
self
.
peer_rank
],
top
_tx
[
btm
_neighbor
],
btm
_inp_halo
,
self
.
signals
[
top
_neighbor
],
self
.
signals
[
btm
_neighbor
],
self
.
signals
[
self
.
peer_rank
]
self
.
low_zero
,
low
_out_halo
,
low
_tx
[
self
.
peer_rank
],
high_tx
[
self
.
low
_neighbor
],
low
_inp_halo
,
self
.
high_zero
,
high
_out_halo
,
high
_tx
[
self
.
peer_rank
],
low
_tx
[
self
.
high
_neighbor
],
high
_inp_halo
,
self
.
signals
[
self
.
low
_neighbor
],
self
.
signals
[
self
.
high
_neighbor
],
self
.
signals
[
self
.
peer_rank
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment