Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
7a344314
Unverified
Commit
7a344314
authored
Sep 08, 2022
by
Jithun Nair
Committed by
GitHub
Sep 08, 2022
Browse files
Merge branch 'master' into dev/hubertlu/focal_loss_and_index_mul_2d_cuda
parents
9187ea1d
ae5ca671
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
411 additions
and
108 deletions
+411
-108
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+2
-7
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+4
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+262
-48
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+2
-0
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
+8
-2
apex/contrib/peer_memory/__init__.py
apex/contrib/peer_memory/__init__.py
+1
-0
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
+14
-6
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
+36
-32
apex/contrib/test/run_rocm_extensions.py
apex/contrib/test/run_rocm_extensions.py
+1
-1
apex/contrib/test/transducer/test_transducer_joint.py
apex/contrib/test/transducer/test_transducer_joint.py
+7
-1
setup.py
setup.py
+74
-11
No files found.
apex/contrib/bottleneck/halo_exchangers.py
View file @
7a344314
...
@@ -107,15 +107,10 @@ class HaloExchangerPeer(HaloExchanger):
...
@@ -107,15 +107,10 @@ class HaloExchangerPeer(HaloExchanger):
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_output_halo
.
shape
),
right_output_halo
.
dtype
,
channels_last
,
True
)
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_output_halo
.
shape
),
right_output_halo
.
dtype
,
channels_last
,
True
)
pm
.
push_pull_halos_1d
(
pm
.
push_pull_halos_1d
(
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
self
.
left_zero
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
right_zero
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
signals
[
self
.
wrap_around_left_rank_in_group
],
self
.
signals
[
self
.
wrap_around_right_rank_in_group
],
self
.
signals
[
self
.
rank_in_group
]
self
.
signals
[
self
.
wrap_around_left_rank_in_group
],
self
.
signals
[
self
.
wrap_around_right_rank_in_group
],
self
.
signals
[
self
.
rank_in_group
]
)
)
# TODO: Add to push_pull_halos_1d kernel
if
self
.
left_zero
:
left_input_halo
.
zero_
()
if
self
.
right_zero
:
right_input_halo
.
zero_
()
if
not
inplace
:
if
not
inplace
:
return
left_input_halo
,
right_input_halo
return
left_input_halo
,
right_input_halo
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
View file @
7a344314
...
@@ -5,7 +5,11 @@
...
@@ -5,7 +5,11 @@
#include <cstdio>
#include <cstdio>
#include <ctime>
#include <ctime>
#include <cassert>
#include <cassert>
#ifdef __HIP_PLATFORM_HCC__
#include "rccl.h"
#else
#include "nccl.h"
#include "nccl.h"
#endif
/*
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
7a344314
...
@@ -5,8 +5,15 @@
...
@@ -5,8 +5,15 @@
#include <cstdio>
#include <cstdio>
#include <cassert>
#include <cassert>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#include "rccl.h"
#else
#include <cooperative_groups.h>
#include <cooperative_groups.h>
#include "nccl.h"
#include "nccl.h"
#endif
namespace
cg
=
cooperative_groups
;
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
#define CUDACHECK(cmd) do { \
...
@@ -117,7 +124,20 @@ void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride
...
@@ -117,7 +124,20 @@ void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride
}
}
}
}
template
<
class
T
,
bool
is_HWC
>
template
<
class
T
>
__device__
void
__zero
(
T
*
dst
)
{
*
dst
=
T
(
0
);
}
__device__
void
__zero
(
int4
*
dst
)
{
int4
v
;
v
.
x
=
v
.
y
=
v
.
z
=
v
.
w
=
0
;
*
dst
=
v
;
}
template
<
class
T
,
bool
is_HWC
,
bool
zero
>
__device__
void
strided_copy_kernel
(
__device__
void
strided_copy_kernel
(
T
*
dst
,
const
int
dst_stride_C
,
const
int
dst_stride_H
,
const
int
dst_stride_W
,
T
*
dst
,
const
int
dst_stride_C
,
const
int
dst_stride_H
,
const
int
dst_stride_W
,
const
T
*
src
,
const
int
src_stride_C
,
const
int
src_stride_H
,
const
int
src_stride_W
,
const
T
*
src
,
const
int
src_stride_C
,
const
int
src_stride_H
,
const
int
src_stride_W
,
...
@@ -131,23 +151,28 @@ __device__ void strided_copy_kernel(
...
@@ -131,23 +151,28 @@ __device__ void strided_copy_kernel(
{
{
size_t
c
,
h
,
w
;
size_t
c
,
h
,
w
;
if
(
is_HWC
)
{
if
(
is_HWC
)
{
c
=
i
%
NC
;
w
=
i
/
NC
;
w
=
i
/
NC
;
c
=
i
-
w
*
NC
;
h
=
w
/
NW
;
h
=
w
/
NW
;
w
=
w
%
NW
;
w
=
w
-
h
*
NW
;
}
}
else
{
else
{
w
=
i
%
NW
;
h
=
i
/
NW
;
h
=
i
/
NW
;
w
=
i
-
h
*
NW
;
c
=
h
/
NH
;
c
=
h
/
NH
;
h
=
h
%
NH
;
h
=
h
-
c
*
NH
;
}
}
size_t
dst_off
=
c
*
dst_stride_C
+
h
*
dst_stride_H
+
w
*
dst_stride_W
;
size_t
dst_off
=
c
*
dst_stride_C
+
h
*
dst_stride_H
+
w
*
dst_stride_W
;
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
if
(
zero
)
{
dst
[
dst_off
]
=
src
[
src_off
];
__zero
(
dst
+
dst_off
);
}
else
{
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
dst
[
dst_off
]
=
src
[
src_off
];
}
}
}
}
}
template
<
bool
top_zero
,
bool
btm_zero
>
__device__
void
checked_signal
(
__device__
void
checked_signal
(
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
...
@@ -160,29 +185,119 @@ __device__ void checked_signal(
...
@@ -160,29 +185,119 @@ __device__ void checked_signal(
__threadfence_system
();
__threadfence_system
();
// wait for top or bottom neighbor to clear signal
// wait for top or bottom neighbor to clear signal
register
int
r1
,
r2
,
r3
,
r4
;
register
int
r1
,
r2
,
r3
,
r4
;
bool
top_zeroed
=
false
,
btm_zeroed
=
false
,
top_done
=
false
,
btm_done
=
false
;
if
(
!
(
top_zero
||
btm_zero
))
{
do
{
bool
top_zeroed
=
false
,
top_done
=
false
;
bool
btm_zeroed
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
if
(
!
top_zeroed
)
{
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
}
}
if
(
!
btm_zeroed
)
{
if
(
!
btm_done
&&
btm_zeroed
)
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
// signal to bottom neighbor my output is ready
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
}
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
}
while
(
!
top_done
||
!
btm_done
);
if
(
!
top_done
&&
top_zeroed
)
{
}
else
if
(
top_zero
)
{
// signal to top neighbor my output is ready
bool
btm_zeroed
=
false
,
btm_done
=
false
;
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
do
{
top_done
=
true
;
do
{
}
if
(
!
btm_zeroed
)
{
if
(
!
btm_done
&&
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
// signal to bottom neighbor my output is ready
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
btm_done
=
true
;
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
}
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
}
while
(
!
top_done
||
!
btm_done
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
(
btm_zeroed
==
btm_done
);
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
}
}
while
(
!
btm_done
);
}
else
if
(
btm_zero
)
{
bool
top_zeroed
=
false
,
top_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
}
while
(
top_zeroed
==
top_done
);
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
}
}
while
(
!
top_done
);
}
}
}
}
}
...
@@ -196,7 +311,14 @@ __device__ void wait_for(
...
@@ -196,7 +311,14 @@ __device__ void wait_for(
register
int
r1
,
r2
,
r3
,
r4
;
register
int
r1
,
r2
,
r3
,
r4
;
// wait for senders to signal their output is read
// wait for senders to signal their output is read
do
{
do
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
wait_flag
);
r2
=
__builtin_nontemporal_load
(
wait_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
wait_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
wait_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
#endif
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
}
cg
::
this_grid
().
sync
();
// all threads wait for main
cg
::
this_grid
().
sync
();
// all threads wait for main
...
@@ -212,12 +334,19 @@ __device__ void clear_flag(
...
@@ -212,12 +334,19 @@ __device__ void clear_flag(
if
(
is_main_thread
)
{
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
register
int
r1
,
r2
,
r3
,
r4
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
r1
,
wait_flag
);
__builtin_nontemporal_store
(
r2
,
wait_flag
+
1
);
__builtin_nontemporal_store
(
r3
,
wait_flag
+
2
);
__builtin_nontemporal_store
(
r4
,
wait_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
#endif
}
}
}
}
template
<
class
T
,
bool
is_HWC
>
template
<
class
T
,
bool
is_HWC
,
bool
top_zero
,
bool
btm_zero
>
#if __CUDA_ARCH__
>
= 700
#if __CUDA_ARCH__
=
= 700
|| __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900
__launch_bounds__
(
128
,
16
)
__launch_bounds__
(
128
,
16
)
#endif
#endif
__global__
void
push_pull_halos_1d_kernel
(
__global__
void
push_pull_halos_1d_kernel
(
...
@@ -241,20 +370,34 @@ __global__ void push_pull_halos_1d_kernel(
...
@@ -241,20 +370,34 @@ __global__ void push_pull_halos_1d_kernel(
)
)
{
{
// push top output halo to transfer buffer
// push top output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
if
(
!
top_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
// push btm output halo to transfer buffer
// push btm output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
if
(
!
btm_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
if
(
!
(
top_zero
||
btm_zero
))
{
checked_signal
<
false
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
top_zero
)
{
checked_signal
<
true
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
btm_zero
)
{
checked_signal
<
false
,
true
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
// pull top halo from transfer buffer in peer memory to input
// pull top halo from transfer buffer in peer memory to input
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
if
(
top_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
}
else
{
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
}
// pull btm halo from transfer buffer in peer memory to input
// pull btm halo from transfer buffer in peer memory to input
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
if
(
btm_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
else
{
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
}
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
...
@@ -343,10 +486,12 @@ void push_pull_halos_1d(
...
@@ -343,10 +486,12 @@ void push_pull_halos_1d(
bool
diagnostics
,
bool
diagnostics
,
bool
explicit_nhwc
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
...
@@ -368,6 +513,7 @@ void push_pull_halos_1d(
...
@@ -368,6 +513,7 @@ void push_pull_halos_1d(
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
waits
.
is_cuda
());
TORCH_CHECK
(
waits
.
is_cuda
());
TORCH_CHECK
(
!
(
top_zero
&&
btm_zero
));
// shapes and strides
// shapes and strides
int
toh_N
,
toh_C
,
toh_H
,
toh_W
;
int
toh_N
,
toh_C
,
toh_H
,
toh_W
;
...
@@ -492,10 +638,34 @@ void push_pull_halos_1d(
...
@@ -492,10 +638,34 @@ void push_pull_halos_1d(
&
NC
,
&
NH
,
&
NW
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
};
int
numBlocksPerSm
;
if
(
top_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
>
,
numThreads
,
0
);
int
numBlocksPerSm
;
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
numThreads
,
0
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
}
else
{
// cannot do int4 transfers
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
...
@@ -513,13 +683,57 @@ void push_pull_halos_1d(
...
@@ -513,13 +683,57 @@ void push_pull_halos_1d(
};
};
int
numBlocksPerSm
;
int
numBlocksPerSm
;
if
(
is_nhwc
)
{
if
(
is_nhwc
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
numThreads
,
0
);
if
(
top_zero
)
{
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
numThreads
,
0
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
numThreads
,
0
);
if
(
top_zero
)
{
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
numThreads
,
0
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
}
}
}
}
);
}
);
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
7a344314
...
@@ -32,10 +32,12 @@ namespace apex { namespace contrib { namespace peer_memory {
...
@@ -32,10 +32,12 @@ namespace apex { namespace contrib { namespace peer_memory {
bool
diagnostics
,
bool
diagnostics
,
bool
explicit_nhwc
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
...
...
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
View file @
7a344314
...
@@ -17,12 +17,18 @@
...
@@ -17,12 +17,18 @@
#include "philox.cuh"
#include "philox.cuh"
#ifdef __HIP_PLATFORM_HCC__
#define SHFL_DOWN(val, laneMask, width) __shfl_down(val, laneMask, width)
#else
#define SHFL_DOWN(val, laneMask, width) __shfl_down_sync(0xffffffff, val, laneMask, width)
#endif
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize.
// width should be a power of 2 and should be less than warpSize.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
warpReduce
(
scalar_t
x
,
int
width
=
C10_WARP_SIZE
){
__device__
__forceinline__
scalar_t
warpReduce
(
scalar_t
x
,
int
width
=
C10_WARP_SIZE
){
for
(
unsigned
offset
=
width
/
2
;
offset
>
0
;
offset
/=
2
){
for
(
unsigned
offset
=
width
/
2
;
offset
>
0
;
offset
/=
2
){
x
+=
__shfl_down_sync
(
0xffffffff
,
x
,
offset
,
width
);
x
+=
SHFL_DOWN
(
x
,
offset
,
width
);
}
}
return
x
;
return
x
;
}
}
...
@@ -864,7 +870,7 @@ std::vector<torch::Tensor> transducer_joint_cuda_backward(
...
@@ -864,7 +870,7 @@ std::vector<torch::Tensor> transducer_joint_cuda_backward(
int64_t
*
batchOffsetPtr
=
(
!
packOutput
)
?
nullptr
:
batchOffset
.
data_ptr
<
int64_t
>
();
int64_t
*
batchOffsetPtr
=
(
!
packOutput
)
?
nullptr
:
batchOffset
.
data_ptr
<
int64_t
>
();
// The number "y" I would like each thread to work on
// The number "y" I would like each thread to work on
const
int
workPerThread
=
32
;
const
int
workPerThread
=
32
;
// Since the bwd for f and g have the same thread block size, we need to use the max of the two.
// Since the bwd for f and g have the same thread block size, we need to use the max of the two.
int
numWarp
=
largestPowerOfTwo
((
std
::
max
(
maxFLen
,
maxGLen
)
+
workPerThread
-
1
)
/
workPerThread
);
int
numWarp
=
largestPowerOfTwo
((
std
::
max
(
maxFLen
,
maxGLen
)
+
workPerThread
-
1
)
/
workPerThread
);
// Would like to have at least 2 warps
// Would like to have at least 2 warps
...
...
apex/contrib/peer_memory/__init__.py
View file @
7a344314
from
.peer_memory
import
PeerMemoryPool
from
.peer_memory
import
PeerMemoryPool
from
.peer_halo_exchanger_1d
import
PeerHaloExchanger1d
from
.peer_halo_exchanger_1d
import
PeerHaloExchanger1d
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
View file @
7a344314
...
@@ -40,8 +40,9 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
...
@@ -40,8 +40,9 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
btm_out_halo
=
y
[:,:,:,
W
:
W
+
half_halo
]
btm_out_halo
=
y
[:,:,:,
W
:
W
+
half_halo
]
btm_inp_halo
=
y
[:,:,:,
W
+
half_halo
:
W
+
2
*
half_halo
]
btm_inp_halo
=
y
[:,:,:,
W
+
half_halo
:
W
+
2
*
half_halo
]
top_out_halo
=
top_out_halo
.
clone
(
memory_format
=
torch
.
preserve_format
)
mf
=
torch
.
channels_last
if
y
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
else
torch
.
contiguous_format
btm_out_halo
=
btm_out_halo
.
clone
(
memory_format
=
torch
.
preserve_format
)
top_out_halo
=
top_out_halo
.
contiguous
()
btm_out_halo
=
btm_out_halo
.
contiguous
()
top_inp_halos
=
[
torch
.
empty_like
(
top_out_halo
)
for
_
in
range
(
peer_group_size
)]
top_inp_halos
=
[
torch
.
empty_like
(
top_out_halo
)
for
_
in
range
(
peer_group_size
)]
torch
.
distributed
.
all_gather
(
top_inp_halos
,
top_out_halo
)
torch
.
distributed
.
all_gather
(
top_inp_halos
,
top_out_halo
)
...
@@ -49,8 +50,14 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
...
@@ -49,8 +50,14 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
torch
.
distributed
.
all_gather
(
btm_inp_halos
,
btm_out_halo
)
torch
.
distributed
.
all_gather
(
btm_inp_halos
,
btm_out_halo
)
top_rank
=
(
peer_rank
+
peer_group_size
-
1
)
%
peer_group_size
top_rank
=
(
peer_rank
+
peer_group_size
-
1
)
%
peer_group_size
btm_rank
=
(
peer_rank
+
1
)
%
peer_group_size
btm_rank
=
(
peer_rank
+
1
)
%
peer_group_size
top_inp_halo
.
copy_
(
btm_inp_halos
[
top_rank
])
if
peer_rank
==
0
:
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
])
top_inp_halo
.
zero_
()
else
:
top_inp_halo
.
copy_
(
btm_inp_halos
[
top_rank
].
to
(
memory_format
=
mf
))
if
peer_rank
==
peer_group_size
-
1
:
btm_inp_halo
.
zero_
()
else
:
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
].
to
(
memory_format
=
mf
))
def
single_test
(
peer_rank
,
peer_group_size
,
halo_ex
,
C
,
H
,
W
,
half_halo
,
dtype
,
memory_format
,
H_split
,
num_steps
,
numSM
=
1
):
def
single_test
(
peer_rank
,
peer_group_size
,
halo_ex
,
C
,
H
,
W
,
half_halo
,
dtype
,
memory_format
,
H_split
,
num_steps
,
numSM
=
1
):
...
@@ -141,12 +148,13 @@ def main():
...
@@ -141,12 +148,13 @@ def main():
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
torch
.
cuda
.
set_device
(
rank
)
torch
.
cuda
.
set_device
(
rank
)
pool
=
PeerMemoryPool
(
rank
,
world_size
,
world_size
,
64
*
1024
,
2
*
1024
*
1024
)
peer_ranks
=
[
i
for
i
in
range
(
world_size
)]
pool
=
PeerMemoryPool
(
64
*
1024
,
2
*
1024
*
1024
,
peer_ranks
)
num_steps
=
100
num_steps
=
100
half_halo
=
1
half_halo
=
1
halo_ex
=
PeerHaloExchanger1d
(
rank
,
world_size
,
pool
,
half_halo
)
halo_ex
=
PeerHaloExchanger1d
(
peer_
rank
s
,
rank
,
pool
,
half_halo
)
H_split_tests
(
1
,
64
,
336
,
200
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
H_split_tests
(
1
,
64
,
336
,
200
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
W_split_tests
(
1
,
64
,
200
,
336
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
W_split_tests
(
1
,
64
,
200
,
336
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
...
...
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
View file @
7a344314
...
@@ -3,9 +3,15 @@ from apex.contrib.peer_memory import PeerMemoryPool
...
@@ -3,9 +3,15 @@ from apex.contrib.peer_memory import PeerMemoryPool
import
peer_memory_cuda
as
pm
import
peer_memory_cuda
as
pm
class
PeerHaloExchanger1d
:
class
PeerHaloExchanger1d
:
def
__init__
(
self
,
rank
,
peer_group_size
,
peer_pool
,
half_halo
):
def
__init__
(
self
,
ranks
,
rank_in_group
,
peer_pool
,
half_halo
):
self
.
peer_group_size
=
peer_group_size
self
.
peer_group_size
=
len
(
ranks
)
self
.
peer_rank
=
rank
%
peer_group_size
self
.
ranks
=
ranks
self
.
peer_rank
=
rank_in_group
self
.
low_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
self
.
high_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
self
.
low_zero
=
True
if
self
.
peer_rank
==
0
else
False
self
.
high_zero
=
True
if
self
.
peer_rank
==
self
.
peer_group_size
-
1
else
False
self
.
peer_pool
=
peer_pool
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
peer_rank
].
zero_
()
self
.
signals
[
self
.
peer_rank
].
zero_
()
...
@@ -17,45 +23,43 @@ class PeerHaloExchanger1d:
...
@@ -17,45 +23,43 @@ class PeerHaloExchanger1d:
if
explicit_nhwc
:
if
explicit_nhwc
:
_
,
Hs
,
_
,
_
=
list
(
y
.
shape
)
_
,
Hs
,
_
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
H
=
Hs
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,
self
.
half_halo
:
2
*
self
.
half_halo
,:,:]
low
_out_halo
=
y
[:,
self
.
half_halo
:
2
*
self
.
half_halo
,:,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
False
,
True
)
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
False
,
True
)
top
_inp_halo
=
y
[:,:
self
.
half_halo
,:,:]
low
_inp_halo
=
y
[:,:
self
.
half_halo
,:,:]
btm
_out_halo
=
y
[:,
H
:
H
+
self
.
half_halo
,:,:]
high
_out_halo
=
y
[:,
H
:
H
+
self
.
half_halo
,:,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
False
,
True
)
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
False
,
True
)
btm
_inp_halo
=
y
[:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:,:]
high
_inp_halo
=
y
[:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:,:]
else
:
else
:
_
,
_
,
Hs
,
_
=
list
(
y
.
shape
)
_
,
_
,
Hs
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
H
=
Hs
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
low
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
channels_last
,
True
)
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
channels_last
,
True
)
top
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
low
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm
_out_halo
=
y
[:,:,
H
:
H
+
self
.
half_halo
,:]
high
_out_halo
=
y
[:,:,
H
:
H
+
self
.
half_halo
,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
channels_last
,
True
)
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
channels_last
,
True
)
btm
_inp_halo
=
y
[:,:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:]
high
_inp_halo
=
y
[:,:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:]
else
:
else
:
if
explicit_nhwc
:
if
explicit_nhwc
:
_
,
_
,
Ws
,
_
=
list
(
y
.
shape
)
_
,
_
,
Ws
,
_
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
W
=
Ws
-
2
*
self
.
half_halo
top
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
low
_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top
_out_halo
.
shape
),
top
_out_halo
.
dtype
,
False
,
True
)
low
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low
_out_halo
.
shape
),
low
_out_halo
.
dtype
,
False
,
True
)
top
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
low
_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm
_out_halo
=
y
[:,:,
W
:
W
+
self
.
half_halo
,:]
high
_out_halo
=
y
[:,:,
W
:
W
+
self
.
half_halo
,:]
btm
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm
_out_halo
.
shape
),
btm
_out_halo
.
dtype
,
False
,
True
)
high
_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high
_out_halo
.
shape
),
high
_out_halo
.
dtype
,
False
,
True
)
btm
_inp_halo
=
y
[:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
,:]
high
_inp_halo
=
y
[:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
,:]
else
:
else
:
_
,
_
,
_
,
Ws
=
list
(
y
.
shape
)
_
,
_
,
_
,
Ws
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
W
=
Ws
-
2
*
self
.
half_halo
top_out_halo
=
y
[:,:,:,
self
.
half_halo
:
2
*
self
.
half_halo
]
low_out_halo
=
y
[:,:,:,
self
.
half_halo
:
2
*
self
.
half_halo
]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
channels_last
,
True
)
low_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
low_out_halo
.
shape
),
low_out_halo
.
dtype
,
channels_last
,
True
)
top_inp_halo
=
y
[:,:,:,:
self
.
half_halo
]
low_inp_halo
=
y
[:,:,:,:
self
.
half_halo
]
btm_out_halo
=
y
[:,:,:,
W
:
W
+
self
.
half_halo
]
high_out_halo
=
y
[:,:,:,
W
:
W
+
self
.
half_halo
]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
channels_last
,
True
)
high_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
high_out_halo
.
shape
),
high_out_halo
.
dtype
,
channels_last
,
True
)
btm_inp_halo
=
y
[:,:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
]
high_inp_halo
=
y
[:,:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
]
top_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
btm_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
pm
.
push_pull_halos_1d
(
pm
.
push_pull_halos_1d
(
diagnostics
,
explicit_nhwc
,
numSM
,
diagnostics
,
explicit_nhwc
,
numSM
,
top
_out_halo
,
top
_tx
[
self
.
peer_rank
],
btm_tx
[
top
_neighbor
],
top
_inp_halo
,
self
.
low_zero
,
low
_out_halo
,
low
_tx
[
self
.
peer_rank
],
high_tx
[
self
.
low
_neighbor
],
low
_inp_halo
,
btm
_out_halo
,
btm
_tx
[
self
.
peer_rank
],
top
_tx
[
btm
_neighbor
],
btm
_inp_halo
,
self
.
high_zero
,
high
_out_halo
,
high
_tx
[
self
.
peer_rank
],
low
_tx
[
self
.
high
_neighbor
],
high
_inp_halo
,
self
.
signals
[
top
_neighbor
],
self
.
signals
[
btm
_neighbor
],
self
.
signals
[
self
.
peer_rank
]
self
.
signals
[
self
.
low
_neighbor
],
self
.
signals
[
self
.
high
_neighbor
],
self
.
signals
[
self
.
peer_rank
]
)
)
apex/contrib/test/run_rocm_extensions.py
View file @
7a344314
...
@@ -2,7 +2,7 @@ import unittest
...
@@ -2,7 +2,7 @@ import unittest
import
sys
import
sys
test_dirs
=
[
"groupbn"
,
"layer_norm"
,
"multihead_attn"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
test_dirs
=
[
"groupbn"
,
"layer_norm"
,
"multihead_attn"
,
"transducer"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
ROCM_BLACKLIST
=
[
ROCM_BLACKLIST
=
[
"layer_norm"
"layer_norm"
]
]
...
...
apex/contrib/test/transducer/test_transducer_joint.py
View file @
7a344314
...
@@ -121,6 +121,7 @@ class TransducerJointTest(unittest.TestCase):
...
@@ -121,6 +121,7 @@ class TransducerJointTest(unittest.TestCase):
def
test_transducer_joint_vec
(
self
):
def
test_transducer_joint_vec
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
False
,
dropout
=
False
)
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
False
,
dropout
=
False
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_pack
(
self
):
def
test_transducer_joint_pack
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
False
,
dropout
=
False
)
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
False
,
dropout
=
False
)
...
@@ -133,25 +134,30 @@ class TransducerJointTest(unittest.TestCase):
...
@@ -133,25 +134,30 @@ class TransducerJointTest(unittest.TestCase):
def
test_transducer_joint_vec_relu
(
self
):
def
test_transducer_joint_vec_relu
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
True
,
dropout
=
False
)
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
True
,
dropout
=
False
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_pack_relu
(
self
):
def
test_transducer_joint_pack_relu
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
True
,
dropout
=
False
)
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
True
,
dropout
=
False
)
def
test_transducer_joint_vec_pack_relu
(
self
):
def
test_transducer_joint_vec_pack_relu
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
False
)
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
False
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_relu_dropout
(
self
):
def
test_transducer_joint_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_vec_relu_dropout
(
self
):
def
test_transducer_joint_vec_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
True
,
dropout
=
True
)
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
False
,
relu
=
True
,
dropout
=
True
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_pack_relu_dropout
(
self
):
def
test_transducer_joint_pack_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
self
.
run_transducer_joint
(
for_vector_kernel
=
False
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
@
unittest
.
skip
(
"Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89"
)
def
test_transducer_joint_vec_pack_relu_dropout
(
self
):
def
test_transducer_joint_vec_pack_relu_dropout
(
self
):
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
self
.
run_transducer_joint
(
for_vector_kernel
=
True
,
pack_output
=
True
,
relu
=
True
,
dropout
=
True
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
\ No newline at end of file
setup.py
View file @
7a344314
...
@@ -42,6 +42,55 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -42,6 +42,55 @@ def get_cuda_bare_metal_version(cuda_dir):
return
raw_output
,
bare_metal_major
,
bare_metal_minor
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_minor
):
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.
\n
"
.
format
(
torch
.
version
.
cuda
)
+
"In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def
raise_if_cuda_home_none
(
global_option
:
str
)
->
None
:
if
CUDA_HOME
is
not
None
:
return
raise
RuntimeError
(
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
def
check_cudnn_version_and_warn
(
global_option
:
str
,
required_cudnn_version
:
int
)
->
bool
:
cudnn_available
=
torch
.
backends
.
cudnn
.
is_available
()
cudnn_version
=
torch
.
backends
.
cudnn
.
version
()
if
cudnn_available
else
None
if
not
(
cudnn_available
and
(
cudnn_version
>=
required_cudnn_version
)):
warnings
.
warn
(
f
"Skip `
{
global_option
}
` as it requires cuDNN
{
required_cudnn_version
}
or later, "
f
"but
{
'cuDNN is not available'
if
not
cudnn_available
else
cudnn_version
}
"
)
return
False
return
True
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
...
@@ -524,9 +573,13 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
...
@@ -524,9 +573,13 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
)
)
)
)
if
"--transducer"
in
sys
.
argv
:
if
"--transducer"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--transducer"
)
if
"--transducer"
in
sys
.
argv
:
raise_if_cuda_home_none
(
"--transducer"
)
sys
.
argv
.
remove
(
"--transducer"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--transducer"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
name
=
"transducer_joint_cuda"
,
name
=
"transducer_joint_cuda"
,
...
@@ -536,7 +589,8 @@ if "--transducer" in sys.argv:
...
@@ -536,7 +589,8 @@ if "--transducer" in sys.argv:
],
],
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
,
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
,
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
+
generator_flag
),
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
+
generator_flag
)
if
not
IS_ROCM_PYTORCH
else
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
,
},
},
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"csrc"
),
os
.
path
.
join
(
this_dir
,
"apex/contrib/csrc/multihead_attn"
)],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"csrc"
),
os
.
path
.
join
(
this_dir
,
"apex/contrib/csrc/multihead_attn"
)],
)
)
...
@@ -551,7 +605,8 @@ if "--transducer" in sys.argv:
...
@@ -551,7 +605,8 @@ if "--transducer" in sys.argv:
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"csrc"
)],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
"csrc"
)],
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
,
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
,
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
),
"nvcc"
:
append_nvcc_threads
([
"-O3"
]
+
version_dependent_macros
)
if
not
IS_ROCM_PYTORCH
else
[
"-O3"
]
+
version_dependent_macros
,
},
},
)
)
)
)
...
@@ -571,9 +626,13 @@ if "--fast_bottleneck" in sys.argv:
...
@@ -571,9 +626,13 @@ if "--fast_bottleneck" in sys.argv:
)
)
)
)
if
"--peer_memory"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--peer_memory"
)
if
"--peer_memory"
in
sys
.
argv
:
raise_if_cuda_home_none
(
"--peer_memory"
)
sys
.
argv
.
remove
(
"--peer_memory"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--peer_memory"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
name
=
"peer_memory_cuda"
,
name
=
"peer_memory_cuda"
,
...
@@ -585,9 +644,13 @@ if "--peer_memory" in sys.argv:
...
@@ -585,9 +644,13 @@ if "--peer_memory" in sys.argv:
)
)
)
)
if
"--nccl_p2p"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--nccl_p2p"
)
if
"--nccl_p2p"
in
sys
.
argv
:
raise_if_cuda_home_none
(
"--nccl_p2p"
)
sys
.
argv
.
remove
(
"--nccl_p2p"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--nccl_p2p"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
name
=
"nccl_p2p_cuda"
,
name
=
"nccl_p2p_cuda"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment