Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
7a344314
Unverified
Commit
7a344314
authored
Sep 08, 2022
by
Jithun Nair
Committed by
GitHub
Sep 08, 2022
Browse files
Merge branch 'master' into dev/hubertlu/focal_loss_and_index_mul_2d_cuda
parents
9187ea1d
ae5ca671
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
411 additions
and
108 deletions
+411
-108
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+2
-7
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+4
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+262
-48
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+2
-0
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
+8
-2
apex/contrib/peer_memory/__init__.py
apex/contrib/peer_memory/__init__.py
+1
-0
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
+14
-6
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
+36
-32
apex/contrib/test/run_rocm_extensions.py
apex/contrib/test/run_rocm_extensions.py
+1
-1
apex/contrib/test/transducer/test_transducer_joint.py
apex/contrib/test/transducer/test_transducer_joint.py
+7
-1
setup.py
setup.py
+74
-11
No files found.
apex/contrib/bottleneck/halo_exchangers.py
View file @
7a344314
...
...
@@ -107,15 +107,10 @@ class HaloExchangerPeer(HaloExchanger):
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_output_halo
.
shape
),
right_output_halo
.
dtype
,
channels_last
,
True
)
pm
.
push_pull_halos_1d
(
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
left_zero
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
self
.
right_zero
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
signals
[
self
.
wrap_around_left_rank_in_group
],
self
.
signals
[
self
.
wrap_around_right_rank_in_group
],
self
.
signals
[
self
.
rank_in_group
]
)
# TODO: Add to push_pull_halos_1d kernel
if
self
.
left_zero
:
left_input_halo
.
zero_
()
if
self
.
right_zero
:
right_input_halo
.
zero_
()
if
not
inplace
:
return
left_input_halo
,
right_input_halo
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
View file @
7a344314
...
...
@@ -5,7 +5,11 @@
#include <cstdio>
#include <ctime>
#include <cassert>
#ifdef __HIP_PLATFORM_HCC__
#include "rccl.h"
#else
#include "nccl.h"
#endif
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
7a344314
...
...
@@ -5,8 +5,15 @@
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#include "rccl.h"
#else
#include <cooperative_groups.h>
#include "nccl.h"
#endif
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
...
...
@@ -117,7 +124,20 @@ void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride
}
}
template
<
class
T
,
bool
is_HWC
>
template
<
class
T
>
__device__
void
__zero
(
T
*
dst
)
{
*
dst
=
T
(
0
);
}
__device__
void
__zero
(
int4
*
dst
)
{
int4
v
;
v
.
x
=
v
.
y
=
v
.
z
=
v
.
w
=
0
;
*
dst
=
v
;
}
template
<
class
T
,
bool
is_HWC
,
bool
zero
>
__device__
void
strided_copy_kernel
(
T
*
dst
,
const
int
dst_stride_C
,
const
int
dst_stride_H
,
const
int
dst_stride_W
,
const
T
*
src
,
const
int
src_stride_C
,
const
int
src_stride_H
,
const
int
src_stride_W
,
...
...
@@ -131,23 +151,28 @@ __device__ void strided_copy_kernel(
{
size_t
c
,
h
,
w
;
if
(
is_HWC
)
{
c
=
i
%
NC
;
w
=
i
/
NC
;
c
=
i
-
w
*
NC
;
h
=
w
/
NW
;
w
=
w
%
NW
;
w
=
w
-
h
*
NW
;
}
else
{
w
=
i
%
NW
;
h
=
i
/
NW
;
w
=
i
-
h
*
NW
;
c
=
h
/
NH
;
h
=
h
%
NH
;
h
=
h
-
c
*
NH
;
}
size_t
dst_off
=
c
*
dst_stride_C
+
h
*
dst_stride_H
+
w
*
dst_stride_W
;
if
(
zero
)
{
__zero
(
dst
+
dst_off
);
}
else
{
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
dst
[
dst_off
]
=
src
[
src_off
];
}
}
}
template
<
bool
top_zero
,
bool
btm_zero
>
__device__
void
checked_signal
(
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
...
...
@@ -160,29 +185,119 @@ __device__ void checked_signal(
__threadfence_system
();
// wait for top or bottom neighbor to clear signal
register
int
r1
,
r2
,
r3
,
r4
;
bool
top_zeroed
=
false
,
btm_zeroed
=
false
,
top_done
=
false
,
btm_done
=
false
;
if
(
!
(
top_zero
||
btm_zero
))
{
bool
top_zeroed
=
false
,
top_done
=
false
;
bool
btm_zeroed
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
}
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
}
}
while
(
!
top_done
||
!
btm_done
);
}
else
if
(
top_zero
)
{
bool
btm_zeroed
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
(
btm_zeroed
==
btm_done
);
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
}
}
while
(
!
btm_done
);
}
else
if
(
btm_zero
)
{
bool
top_zeroed
=
false
,
top_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
}
while
(
top_zeroed
==
top_done
);
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
}
}
while
(
!
top_done
);
}
}
}
...
...
@@ -196,7 +311,14 @@ __device__ void wait_for(
register
int
r1
,
r2
,
r3
,
r4
;
// wait for senders to signal their output is read
do
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
wait_flag
);
r2
=
__builtin_nontemporal_load
(
wait_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
wait_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
wait_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
#endif
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
cg
::
this_grid
().
sync
();
// all threads wait for main
...
...
@@ -212,12 +334,19 @@ __device__ void clear_flag(
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
r1
,
wait_flag
);
__builtin_nontemporal_store
(
r2
,
wait_flag
+
1
);
__builtin_nontemporal_store
(
r3
,
wait_flag
+
2
);
__builtin_nontemporal_store
(
r4
,
wait_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
#endif
}
}
template
<
class
T
,
bool
is_HWC
>
#if __CUDA_ARCH__
>
= 700
template
<
class
T
,
bool
is_HWC
,
bool
top_zero
,
bool
btm_zero
>
#if __CUDA_ARCH__
=
= 700
|| __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900
__launch_bounds__
(
128
,
16
)
#endif
__global__
void
push_pull_halos_1d_kernel
(
...
...
@@ -241,20 +370,34 @@ __global__ void push_pull_halos_1d_kernel(
)
{
// push top output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
if
(
!
top_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
// push btm output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
if
(
!
btm_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
if
(
!
(
top_zero
||
btm_zero
))
{
checked_signal
<
false
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
top_zero
)
{
checked_signal
<
true
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
btm_zero
)
{
checked_signal
<
false
,
true
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
// pull top halo from transfer buffer in peer memory to input
if
(
top_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
}
else
{
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
}
// pull btm halo from transfer buffer in peer memory to input
if
(
btm_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
}
else
{
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
...
...
@@ -343,10 +486,12 @@ void push_pull_halos_1d(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
...
...
@@ -368,6 +513,7 @@ void push_pull_halos_1d(
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
waits
.
is_cuda
());
TORCH_CHECK
(
!
(
top_zero
&&
btm_zero
));
// shapes and strides
int
toh_N
,
toh_C
,
toh_H
,
toh_W
;
...
...
@@ -492,10 +638,34 @@ void push_pull_halos_1d(
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
if
(
top_zero
)
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
>
,
numThreads
,
0
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
...
...
@@ -513,13 +683,57 @@ void push_pull_halos_1d(
};
int
numBlocksPerSm
;
if
(
is_nhwc
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
numThreads
,
0
);
if
(
top_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
numThreads
,
0
);
if
(
top_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
}
}
);
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
7a344314
...
...
@@ -32,10 +32,12 @@ namespace apex { namespace contrib { namespace peer_memory {
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
...
...
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
View file @
7a344314
...
...
@@ -17,12 +17,18 @@
#include "philox.cuh"
#ifdef __HIP_PLATFORM_HCC__
#define SHFL_DOWN(val, laneMask, width) __shfl_down(val, laneMask, width)
#else
#define SHFL_DOWN(val, laneMask, width) __shfl_down_sync(0xffffffff, val, laneMask, width)
#endif
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize.
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
warpReduce
(
scalar_t
x
,
int
width
=
C10_WARP_SIZE
){
for
(
unsigned
offset
=
width
/
2
;
offset
>
0
;
offset
/=
2
){
x
+=
__shfl_down_sync
(
0xffffffff
,
x
,
offset
,
width
);
x
+=
SHFL_DOWN
(
x
,
offset
,
width
);
}
return
x
;
}
...
...
apex/contrib/peer_memory/__init__.py
View file @
7a344314
from
.peer_memory
import
PeerMemoryPool
from
.peer_halo_exchanger_1d
import
PeerHaloExchanger1d
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
View file @
7a344314
...
...
@@ -40,8 +40,9 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
btm_out_halo
=
y
[:,:,:,
W
:
W
+
half_halo
]
btm_inp_halo
=
y
[:,:,:,
W
+
half_halo
:
W
+
2
*
half_halo
]
top_out_halo
=
top_out_halo
.
clone
(
memory_format
=
torch
.
preserve_format
)
btm_out_halo
=
btm_out_halo
.
clone
(
memory_format
=
torch
.
preserve_format
)
mf
=
torch
.
channels_last
if
y
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
top_out_halo
=
top_out_halo
.
contiguous
()
btm_out_halo
=
btm_out_halo
.
contiguous
()
top_inp_halos
=
[
torch
.
empty_like
(
top_out_halo
)
for
_
in
range
(
peer_group_size
)]
torch
.
distributed
.
all_gather
(
top_inp_halos
,
top_out_halo
)
...
...
@@ -49,8 +50,14 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
torch
.
distributed
.
all_gather
(
btm_inp_halos
,
btm_out_halo
)
top_rank
=
(
peer_rank
+
peer_group_size
-
1
)
%
peer_group_size
btm_rank
=
(
peer_rank
+
1
)
%
peer_group_size
top_inp_halo
.
copy_
(
btm_inp_halos
[
top_rank
])
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
])
if
peer_rank
==
0
:
top_inp_halo
.
zero_
()
else
:
top_inp_halo
.
copy_
(
btm_inp_halos
[
top_rank
].
to
(
memory_format
=
mf
))
if
peer_rank
==
peer_group_size
-
1
:
btm_inp_halo
.
zero_
()
else
:
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
].
to
(
memory_format
=
mf
))
def
single_test
(
peer_rank
,
peer_group_size
,
halo_ex
,
C
,
H
,
W
,
half_halo
,
dtype
,
memory_format
,
H_split
,
num_steps
,
numSM
=
1
):
...
...
@@ -141,12 +148,13 @@ def main():
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
torch
.
cuda
.
set_device
(
rank
)
pool
=
PeerMemoryPool
(
rank
,
world_size
,
world_size
,
64
*
1024
,
2
*
1024
*
1024
)
peer_ranks
=
[
i
for
i
in
range
(
world_size
)]
pool
=
PeerMemoryPool
(
64
*
1024
,
2
*
1024
*
1024
,
peer_ranks
)
num_steps
=
100
half_halo
=
1
halo_ex
=
PeerHaloExchanger1d
(
rank
,
world_size
,
pool
,
half_halo
)
halo_ex
=
PeerHaloExchanger1d
(
peer_
rank
s
,
rank
,
pool
,
half_halo
)
H_split_tests
(
1
,
64
,
336
,
200
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
W_split_tests
(
1
,
64
,
200
,
336
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
...
...
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
View file @
7a344314
...
...
@@ -3,9 +3,15 @@ from apex.contrib.peer_memory import PeerMemoryPool
import
peer_memory_cuda
as
pm
class
PeerHaloExchanger1d
:
def
__init__
(
self
,
rank
,
peer_group_size
,
peer_pool
,
half_halo
):
self
.
peer_group_size
=
peer_group_size
self
.
peer_rank
=
rank
%
peer_group_size
def
__init__
(
self
,
ranks
,
rank_in_group
,
peer_pool
,
half_halo
):
self
.
peer_group_size
=
len
(
ranks
)
self
.
ranks
=
ranks
self
.
peer_rank
=
rank_in_group
self
.
low_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
self
.
high_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
self
.
low_zero
=
True
if
self
.
peer_rank
==
0
else
False
self
.
high_zero
=
True
if
self
.
peer_rank
==
self
.
peer_group_size
-
1
else
False
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
peer_rank
].
zero_
()
...
...
@@ -17,45 +23,43 @@ class PeerHaloExchanger1d:
if
explicit_nhwc
:
_
,
Hs
,
_
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,
self
.
half_halo
:
2
*
self
.
half_halo
,:,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
False
,
True
)
top
_inp_halo
=
y
[:,:
self
.
half_halo
,:,:]
btm
_out_halo
=
y
[:,
H
:
H
+
self
.
half_halo
,:,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
False
,
True
)
btm
_inp_halo
=
y
[:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:,:]
low
_out_halo
=
y
[:,
self
.
half_halo
:
2
*
self
.
half_halo
,:,:]
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
False
,
True
)
low
_inp_halo
=
y
[:,:
self
.
half_halo
,:,:]
high
_out_halo
=
y
[:,
H
:
H
+
self
.
half_halo
,:,:]
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
False
,
True
)
high
_inp_halo
=
y
[:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:,:]
else
:
_
,
_
,
Hs
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
channels_last
,
True
)
top
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm
_out_halo
=
y
[:,:,
H
:
H
+
self
.
half_halo
,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
channels_last
,
True
)
btm
_inp_halo
=
y
[:,:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:]
low
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
channels_last
,
True
)
low
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
high
_out_halo
=
y
[:,:,
H
:
H
+
self
.
half_halo
,:]
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
channels_last
,
True
)
high
_inp_halo
=
y
[:,:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:]
else
:
if
explicit_nhwc
:
_
,
_
,
Ws
,
_
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
False
,
True
)
top
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm
_out_halo
=
y
[:,:,
W
:
W
+
self
.
half_halo
,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
False
,
True
)
btm
_inp_halo
=
y
[:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
,:]
low
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
False
,
True
)
low
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
high
_out_halo
=
y
[:,:,
W
:
W
+
self
.
half_halo
,:]
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
False
,
True
)
high
_inp_halo
=
y
[:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
,:]
else
:
_
,
_
,
_
,
Ws
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
top_out_halo
=
y
[:,:,:,
self
.
half_halo
:
2
*
self
.
half_halo
]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
channels_last
,
True
)
top_inp_halo
=
y
[:,:,:,:
self
.
half_halo
]
btm_out_halo
=
y
[:,:,:,
W
:
W
+
self
.
half_halo
]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
channels_last
,
True
)
btm_inp_halo
=
y
[:,:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
]
top_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
btm_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
low_out_halo
=
y
[:,:,:,
self
.
half_halo
:
2
*
self
.
half_halo
]
low_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low_out_halo
.
shape
),
low_out_halo
.
dtype
,
channels_last
,
True
)
low_inp_halo
=
y
[:,:,:,:
self
.
half_halo
]
high_out_halo
=
y
[:,:,:,
W
:
W
+
self
.
half_halo
]
high_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high_out_halo
.
shape
),
high_out_halo
.
dtype
,
channels_last
,
True
)
high_inp_halo
=
y
[:,:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
]
pm
.
push_pull_halos_1d
(
diagnostics
,
explicit_nhwc
,
numSM
,
top
_out_halo
,
top
_tx
[
self
.
peer_rank
],
btm_tx
[
top
_neighbor
],
top
_inp_halo
,
btm
_out_halo
,
btm
_tx
[
self
.
peer_rank
],
top
_tx
[
btm
_neighbor
],
btm
_inp_halo
,
self
.
signals
[
top
_neighbor
],
self
.
signals
[
btm
_neighbor
],
self
.
signals
[
self
.
peer_rank
]
self
.
low_zero
,
low
_out_halo
,
low
_tx
[
self
.
peer_rank
],
high_tx
[
self
.
low
_neighbor
],
low
_inp_halo
,
self
.
high_zero
,
high
_out_halo
,
high
_tx
[
self
.
peer_rank
],
low
_tx
[
self
.
high
_neighbor
],
high
_inp_halo
,
self
.
signals
[
self
.
low
_neighbor
],
self
.
signals
[
self
.
high
_neighbor
],
self
.
signals
[
self
.
peer_rank
]
)
apex/contrib/test/run_rocm_extensions.py
View file @
7a344314
...
...
@@ -2,7 +2,7 @@ import unittest
import
sys
test_dirs
=
[
"groupbn"
,
"layer_norm"
,
"multihead_attn"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
test_dirs
=
[
"groupbn"
,
"layer_norm"
,
"multihead_attn"
,
"transducer"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
ROCM_BLACKLIST
=
[
"layer_norm"
]
...
...
apex/contrib/test/transducer/test_transducer_joint.py
View file @
7a344314
...
...
@@ -121,6 +121,7 @@ class TransducerJointTest(unittest.TestCase):
def
test_transducer_joint_vec
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
False
,
dropout
=
False
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_pack
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
False
,
dropout
=
False
)
...
...
@@ -133,21 +134,26 @@ class TransducerJointTest(unittest.TestCase):
def
test_transducer_joint_vec_relu
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
True
,
dropout
=
False
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_pack_relu
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
True
,
dropout
=
False
)
def
test_transducer_joint_vec_pack_relu
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
False
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_vec_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
True
,
dropout
=
True
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_pack_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_vec_pack_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
...
...
setup.py
View file @
7a344314
...
...
@@ -42,6 +42,55 @@ def get_cuda_bare_metal_version(cuda_dir):
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_minor
):
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.
\n
"
.
format
(
torch
.
version
.
cuda
)
+
"In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def
raise_if_cuda_home_none
(
global_option
:
str
)
->
None
:
if
CUDA_HOME
is
not
None
:
return
raise
RuntimeError
(
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
def
check_cudnn_version_and_warn
(
global_option
:
str
,
required_cudnn_version
:
int
)
->
bool
:
cudnn_available
=
torch
.
backends
.
cudnn
.
is_available
()
cudnn_version
=
torch
.
backends
.
cudnn
.
version
()
if
cudnn_available
else
None
if
not
(
cudnn_available
and
(
cudnn_version
>=
required_cudnn_version
)):
warnings
.
warn
(
f
"Skip `
{
global_option
}
` as it requires cuDNN
{
required_cudnn_version
}
or later, "
f
"but
{
'cuDNN is not available'
if
not
cudnn_available
else
cudnn_version
}
"
)
return
False
return
True
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
...
...
@@ -524,9 +573,13 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
)
)
if
"--transducer"
in
sys
.
argv
:
if
"--transducer"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--transducer"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--transducer"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--transducer"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"transducer_joint_cuda"
,
...
...
@@ -536,7 +589,8 @@ if "--transducer" in sys.argv:
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
,
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
+
generator_flag
),
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
+
generator_flag
)
if
not
IS_ROCM_PYTORCH
else
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
,
},
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"csrc"
),
os
.
path
.
join
(
this_dir
,
"apex/contrib/csrc/multihead_attn"
)],
)
...
...
@@ -551,7 +605,8 @@ if "--transducer" in sys.argv:
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"csrc"
)],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
,
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
),
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
)
if
not
IS_ROCM_PYTORCH
else
[
"-O3"
]
+
version_dependent_macros
,
},
)
)
...
...
@@ -571,9 +626,13 @@ if "--fast_bottleneck" in sys.argv:
)
)
if
"--peer_memory"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--peer_memory"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--peer_memory"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"peer_memory_cuda"
,
...
...
@@ -585,9 +644,13 @@ if "--peer_memory" in sys.argv:
)
)
if
"--nccl_p2p"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--nccl_p2p"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--nccl_p2p"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"nccl_p2p_cuda"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment