Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
a4eb97fb
"cmake" did not exist on "a1d50f0f53103e0b85c2265312475e8c3a036443"
Commit
a4eb97fb
authored
Mar 23, 2022
by
Thor Johnsen
Browse files
Bug fixes
parent
40a0e025
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
148 additions
and
70 deletions
+148
-70
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+97
-42
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+6
-3
apex/contrib/peer_memory/peer_memory.py
apex/contrib/peer_memory/peer_memory.py
+45
-25
No files found.
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
a4eb97fb
...
...
@@ -30,16 +30,24 @@ void deleter(void* ptr)
*/
template
<
class
T
>
at
::
Tensor
blob_view
(
T
*
raw_ptr
,
std
::
vector
<
int64_t
>
shape
,
const
at
::
TensorOptions
&
options
)
at
::
Tensor
blob_view
(
T
*
raw_ptr
,
std
::
vector
<
int64_t
>
shape
,
const
at
::
TensorOptions
&
options
,
bool
channels_last
)
{
std
::
vector
<
int64_t
>
strides
(
shape
.
size
());
size_t
size
=
1
;
std
::
vector
<
int64_t
>
strides
(
shape
.
size
());
if
(
channels_last
)
{
assert
(
shape
.
size
()
==
4
);
strides
[
0
]
=
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
];
strides
[
1
]
=
1
;
strides
[
2
]
=
shape
[
1
]
*
shape
[
3
];
strides
[
3
]
=
shape
[
1
];
}
else
{
int
idx
=
strides
.
size
();
for
(
auto
it
=
shape
.
rbegin
();
it
!=
shape
.
rend
();
++
it
)
{
strides
[
--
idx
]
=
size
;
size
*=
*
it
;
}
}
size
*=
sizeof
(
T
);
// TODO: Implement dynamic reuse of pooled peer memory.
// We provide no deleter function because all peer memory allocations are static in this implementation.
...
...
@@ -139,11 +147,11 @@ __device__ void strided_copy_kernel(
}
}
template
<
bool
wait
,
bool
clear
>
__device__
void
dual_signal_wait_clear
(
volatile
int
*
signal1_flag
,
volatile
int
*
wait1_flag
,
volatile
int
*
signal2_flag
,
volatile
int
*
wait2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
,
const
bool
clear
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
{
register
int
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
;
...
...
@@ -152,18 +160,21 @@ __device__ void dual_signal_wait_clear(
if
(
is_main_thread
)
{
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
if
(
wait
)
{
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait1_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r5
),
"=r"
(
r6
),
"=r"
(
r7
),
"=r"
(
r8
)
:
"l"
(
wait2_flag
)
:
"memory"
);
}
while
(
r1
!=
v1
||
r5
!=
v1
||
r2
!=
v2
||
r6
!=
v2
||
r3
!=
v3
||
r7
!=
v3
||
r4
!=
v4
||
r8
!=
v4
);
}
}
cg
::
this_grid
().
sync
();
// optionally clear wait flag
if
(
clear
&&
is_main_thread
)
{
if
(
clear
)
{
if
(
is_main_thread
)
{
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait1_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait2_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
}
}
}
template
<
class
T
,
bool
is_HWC
>
...
...
@@ -173,12 +184,14 @@ __launch_bounds__(128, 16)
__global__
void
push_pull_halos_1d_kernel
(
// top halo,
const
T
*
toh
,
int
toh_stride_C
,
int
toh_stride_H
,
int
toh_stride_W
,
// top output halo
T
*
tox
,
int
tox_stride_C
,
int
tox_stride_H
,
int
tox_stride_W
,
// top tx buffer
T
*
tox
,
int
tox_stride_C
,
int
tox_stride_H
,
int
tox_stride_W
,
// top output tx buffer
T
*
tix
,
int
tix_stride_C
,
int
tix_stride_H
,
int
tix_stride_W
,
// top input tx buffer
T
*
tih
,
int
tih_stride_C
,
int
tih_stride_H
,
int
tih_stride_W
,
// top input halo
// btm halo
const
T
*
boh
,
int
boh_stride_C
,
int
boh_stride_H
,
int
boh_stride_W
,
// top output halo
T
*
box
,
int
box_stride_C
,
int
box_stride_H
,
int
box_stride_W
,
// top tx buffer
T
*
bih
,
int
bih_stride_C
,
int
bih_stride_H
,
int
bih_stride_W
,
// top input halo
const
T
*
boh
,
int
boh_stride_C
,
int
boh_stride_H
,
int
boh_stride_W
,
// btm output halo
T
*
box
,
int
box_stride_C
,
int
box_stride_H
,
int
box_stride_W
,
// btm output tx buffer
T
*
bix
,
int
bix_stride_C
,
int
bix_stride_H
,
int
bix_stride_W
,
// btm input tx buffer
T
*
bih
,
int
bih_stride_C
,
int
bih_stride_H
,
int
bih_stride_W
,
// btm input halo
// dimensions
int
NC
,
int
NH
,
int
NW
,
// signals
...
...
@@ -194,11 +207,11 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
dual_signal_wait_clear
(
signal1_flag
,
wait1_flag
,
signal2_flag
,
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
,
true
);
dual_signal_wait_clear
<
true
,
true
>
(
signal1_flag
,
wait1_flag
,
signal2_flag
,
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
// pull top halo from transfer buffer in peer memory to input
strided_copy_kernel
<
T
,
is_HWC
>
(
t
ox
,
t
ox
_stride_C
,
t
ox
_stride_H
,
t
ox
_stride_W
,
ti
h
,
ti
h
_stride_C
,
ti
h
_stride_H
,
ti
h
_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
>
(
t
ih
,
t
ih
_stride_C
,
t
ih
_stride_H
,
t
ih
_stride_W
,
ti
x
,
ti
x
_stride_C
,
ti
x
_stride_H
,
ti
x
_stride_W
,
NC
,
NH
,
NW
);
// pull btm halo from transfer buffer in peer memory to input
strided_copy_kernel
<
T
,
is_HWC
>
(
b
ox
,
b
ox
_stride_C
,
b
ox
_stride_H
,
b
ox
_stride_W
,
bi
h
,
bi
h
_stride_C
,
bi
h
_stride_H
,
bi
h
_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
>
(
b
ih
,
b
ih
_stride_C
,
b
ih
_stride_H
,
b
ih
_stride_W
,
bi
x
,
bi
x
_stride_C
,
bi
x
_stride_H
,
bi
x
_stride_W
,
NC
,
NH
,
NW
);
}
}
...
...
@@ -246,29 +259,32 @@ std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int6
return
results
;
}
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
)
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
at
::
Half
>
((
at
::
Half
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
));
return
blob_view
<
at
::
Half
>
((
at
::
Half
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
)
,
channels_last
);
}
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
)
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
float
>
((
float
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat
16
).
device
(
torch
::
kCUDA
));
return
blob_view
<
float
>
((
float
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat
32
).
device
(
torch
::
kCUDA
)
,
channels_last
);
}
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
)
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
int
>
((
int
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
k
Float16
).
device
(
torch
::
kCUDA
));
return
blob_view
<
int
>
((
int
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
k
Int32
).
device
(
torch
::
kCUDA
)
,
channels_last
);
}
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
...
...
@@ -278,9 +294,11 @@ void push_pull_halos_1d(
// basic checks of inputs
TORCH_CHECK
(
top_out_halo
.
is_cuda
());
TORCH_CHECK
(
top_out_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_halo
.
is_cuda
());
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
...
...
@@ -291,46 +309,56 @@ void push_pull_halos_1d(
tensor_shape
(
top_out_halo
,
explicit_nhwc
,
toh_N
,
toh_C
,
toh_H
,
toh_W
);
int
tox_N
,
tox_C
,
tox_H
,
tox_W
;
tensor_shape
(
top_out_tx
,
explicit_nhwc
,
tox_N
,
tox_C
,
tox_H
,
tox_W
);
int
tix_N
,
tix_C
,
tix_H
,
tix_W
;
tensor_shape
(
top_inp_tx
,
explicit_nhwc
,
tix_N
,
tix_C
,
tix_H
,
tix_W
);
int
tih_N
,
tih_C
,
tih_H
,
tih_W
;
tensor_shape
(
top_inp_halo
,
explicit_nhwc
,
tih_N
,
tih_C
,
tih_H
,
tih_W
);
TORCH_CHECK
(
(
toh_N
==
tox_N
&&
tox_N
==
tih_N
)
&&
(
toh_C
==
tox_C
&&
tox_C
==
tih_C
)
&&
(
toh_H
==
tox_H
&&
tox_H
==
tih_H
)
&&
(
toh_W
==
tox_W
&&
tox_W
==
tih_W
));
(
toh_N
==
tox_N
&&
tox_N
==
tix_N
&&
tix_N
==
tih_N
)
&&
(
toh_C
==
tox_C
&&
tox_C
==
tix_C
&&
tix_C
==
tih_C
)
&&
(
toh_H
==
tox_H
&&
tox_H
==
tix_H
&&
tix_H
==
tih_H
)
&&
(
toh_W
==
tox_W
&&
tox_W
==
tix_W
&&
tix_W
==
tih_W
));
int
boh_N
,
boh_C
,
boh_H
,
boh_W
;
tensor_shape
(
btm_out_halo
,
explicit_nhwc
,
boh_N
,
boh_C
,
boh_H
,
boh_W
);
int
box_N
,
box_C
,
box_H
,
box_W
;
tensor_shape
(
btm_out_tx
,
explicit_nhwc
,
box_N
,
box_C
,
box_H
,
box_W
);
int
bix_N
,
bix_C
,
bix_H
,
bix_W
;
tensor_shape
(
btm_inp_tx
,
explicit_nhwc
,
bix_N
,
bix_C
,
bix_H
,
bix_W
);
int
bih_N
,
bih_C
,
bih_H
,
bih_W
;
tensor_shape
(
btm_inp_halo
,
explicit_nhwc
,
bih_N
,
bih_C
,
bih_H
,
bih_W
);
TORCH_CHECK
(
(
boh_N
==
box_N
&&
box_N
==
bih_N
)
&&
(
boh_C
==
box_C
&&
box_C
==
bih_C
)
&&
(
boh_H
==
box_H
&&
box_H
==
bih_H
)
&&
(
boh_W
==
box_W
&&
box_W
==
bih_W
));
(
boh_N
==
box_N
&&
box_N
==
bix_N
&&
bix_N
==
bih_N
)
&&
(
boh_C
==
box_C
&&
box_C
==
bix_C
&&
bix_C
==
bih_C
)
&&
(
boh_H
==
box_H
&&
box_H
==
bix_H
&&
bix_H
==
bih_H
)
&&
(
boh_W
==
box_W
&&
box_W
==
bix_W
&&
bix_W
==
bih_W
));
TORCH_CHECK
(
(
toh_N
==
boh_N
)
&&
(
toh_C
==
boh_C
)
&&
(
toh_H
==
boh_H
)
&&
(
toh_W
==
boh_W
));
int
NC
=
toh_C
,
NH
=
toh_H
,
NW
=
toh_W
;
if
(
diagnostics
)
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
int
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
;
tensor_strides
(
top_out_halo
,
explicit_nhwc
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
int
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
;
tensor_strides
(
top_out_tx
,
explicit_nhwc
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
int
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
;
tensor_strides
(
top_inp_tx
,
explicit_nhwc
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
int
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
;
tensor_strides
(
top_inp_halo
,
explicit_nhwc
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
int
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
;
tensor_strides
(
btm_out_halo
,
explicit_nhwc
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
int
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
;
tensor_strides
(
btm_out_tx
,
explicit_nhwc
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
int
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
;
tensor_strides
(
btm_inp_tx
,
explicit_nhwc
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
int
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
;
tensor_strides
(
btm_inp_halo
,
explicit_nhwc
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
// determine if nhwc
auto
is_nhwc
=
(
toh_stride_C
==
1
)
?
true
:
false
;
if
(
diagnostics
)
printf
(
"is_nhwc = %s
\n
"
,
is_nhwc
?
"true"
:
"false"
);
// figure out launch parameters
int
device
;
...
...
@@ -342,35 +370,59 @@ void push_pull_halos_1d(
const
int
numThreads
=
128
;
dim3
block
(
numThreads
,
1
,
1
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
top_out_halo
.
scalar_type
(),
"push_pull_halos_1d_kernel"
,
[
&
]{
if
(
diagnostics
)
printf
(
"size(scalar_t) = %d
\n
"
,
sizeof
(
scalar_t
));
scalar_t
*
toh_p
=
top_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tox_p
=
top_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tix_p
=
top_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tih_p
=
top_inp_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
boh_p
=
btm_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
box_p
=
btm_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bix_p
=
btm_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bih_p
=
btm_inp_halo
.
data_ptr
<
scalar_t
>
();
int
*
top_signal_p
=
top_signal
.
data_ptr
<
int
>
();
int
*
btm_signal_p
=
btm_signal
.
data_ptr
<
int
>
()
+
4
;
if
(
diagnostics
)
printf
(
"waypoint1
\n
"
);
int
*
top_signal_p
=
top_signal
.
data_ptr
<
int
>
()
+
4
;
int
*
btm_signal_p
=
btm_signal
.
data_ptr
<
int
>
();
int
*
top_wait_p
=
waits
.
data_ptr
<
int
>
();
int
*
btm_wait_p
=
waits
.
data_ptr
<
int
>
()
+
4
;
if
(
diagnostics
)
printf
(
"waypoint2
\n
"
);
// do int4 vector loads if channel count permits
int
elem_size_in_bytes
=
toh_C
*
sizeof
(
scalar_t
);
int
elem_size_in_int4
=
(
elem_size_in_bytes
/
16
);
if
(
diagnostics
)
printf
(
"elem_size_in_bytes = %d, elem_size_in_int4 = %d
\n
"
,
elem_size_in_bytes
,
elem_size_in_int4
);
if
(
is_nhwc
&&
elem_size_in_int4
*
16
==
elem_size_in_bytes
)
{
// can do int4 transfers
int
divisor
=
elem_size_in_bytes
/
elem_size_in_int4
;
int
divisor
=
toh_C
/
elem_size_in_int4
;
if
(
diagnostics
)
printf
(
"CAN DO INT4 :: divisor = %d
\n
"
,
divisor
);
toh_stride_N
/=
divisor
;
toh_stride_H
/=
divisor
;
toh_stride_W
/=
divisor
;
tox_stride_N
/=
divisor
;
tox_stride_H
/=
divisor
;
tox_stride_W
/=
divisor
;
tix_stride_N
/=
divisor
;
tix_stride_H
/=
divisor
;
tix_stride_W
/=
divisor
;
tih_stride_N
/=
divisor
;
tih_stride_H
/=
divisor
;
tih_stride_W
/=
divisor
;
boh_stride_N
/=
divisor
;
boh_stride_H
/=
divisor
;
boh_stride_W
/=
divisor
;
box_stride_N
/=
divisor
;
box_stride_H
/=
divisor
;
box_stride_W
/=
divisor
;
bix_stride_N
/=
divisor
;
bix_stride_H
/=
divisor
;
bix_stride_W
/=
divisor
;
bih_stride_N
/=
divisor
;
bih_stride_H
/=
divisor
;
bih_stride_W
/=
divisor
;
NC
/=
divisor
;
if
(
diagnostics
)
{
printf
(
"divisor=%d
\n
"
,
divisor
);
printf
(
"toh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
printf
(
"tox_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
printf
(
"tix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
printf
(
"tih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
printf
(
"boh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
printf
(
"box_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
printf
(
"bix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
printf
(
"bih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
}
void
*
kernelArgs
[]
=
{
(
int4
**
)
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
(
int4
**
)
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
(
int4
**
)
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
(
int4
**
)
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
(
int4
**
)
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
(
int4
**
)
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
(
int4
**
)
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
(
int4
**
)
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
...
...
@@ -381,12 +433,15 @@ void push_pull_halos_1d(
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
else
{
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
void
*
kernelArgs
[]
=
{
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
a4eb97fb
...
...
@@ -24,17 +24,20 @@ namespace apex { namespace peer_memory {
void
free_raw
(
int64_t
raw
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
);
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
);
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
...
...
apex/contrib/peer_memory/peer_memory.py
View file @
a4eb97fb
import
torch
import
numpy
as
np
import
peer_memory
import
peer_memory
as
pm
class
PeerMemoryPool
(
object
):
def
__init__
(
self
,
rank
,
world_size
,
peer_group_size
,
size
):
def
__init__
(
self
,
rank
,
world_size
,
peer_group_size
,
static_size
,
dynamic_
size
):
self
.
peer_group
=
rank
//
peer_group_size
self
.
peer_rank
=
rank
%
peer_group_size
self
.
peer_group_size
=
peer_group_size
self
.
alignment
=
256
self
.
size
=
size
self
.
static_size
=
((
static_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
dynamic_size
=
((
dynamic_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
# allocate giant pool of device memory
self
.
raw
=
allocate_raw
(
size
)
self
.
raw
=
pm
.
allocate_raw
(
self
.
static_size
+
self
.
dynamic_
size
)
# exchange peer pointers with nccl
raw_ipc
=
get_raw_ipc_address
(
self
.
raw
).
cuda
()
raw_ipc
=
pm
.
get_raw_ipc_address
(
self
.
raw
).
cuda
()
peer_raw_ipcs
=
[
torch
.
empty_like
(
raw_ipc
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
peer_raw_ipcs
,
raw_ipc
)
peer_raw_ipcs
=
torch
.
stack
(
peer_raw_ipcs
).
cpu
()
self
.
peer_raw
=
get_raw_peers
(
peer_raw_ipcs
,
self
.
peer_rank
,
self
.
raw
)
self
.
current
=
0
self
.
peer_raw
=
pm
.
get_raw_peers
(
peer_raw_ipcs
,
self
.
peer_rank
,
self
.
raw
)
self
.
static_offset
=
0
self
.
dynamic_offset
=
0
def
__del__
(
self
):
free_raw
(
self
.
raw
)
pm
.
free_raw
(
self
.
raw
)
def
reset
(
self
):
self
.
curren
t
=
0
self
.
dynamic_offse
t
=
0
def
allocate_peer_tensors
(
self
,
shape
,
dtype
):
def
allocate_peer_tensors
(
self
,
shape
,
dtype
,
channels_last
,
dynamic
):
nels
=
np
.
prod
(
shape
)
if
dtype
==
torch
.
float16
:
elem_size
=
2
start
=
((
self
.
current
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
current
=
start
+
nels
*
elem_size
assert
(
self
.
current
<
self
.
size
),
"Peer memory pool exhausted"
return
[
blob_view_half
(
pr
+
start
,
shape
)
for
pr
in
self
.
peer_raw
]
elif
dtype
==
torch
.
float32
:
if
dynamic
:
start
=
((
self
.
dynamic_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
dynamic_offset
=
start
+
nels
*
elem_size
assert
(
self
.
dynamic_offset
<
self
.
dynamic_size
),
"Dynamic peer memory pool exhausted"
return
[
pm
.
blob_view_half
(
pr
+
self
.
static_size
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
else
:
start
=
((
self
.
static_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
static_offset
=
start
+
nels
*
elem_size
assert
(
self
.
static_offset
<
self
.
static_size
),
"Static peer memory pool exhausted"
return
[
pm
.
blob_view_half
(
pr
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
if
dtype
==
torch
.
float32
:
elem_size
=
4
start
=
((
self
.
current
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
current
=
start
+
nels
*
elem_size
assert
(
self
.
current
<
self
.
size
),
"Peer memory pool exhausted"
return
[
blob_view_float
(
pr
+
start
,
shape
)
for
pr
in
self
.
peer_raw
]
elif
dtype
==
torch
.
int32
:
if
dynamic
:
start
=
((
self
.
dynamic_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
dynamic_offset
=
start
+
nels
*
elem_size
assert
(
self
.
dynamic_offset
<
self
.
dynamic_size
),
"Dynamic peer memory pool exhausted"
return
[
pm
.
blob_view_float
(
pr
+
self
.
static_size
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
else
:
start
=
((
self
.
static_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
static_offset
=
start
+
nels
*
elem_size
assert
(
self
.
static_offset
<
self
.
static_size
),
"Static peer memory pool exhausted"
return
[
pm
.
blob_view_float
(
pr
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
if
dtype
==
torch
.
int32
:
elem_size
=
4
start
=
((
self
.
current
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
current
=
start
+
nels
*
elem_size
assert
(
self
.
current
<
self
.
size
),
"Peer memory pool exhausted"
return
[
blob_view_int
(
pr
+
start
,
shape
)
for
pr
in
self
.
peer_raw
]
if
dynamic
:
start
=
((
self
.
dynamic_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
dynamic_offset
=
start
+
nels
*
elem_size
assert
(
self
.
dynamic_offset
<
self
.
dynamic_size
),
"Dynamic peer memory pool exhausted"
return
[
pm
.
blob_view_int
(
pr
+
self
.
static_size
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
else
:
start
=
((
self
.
static_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
static_offset
=
start
+
nels
*
elem_size
assert
(
self
.
static_offset
<
self
.
static_size
),
"Static peer memory pool exhausted"
return
[
pm
.
blob_view_int
(
pr
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
else
:
assert
(
False
),
"
Unknown dtype : %s
"
%
(
str
(
dtype
))
assert
(
False
),
"
dtype %s not supported
"
%
(
str
(
dtype
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment