Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
7a344314
"tools/imglab/vscode:/vscode.git/clone" did not exist on "09b31219fac20814494dc23820b422b90b804fec"
Unverified
Commit
7a344314
authored
Sep 08, 2022
by
Jithun Nair
Committed by
GitHub
Sep 08, 2022
Browse files
Merge branch 'master' into dev/hubertlu/focal_loss_and_index_mul_2d_cuda
parents
9187ea1d
ae5ca671
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
411 additions
and
108 deletions
+411
-108
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+2
-7
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+4
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+262
-48
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+2
-0
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
+8
-2
apex/contrib/peer_memory/__init__.py
apex/contrib/peer_memory/__init__.py
+1
-0
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
+14
-6
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
+36
-32
apex/contrib/test/run_rocm_extensions.py
apex/contrib/test/run_rocm_extensions.py
+1
-1
apex/contrib/test/transducer/test_transducer_joint.py
apex/contrib/test/transducer/test_transducer_joint.py
+7
-1
setup.py
setup.py
+74
-11
No files found.
apex/contrib/bottleneck/halo_exchangers.py
View file @
7a344314
...
...
@@ -107,15 +107,10 @@ class HaloExchangerPeer(HaloExchanger):
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_output_halo
.
shape
),
right_output_halo
.
dtype
,
channels_last
,
True
)
pm
.
push_pull_halos_1d
(
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
left_zero
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
self
.
right_zero
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
signals
[
self
.
wrap_around_left_rank_in_group
],
self
.
signals
[
self
.
wrap_around_right_rank_in_group
],
self
.
signals
[
self
.
rank_in_group
]
)
# TODO: Add to push_pull_halos_1d kernel
if
self
.
left_zero
:
left_input_halo
.
zero_
()
if
self
.
right_zero
:
right_input_halo
.
zero_
()
if
not
inplace
:
return
left_input_halo
,
right_input_halo
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
View file @
7a344314
...
...
@@ -5,7 +5,11 @@
#include <cstdio>
#include <ctime>
#include <cassert>
#ifdef __HIP_PLATFORM_HCC__
#include "rccl.h"
#else
#include "nccl.h"
#endif
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
7a344314
...
...
@@ -5,8 +5,15 @@
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#include "rccl.h"
#else
#include <cooperative_groups.h>
#include "nccl.h"
#endif
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
...
...
@@ -117,7 +124,20 @@ void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride
}
}
template
<
class
T
,
bool
is_HWC
>
template
<
class
T
>
__device__
void
__zero
(
T
*
dst
)
{
*
dst
=
T
(
0
);
}
__device__
void
__zero
(
int4
*
dst
)
{
int4
v
;
v
.
x
=
v
.
y
=
v
.
z
=
v
.
w
=
0
;
*
dst
=
v
;
}
template
<
class
T
,
bool
is_HWC
,
bool
zero
>
__device__
void
strided_copy_kernel
(
T
*
dst
,
const
int
dst_stride_C
,
const
int
dst_stride_H
,
const
int
dst_stride_W
,
const
T
*
src
,
const
int
src_stride_C
,
const
int
src_stride_H
,
const
int
src_stride_W
,
...
...
@@ -131,23 +151,28 @@ __device__ void strided_copy_kernel(
{
size_t
c
,
h
,
w
;
if
(
is_HWC
)
{
c
=
i
%
NC
;
w
=
i
/
NC
;
c
=
i
-
w
*
NC
;
h
=
w
/
NW
;
w
=
w
%
NW
;
w
=
w
-
h
*
NW
;
}
else
{
w
=
i
%
NW
;
h
=
i
/
NW
;
w
=
i
-
h
*
NW
;
c
=
h
/
NH
;
h
=
h
%
NH
;
h
=
h
-
c
*
NH
;
}
size_t
dst_off
=
c
*
dst_stride_C
+
h
*
dst_stride_H
+
w
*
dst_stride_W
;
if
(
zero
)
{
__zero
(
dst
+
dst_off
);
}
else
{
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
dst
[
dst_off
]
=
src
[
src_off
];
}
}
}
template
<
bool
top_zero
,
bool
btm_zero
>
__device__
void
checked_signal
(
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
...
...
@@ -160,29 +185,119 @@ __device__ void checked_signal(
__threadfence_system
();
// wait for top or bottom neighbor to clear signal
register
int
r1
,
r2
,
r3
,
r4
;
bool
top_zeroed
=
false
,
btm_zeroed
=
false
,
top_done
=
false
,
btm_done
=
false
;
if
(
!
(
top_zero
||
btm_zero
))
{
bool
top_zeroed
=
false
,
top_done
=
false
;
bool
btm_zeroed
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
}
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
}
}
while
(
!
top_done
||
!
btm_done
);
}
else
if
(
top_zero
)
{
bool
btm_zeroed
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
(
btm_zeroed
==
btm_done
);
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
}
}
while
(
!
btm_done
);
}
else
if
(
btm_zero
)
{
bool
top_zeroed
=
false
,
top_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
}
while
(
top_zeroed
==
top_done
);
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
}
}
while
(
!
top_done
);
}
}
}
...
...
@@ -196,7 +311,14 @@ __device__ void wait_for(
register
int
r1
,
r2
,
r3
,
r4
;
// wait for senders to signal their output is read
do
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
wait_flag
);
r2
=
__builtin_nontemporal_load
(
wait_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
wait_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
wait_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
#endif
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
cg
::
this_grid
().
sync
();
// all threads wait for main
...
...
@@ -212,12 +334,19 @@ __device__ void clear_flag(
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
r1
,
wait_flag
);
__builtin_nontemporal_store
(
r2
,
wait_flag
+
1
);
__builtin_nontemporal_store
(
r3
,
wait_flag
+
2
);
__builtin_nontemporal_store
(
r4
,
wait_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
#endif
}
}
template
<
class
T
,
bool
is_HWC
>
#if __CUDA_ARCH__
>
= 700
template
<
class
T
,
bool
is_HWC
,
bool
top_zero
,
bool
btm_zero
>
#if __CUDA_ARCH__
=
= 700
|| __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900
__launch_bounds__
(
128
,
16
)
#endif
__global__
void
push_pull_halos_1d_kernel
(
...
...
@@ -241,20 +370,34 @@ __global__ void push_pull_halos_1d_kernel(
)
{
// push top output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
if
(
!
top_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
// push btm output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
if
(
!
btm_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
if
(
!
(
top_zero
||
btm_zero
))
{
checked_signal
<
false
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
top_zero
)
{
checked_signal
<
true
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
btm_zero
)
{
checked_signal
<
false
,
true
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
// pull top halo from transfer buffer in peer memory to input
if
(
top_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
}
else
{
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
}
// pull btm halo from transfer buffer in peer memory to input
if
(
btm_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
}
else
{
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
...
...
@@ -343,10 +486,12 @@ void push_pull_halos_1d(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
...
...
@@ -368,6 +513,7 @@ void push_pull_halos_1d(
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
waits
.
is_cuda
());
TORCH_CHECK
(
!
(
top_zero
&&
btm_zero
));
// shapes and strides
int
toh_N
,
toh_C
,
toh_H
,
toh_W
;
...
...
@@ -492,10 +638,34 @@ void push_pull_halos_1d(
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
if
(
top_zero
)
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
>
,
numThreads
,
0
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
...
...
@@ -513,13 +683,57 @@ void push_pull_halos_1d(
};
int
numBlocksPerSm
;
if
(
is_nhwc
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
numThreads
,
0
);
if
(
top_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
numThreads
,
0
);
if
(
top_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
}
}
);
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
7a344314
...
...
@@ -32,10 +32,12 @@ namespace apex { namespace contrib { namespace peer_memory {
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
...
...
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
View file @
7a344314
...
...
@@ -17,12 +17,18 @@
#include "philox.cuh"
#ifdef __HIP_PLATFORM_HCC__
#define SHFL_DOWN(val, laneMask, width) __shfl_down(val, laneMask, width)
#else
#define SHFL_DOWN(val, laneMask, width) __shfl_down_sync(0xffffffff, val, laneMask, width)
#endif
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize.
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
warpReduce
(
scalar_t
x
,
int
width
=
C10_WARP_SIZE
){
for
(
unsigned
offset
=
width
/
2
;
offset
>
0
;
offset
/=
2
){
x
+=
__shfl_down_sync
(
0xffffffff
,
x
,
offset
,
width
);
x
+=
SHFL_DOWN
(
x
,
offset
,
width
);
}
return
x
;
}
...
...
apex/contrib/peer_memory/__init__.py
View file @
7a344314
from
.peer_memory
import
PeerMemoryPool
from
.peer_halo_exchanger_1d
import
PeerHaloExchanger1d
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
View file @
7a344314
...
...
@@ -40,8 +40,9 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
btm_out_halo
=
y
[:,:,:,
W
:
W
+
half_halo
]
btm_inp_halo
=
y
[:,:,:,
W
+
half_halo
:
W
+
2
*
half_halo
]
top_out_halo
=
top_out_halo
.
clone
(
memory_format
=
torch
.
preserve_format
)
btm_out_halo
=
btm_out_halo
.
clone
(
memory_format
=
torch
.
preserve_format
)
mf
=
torch
.
channels_last
if
y
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
top_out_halo
=
top_out_halo
.
contiguous
()
btm_out_halo
=
btm_out_halo
.
contiguous
()
top_inp_halos
=
[
torch
.
empty_like
(
top_out_halo
)
for
_
in
range
(
peer_group_size
)]
torch
.
distributed
.
all_gather
(
top_inp_halos
,
top_out_halo
)
...
...
@@ -49,8 +50,14 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
torch
.
distributed
.
all_gather
(
btm_inp_halos
,
btm_out_halo
)
top_rank
=
(
peer_rank
+
peer_group_size
-
1
)
%
peer_group_size
btm_rank
=
(
peer_rank
+
1
)
%
peer_group_size
top_inp_halo
.
copy_
(
btm_inp_halos
[
top_rank
])
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
])
if
peer_rank
==
0
:
top_inp_halo
.
zero_
()
else
:
top_inp_halo
.
copy_
(
btm_inp_halos
[
top_rank
].
to
(
memory_format
=
mf
))
if
peer_rank
==
peer_group_size
-
1
:
btm_inp_halo
.
zero_
()
else
:
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
].
to
(
memory_format
=
mf
))
def
single_test
(
peer_rank
,
peer_group_size
,
halo_ex
,
C
,
H
,
W
,
half_halo
,
dtype
,
memory_format
,
H_split
,
num_steps
,
numSM
=
1
):
...
...
@@ -141,12 +148,13 @@ def main():
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
torch
.
cuda
.
set_device
(
rank
)
pool
=
PeerMemoryPool
(
rank
,
world_size
,
world_size
,
64
*
1024
,
2
*
1024
*
1024
)
peer_ranks
=
[
i
for
i
in
range
(
world_size
)]
pool
=
PeerMemoryPool
(
64
*
1024
,
2
*
1024
*
1024
,
peer_ranks
)
num_steps
=
100
half_halo
=
1
halo_ex
=
PeerHaloExchanger1d
(
rank
,
world_size
,
pool
,
half_halo
)
halo_ex
=
PeerHaloExchanger1d
(
peer_
rank
s
,
rank
,
pool
,
half_halo
)
H_split_tests
(
1
,
64
,
336
,
200
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
W_split_tests
(
1
,
64
,
200
,
336
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
...
...
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
View file @
7a344314
...
...
@@ -3,9 +3,15 @@ from apex.contrib.peer_memory import PeerMemoryPool
import
peer_memory_cuda
as
pm
class
PeerHaloExchanger1d
:
def
__init__
(
self
,
rank
,
peer_group_size
,
peer_pool
,
half_halo
):
self
.
peer_group_size
=
peer_group_size
self
.
peer_rank
=
rank
%
peer_group_size
def
__init__
(
self
,
ranks
,
rank_in_group
,
peer_pool
,
half_halo
):
self
.
peer_group_size
=
len
(
ranks
)
self
.
ranks
=
ranks
self
.
peer_rank
=
rank_in_group
self
.
low_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
self
.
high_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
self
.
low_zero
=
True
if
self
.
peer_rank
==
0
else
False
self
.
high_zero
=
True
if
self
.
peer_rank
==
self
.
peer_group_size
-
1
else
False
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
peer_rank
].
zero_
()
...
...
@@ -17,45 +23,43 @@ class PeerHaloExchanger1d:
if
explicit_nhwc
:
_
,
Hs
,
_
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,
self
.
half_halo
:
2
*
self
.
half_halo
,:,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
False
,
True
)
top
_inp_halo
=
y
[:,:
self
.
half_halo
,:,:]
btm
_out_halo
=
y
[:,
H
:
H
+
self
.
half_halo
,:,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
False
,
True
)
btm
_inp_halo
=
y
[:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:,:]
low
_out_halo
=
y
[:,
self
.
half_halo
:
2
*
self
.
half_halo
,:,:]
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
False
,
True
)
low
_inp_halo
=
y
[:,:
self
.
half_halo
,:,:]
high
_out_halo
=
y
[:,
H
:
H
+
self
.
half_halo
,:,:]
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
False
,
True
)
high
_inp_halo
=
y
[:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:,:]
else
:
_
,
_
,
Hs
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
channels_last
,
True
)
top
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm
_out_halo
=
y
[:,:,
H
:
H
+
self
.
half_halo
,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
channels_last
,
True
)
btm
_inp_halo
=
y
[:,:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:]
low
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
channels_last
,
True
)
low
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
high
_out_halo
=
y
[:,:,
H
:
H
+
self
.
half_halo
,:]
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
channels_last
,
True
)
high
_inp_halo
=
y
[:,:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:]
else
:
if
explicit_nhwc
:
_
,
_
,
Ws
,
_
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
False
,
True
)
top
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm
_out_halo
=
y
[:,:,
W
:
W
+
self
.
half_halo
,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
False
,
True
)
btm
_inp_halo
=
y
[:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
,:]
low
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
False
,
True
)
low
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
high
_out_halo
=
y
[:,:,
W
:
W
+
self
.
half_halo
,:]
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
False
,
True
)
high
_inp_halo
=
y
[:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
,:]
else
:
_
,
_
,
_
,
Ws
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
top_out_halo
=
y
[:,:,:,
self
.
half_halo
:
2
*
self
.
half_halo
]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
channels_last
,
True
)
top_inp_halo
=
y
[:,:,:,:
self
.
half_halo
]
btm_out_halo
=
y
[:,:,:,
W
:
W
+
self
.
half_halo
]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
channels_last
,
True
)
btm_inp_halo
=
y
[:,:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
]
top_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
btm_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
low_out_halo
=
y
[:,:,:,
self
.
half_halo
:
2
*
self
.
half_halo
]
low_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low_out_halo
.
shape
),
low_out_halo
.
dtype
,
channels_last
,
True
)
low_inp_halo
=
y
[:,:,:,:
self
.
half_halo
]
high_out_halo
=
y
[:,:,:,
W
:
W
+
self
.
half_halo
]
high_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high_out_halo
.
shape
),
high_out_halo
.
dtype
,
channels_last
,
True
)
high_inp_halo
=
y
[:,:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
]
pm
.
push_pull_halos_1d
(
diagnostics
,
explicit_nhwc
,
numSM
,
top
_out_halo
,
top
_tx
[
self
.
peer_rank
],
btm_tx
[
top
_neighbor
],
top
_inp_halo
,
btm
_out_halo
,
btm
_tx
[
self
.
peer_rank
],
top
_tx
[
btm
_neighbor
],
btm
_inp_halo
,
self
.
signals
[
top
_neighbor
],
self
.
signals
[
btm
_neighbor
],
self
.
signals
[
self
.
peer_rank
]
self
.
low_zero
,
low
_out_halo
,
low
_tx
[
self
.
peer_rank
],
high_tx
[
self
.
low
_neighbor
],
low
_inp_halo
,
self
.
high_zero
,
high
_out_halo
,
high
_tx
[
self
.
peer_rank
],
low
_tx
[
self
.
high
_neighbor
],
high
_inp_halo
,
self
.
signals
[
self
.
low
_neighbor
],
self
.
signals
[
self
.
high
_neighbor
],
self
.
signals
[
self
.
peer_rank
]
)
apex/contrib/test/run_rocm_extensions.py
View file @
7a344314
...
...
@@ -2,7 +2,7 @@ import unittest
import
sys
test_dirs
=
[
"groupbn"
,
"layer_norm"
,
"multihead_attn"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
test_dirs
=
[
"groupbn"
,
"layer_norm"
,
"multihead_attn"
,
"transducer"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
ROCM_BLACKLIST
=
[
"layer_norm"
]
...
...
apex/contrib/test/transducer/test_transducer_joint.py
View file @
7a344314
...
...
@@ -121,6 +121,7 @@ class TransducerJointTest(unittest.TestCase):
def
test_transducer_joint_vec
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
False
,
dropout
=
False
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_pack
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
False
,
dropout
=
False
)
...
...
@@ -133,21 +134,26 @@ class TransducerJointTest(unittest.TestCase):
def
test_transducer_joint_vec_relu
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
True
,
dropout
=
False
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_pack_relu
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
True
,
dropout
=
False
)
def
test_transducer_joint_vec_pack_relu
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
False
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_vec_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
True
,
dropout
=
True
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_pack_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_vec_pack_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
...
...
setup.py
View file @
7a344314
...
...
@@ -42,6 +42,55 @@ def get_cuda_bare_metal_version(cuda_dir):
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_minor
):
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.
\n
"
.
format
(
torch
.
version
.
cuda
)
+
"In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def
raise_if_cuda_home_none
(
global_option
:
str
)
->
None
:
if
CUDA_HOME
is
not
None
:
return
raise
RuntimeError
(
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
def
check_cudnn_version_and_warn
(
global_option
:
str
,
required_cudnn_version
:
int
)
->
bool
:
cudnn_available
=
torch
.
backends
.
cudnn
.
is_available
()
cudnn_version
=
torch
.
backends
.
cudnn
.
version
()
if
cudnn_available
else
None
if
not
(
cudnn_available
and
(
cudnn_version
>=
required_cudnn_version
)):
warnings
.
warn
(
f
"Skip `
{
global_option
}
` as it requires cuDNN
{
required_cudnn_version
}
or later, "
f
"but
{
'cuDNN is not available'
if
not
cudnn_available
else
cudnn_version
}
"
)
return
False
return
True
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
...
...
@@ -524,9 +573,13 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
)
)
if
"--transducer"
in
sys
.
argv
:
if
"--transducer"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--transducer"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--transducer"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--transducer"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"transducer_joint_cuda"
,
...
...
@@ -536,7 +589,8 @@ if "--transducer" in sys.argv:
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
,
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
+
generator_flag
),
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
+
generator_flag
)
if
not
IS_ROCM_PYTORCH
else
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
,
},
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"csrc"
),
os
.
path
.
join
(
this_dir
,
"apex/contrib/csrc/multihead_attn"
)],
)
...
...
@@ -551,7 +605,8 @@ if "--transducer" in sys.argv:
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"csrc"
)],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
,
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
),
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
)
if
not
IS_ROCM_PYTORCH
else
[
"-O3"
]
+
version_dependent_macros
,
},
)
)
...
...
@@ -571,9 +626,13 @@ if "--fast_bottleneck" in sys.argv:
)
)
if
"--peer_memory"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--peer_memory"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--peer_memory"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"peer_memory_cuda"
,
...
...
@@ -585,9 +644,13 @@ if "--peer_memory" in sys.argv:
)
)
if
"--nccl_p2p"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--nccl_p2p"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--nccl_p2p"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"nccl_p2p_cuda"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment