Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
40a0e025
Commit
40a0e025
authored
Mar 22, 2022
by
Thor Johnsen
Browse files
Peer memory halo exchange
parent
feae3851
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
546 additions
and
0 deletions
+546
-0
apex/contrib/csrc/peer_memory/peer_memory.cpp
apex/contrib/csrc/peer_memory/peer_memory.cpp
+28
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+409
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+44
-0
apex/contrib/peer_memory/__init__.py
apex/contrib/peer_memory/__init__.py
+1
-0
apex/contrib/peer_memory/peer_memory.py
apex/contrib/peer_memory/peer_memory.py
+50
-0
setup.py
setup.py
+14
-0
No files found.
apex/contrib/csrc/peer_memory/peer_memory.cpp
0 → 100644
View file @
40a0e025
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"allocate_raw"
,
&
apex
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
m
.
def
(
"blob_view_float"
,
&
apex
::
peer_memory
::
blob_view_float
,
"blob_view_float"
);
m
.
def
(
"blob_view_int"
,
&
apex
::
peer_memory
::
blob_view_int
,
"blob_view_int"
);
m
.
def
(
"push_pull_halos_1d"
,
&
apex
::
peer_memory
::
push_pull_halos_1d
,
"push_pull_halos_1d"
);
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
0 → 100644
View file @
40a0e025
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
char hostname[1024]; \
gethostname(hostname, 1024); \
printf("%s: CUDA failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,cudaGetErrorString(err)); \
} \
} while(0)
namespace
{
/* Basic deleter function for from_blob function.
void deleter(void* ptr)
{
printf("deleter(ptr=%p)\n",ptr);
cudaFree(ptr);
}
*/
template
<
class
T
>
at
::
Tensor
blob_view
(
T
*
raw_ptr
,
std
::
vector
<
int64_t
>
shape
,
const
at
::
TensorOptions
&
options
)
{
std
::
vector
<
int64_t
>
strides
(
shape
.
size
());
size_t
size
=
1
;
int
idx
=
strides
.
size
();
for
(
auto
it
=
shape
.
rbegin
();
it
!=
shape
.
rend
();
++
it
)
{
strides
[
--
idx
]
=
size
;
size
*=
*
it
;
}
size
*=
sizeof
(
T
);
// TODO: Implement dynamic reuse of pooled peer memory.
// We provide no deleter function because all peer memory allocations are static in this implementation.
return
torch
::
from_blob
((
void
*
)
raw_ptr
,
shape
,
strides
,
0L
,
options
);
}
void
tensor_shape
(
at
::
Tensor
t
,
bool
explicit_nhwc
,
int
&
N
,
int
&
C
,
int
&
H
,
int
&
W
)
{
if
(
t
.
dim
()
==
3
)
{
N
=
1
;
if
(
explicit_nhwc
)
{
C
=
t
.
size
(
2
);
H
=
t
.
size
(
0
);
W
=
t
.
size
(
1
);
}
else
{
C
=
t
.
size
(
0
);
H
=
t
.
size
(
1
);
W
=
t
.
size
(
2
);
}
}
else
if
(
t
.
dim
()
==
4
)
{
if
(
explicit_nhwc
)
{
N
=
t
.
size
(
0
);
C
=
t
.
size
(
3
);
H
=
t
.
size
(
1
);
W
=
t
.
size
(
2
);
}
else
{
N
=
t
.
size
(
0
);
C
=
t
.
size
(
1
);
H
=
t
.
size
(
2
);
W
=
t
.
size
(
3
);
}
}
else
{
printf
(
"%s;%d - t.dim() must be either 3 or 4 (was %d)
\n
"
,
__FILE__
,
__LINE__
,
t
.
dim
());
assert
(
t
.
dim
()
==
3
||
t
.
dim
()
==
4
);
}
}
void
tensor_strides
(
at
::
Tensor
t
,
bool
explicit_nhwc
,
int
&
stride_N
,
int
&
stride_C
,
int
&
stride_H
,
int
&
stride_W
)
{
if
(
t
.
dim
()
==
3
)
{
if
(
explicit_nhwc
)
{
stride_C
=
t
.
stride
(
2
);
stride_H
=
t
.
stride
(
0
);
stride_W
=
t
.
stride
(
1
);
}
else
{
stride_C
=
t
.
stride
(
0
);
stride_H
=
t
.
stride
(
1
);
stride_W
=
t
.
stride
(
2
);
}
stride_N
=
t
.
size
(
0
)
*
t
.
size
(
1
)
*
t
.
size
(
2
);
}
else
if
(
t
.
dim
()
==
4
)
{
if
(
explicit_nhwc
)
{
stride_N
=
t
.
stride
(
0
);
stride_C
=
t
.
stride
(
3
);
stride_H
=
t
.
stride
(
1
);
stride_W
=
t
.
stride
(
2
);
}
else
{
stride_N
=
t
.
stride
(
0
);
stride_C
=
t
.
stride
(
1
);
stride_H
=
t
.
stride
(
2
);
stride_W
=
t
.
stride
(
3
);
}
}
else
{
printf
(
"%s;%d - t.dim() must be either 3 or 4 (was %d)
\n
"
,
__FILE__
,
__LINE__
,
t
.
dim
());
assert
(
t
.
dim
()
==
3
||
t
.
dim
()
==
4
);
}
}
template
<
class
T
,
bool
is_HWC
>
__device__
void
strided_copy_kernel
(
T
*
dst
,
const
int
dst_stride_C
,
const
int
dst_stride_H
,
const
int
dst_stride_W
,
const
T
*
src
,
const
int
src_stride_C
,
const
int
src_stride_H
,
const
int
src_stride_W
,
const
int
NC
,
const
int
NH
,
const
int
NW
)
{
size_t
tot_num_threads
=
gridDim
.
x
*
blockDim
.
x
;
size_t
thread_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
count
=
NC
*
NH
*
NW
;
for
(
size_t
i
=
thread_id
;
i
<
count
;
i
+=
tot_num_threads
)
{
size_t
c
,
h
,
w
;
if
(
is_HWC
)
{
c
=
i
%
NC
;
w
=
i
/
NC
;
h
=
w
/
NW
;
w
=
w
%
NW
;
}
else
{
w
=
i
%
NW
;
h
=
i
/
NW
;
c
=
h
/
NH
;
h
=
h
%
NH
;
}
size_t
dst_off
=
c
*
dst_stride_C
+
h
*
dst_stride_H
+
w
*
dst_stride_W
;
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
dst
[
dst_off
]
=
src
[
src_off
];
}
}
__device__
void
dual_signal_wait_clear
(
volatile
int
*
signal1_flag
,
volatile
int
*
wait1_flag
,
volatile
int
*
signal2_flag
,
volatile
int
*
wait2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
,
const
bool
clear
)
{
register
int
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
;
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
// signal and wait
if
(
is_main_thread
)
{
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait1_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r5
),
"=r"
(
r6
),
"=r"
(
r7
),
"=r"
(
r8
)
:
"l"
(
wait2_flag
)
:
"memory"
);
}
while
(
r1
!=
v1
||
r5
!=
v1
||
r2
!=
v2
||
r6
!=
v2
||
r3
!=
v3
||
r7
!=
v3
||
r4
!=
v4
||
r8
!=
v4
);
}
cg
::
this_grid
().
sync
();
// optionally clear wait flag
if
(
clear
&&
is_main_thread
)
{
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait1_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait2_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
}
}
template
<
class
T
,
bool
is_HWC
>
#if __CUDA_ARCH__ >= 700
__launch_bounds__
(
128
,
16
)
#endif
__global__
void
push_pull_halos_1d_kernel
(
// top halo,
const
T
*
toh
,
int
toh_stride_C
,
int
toh_stride_H
,
int
toh_stride_W
,
// top output halo
T
*
tox
,
int
tox_stride_C
,
int
tox_stride_H
,
int
tox_stride_W
,
// top tx buffer
T
*
tih
,
int
tih_stride_C
,
int
tih_stride_H
,
int
tih_stride_W
,
// top input halo
// btm halo
const
T
*
boh
,
int
boh_stride_C
,
int
boh_stride_H
,
int
boh_stride_W
,
// top output halo
T
*
box
,
int
box_stride_C
,
int
box_stride_H
,
int
box_stride_W
,
// top tx buffer
T
*
bih
,
int
bih_stride_C
,
int
bih_stride_H
,
int
bih_stride_W
,
// top input halo
// dimensions
int
NC
,
int
NH
,
int
NW
,
// signals
int
*
signal1_flag
,
int
*
signal2_flag
,
int
*
wait1_flag
,
int
*
wait2_flag
)
{
// push top output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
// push btm output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
dual_signal_wait_clear
(
signal1_flag
,
wait1_flag
,
signal2_flag
,
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
,
true
);
// pull top halo from transfer buffer in peer memory to input
strided_copy_kernel
<
T
,
is_HWC
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
NC
,
NH
,
NW
);
// pull btm halo from transfer buffer in peer memory to input
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
NC
,
NH
,
NW
);
}
}
namespace
apex
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
)
{
float
*
ptr
=
0L
;
cudaMalloc
(
&
ptr
,
size
);
return
(
int64_t
)
ptr
;
}
void
free_raw
(
int64_t
raw
)
{
cudaFree
((
void
*
)
raw
);
}
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
)
{
cudaIpcMemHandle_t
mem_handle
;
CUDACHECK
(
cudaIpcGetMemHandle
(
&
mem_handle
,
(
void
*
)
raw
)
);
const
int
n
=
sizeof
(
cudaIpcMemHandle_t
);
auto
address_tensor
=
torch
::
empty
({
n
},
torch
::
dtype
(
torch
::
kUInt8
));
auto
address_tensor_p
=
address_tensor
.
data_ptr
<
uint8_t
>
();
memcpy
(
address_tensor_p
,
(
uint8_t
*
)
&
mem_handle
,
n
);
return
address_tensor
;
}
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
)
{
int
peer_group_size
=
ipc_addresses
.
size
(
0
);
std
::
vector
<
int64_t
>
results
(
peer_group_size
);
for
(
int
i
=
0
;
i
<
peer_group_size
;
++
i
)
{
if
(
i
!=
peer_rank
)
{
cudaIpcMemHandle_t
mem_handle
;
memcpy
(
&
mem_handle
,
ipc_addresses
.
index
({
i
}).
data_ptr
<
uint8_t
>
(),
sizeof
(
cudaIpcMemHandle_t
));
void
*
p
=
0L
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
p
,
mem_handle
,
cudaIpcMemLazyEnablePeerAccess
)
);
results
[
i
]
=
(
int64_t
)
p
;
}
else
{
results
[
i
]
=
(
int64_t
)
raw
;
}
}
return
results
;
}
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
)
{
return
blob_view
<
at
::
Half
>
((
at
::
Half
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
));
}
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
)
{
return
blob_view
<
float
>
((
float
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
));
}
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
)
{
return
blob_view
<
int
>
((
int
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
));
}
void
push_pull_halos_1d
(
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
)
{
// basic checks of inputs
TORCH_CHECK
(
top_out_halo
.
is_cuda
());
TORCH_CHECK
(
top_out_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_halo
.
is_cuda
());
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
waits
.
is_cuda
());
// shapes and strides
int
toh_N
,
toh_C
,
toh_H
,
toh_W
;
tensor_shape
(
top_out_halo
,
explicit_nhwc
,
toh_N
,
toh_C
,
toh_H
,
toh_W
);
int
tox_N
,
tox_C
,
tox_H
,
tox_W
;
tensor_shape
(
top_out_tx
,
explicit_nhwc
,
tox_N
,
tox_C
,
tox_H
,
tox_W
);
int
tih_N
,
tih_C
,
tih_H
,
tih_W
;
tensor_shape
(
top_inp_halo
,
explicit_nhwc
,
tih_N
,
tih_C
,
tih_H
,
tih_W
);
TORCH_CHECK
(
(
toh_N
==
tox_N
&&
tox_N
==
tih_N
)
&&
(
toh_C
==
tox_C
&&
tox_C
==
tih_C
)
&&
(
toh_H
==
tox_H
&&
tox_H
==
tih_H
)
&&
(
toh_W
==
tox_W
&&
tox_W
==
tih_W
));
int
boh_N
,
boh_C
,
boh_H
,
boh_W
;
tensor_shape
(
btm_out_halo
,
explicit_nhwc
,
boh_N
,
boh_C
,
boh_H
,
boh_W
);
int
box_N
,
box_C
,
box_H
,
box_W
;
tensor_shape
(
btm_out_tx
,
explicit_nhwc
,
box_N
,
box_C
,
box_H
,
box_W
);
int
bih_N
,
bih_C
,
bih_H
,
bih_W
;
tensor_shape
(
btm_inp_halo
,
explicit_nhwc
,
bih_N
,
bih_C
,
bih_H
,
bih_W
);
TORCH_CHECK
(
(
boh_N
==
box_N
&&
box_N
==
bih_N
)
&&
(
boh_C
==
box_C
&&
box_C
==
bih_C
)
&&
(
boh_H
==
box_H
&&
box_H
==
bih_H
)
&&
(
boh_W
==
box_W
&&
box_W
==
bih_W
));
TORCH_CHECK
(
(
toh_N
==
boh_N
)
&&
(
toh_C
==
boh_C
)
&&
(
toh_H
==
boh_H
)
&&
(
toh_W
==
boh_W
));
int
NC
=
toh_C
,
NH
=
toh_H
,
NW
=
toh_W
;
int
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
;
tensor_strides
(
top_out_halo
,
explicit_nhwc
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
int
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
;
tensor_strides
(
top_out_tx
,
explicit_nhwc
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
int
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
;
tensor_strides
(
top_inp_halo
,
explicit_nhwc
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
int
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
;
tensor_strides
(
btm_out_halo
,
explicit_nhwc
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
int
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
;
tensor_strides
(
btm_out_tx
,
explicit_nhwc
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
int
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
;
tensor_strides
(
btm_inp_halo
,
explicit_nhwc
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
// determine if nhwc
auto
is_nhwc
=
(
toh_stride_C
==
1
)
?
true
:
false
;
// figure out launch parameters
int
device
;
cudaGetDevice
(
&
device
);
cudaDeviceProp
prop
;
cudaGetDeviceProperties
(
&
prop
,
device
);
assert
(
numSM
>
0
&&
numSM
<=
prop
.
multiProcessorCount
);
auto
current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
numThreads
=
128
;
dim3
block
(
numThreads
,
1
,
1
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
top_out_halo
.
scalar_type
(),
"push_pull_halos_1d_kernel"
,
[
&
]{
scalar_t
*
toh_p
=
top_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tox_p
=
top_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tih_p
=
top_inp_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
boh_p
=
btm_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
box_p
=
btm_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bih_p
=
btm_inp_halo
.
data_ptr
<
scalar_t
>
();
int
*
top_signal_p
=
top_signal
.
data_ptr
<
int
>
();
int
*
btm_signal_p
=
btm_signal
.
data_ptr
<
int
>
()
+
4
;
int
*
top_wait_p
=
waits
.
data_ptr
<
int
>
();
int
*
btm_wait_p
=
waits
.
data_ptr
<
int
>
()
+
4
;
// do int4 vector loads if channel count permits
int
elem_size_in_bytes
=
toh_C
*
sizeof
(
scalar_t
);
int
elem_size_in_int4
=
(
elem_size_in_bytes
/
16
);
if
(
is_nhwc
&&
elem_size_in_int4
*
16
==
elem_size_in_bytes
)
{
// can do int4 transfers
int
divisor
=
elem_size_in_bytes
/
elem_size_in_int4
;
toh_stride_N
/=
divisor
;
toh_stride_H
/=
divisor
;
toh_stride_W
/=
divisor
;
tox_stride_N
/=
divisor
;
tox_stride_H
/=
divisor
;
tox_stride_W
/=
divisor
;
tih_stride_N
/=
divisor
;
tih_stride_H
/=
divisor
;
tih_stride_W
/=
divisor
;
boh_stride_N
/=
divisor
;
boh_stride_H
/=
divisor
;
boh_stride_W
/=
divisor
;
box_stride_N
/=
divisor
;
box_stride_H
/=
divisor
;
box_stride_W
/=
divisor
;
bih_stride_N
/=
divisor
;
bih_stride_H
/=
divisor
;
bih_stride_W
/=
divisor
;
void
*
kernelArgs
[]
=
{
(
int4
**
)
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
(
int4
**
)
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
(
int4
**
)
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
(
int4
**
)
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
(
int4
**
)
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
(
int4
**
)
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
else
{
// cannot do int4 transfers
void
*
kernelArgs
[]
=
{
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
int
numBlocksPerSm
;
if
(
is_nhwc
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
}
}
);
}
}
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
0 → 100644
View file @
40a0e025
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _peer_memory_h_
#define _peer_memory_h_
namespace
apex
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
);
void
free_raw
(
int64_t
raw
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
);
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
);
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
);
void
push_pull_halos_1d
(
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
);
}
}
#endif
apex/contrib/peer_memory/__init__.py
0 → 100644
View file @
40a0e025
from
.peer_memory
import
PeerMemoryPool
apex/contrib/peer_memory/peer_memory.py
0 → 100644
View file @
40a0e025
import
torch
import
numpy
as
np
import
peer_memory
class
PeerMemoryPool
(
object
):
def
__init__
(
self
,
rank
,
world_size
,
peer_group_size
,
size
):
self
.
peer_group
=
rank
//
peer_group_size
self
.
peer_rank
=
rank
%
peer_group_size
self
.
peer_group_size
=
peer_group_size
self
.
alignment
=
256
self
.
size
=
size
# allocate giant pool of device memory
self
.
raw
=
allocate_raw
(
size
)
# exchange peer pointers with nccl
raw_ipc
=
get_raw_ipc_address
(
self
.
raw
).
cuda
()
peer_raw_ipcs
=
[
torch
.
empty_like
(
raw_ipc
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
peer_raw_ipcs
,
raw_ipc
)
peer_raw_ipcs
=
torch
.
stack
(
peer_raw_ipcs
).
cpu
()
self
.
peer_raw
=
get_raw_peers
(
peer_raw_ipcs
,
self
.
peer_rank
,
self
.
raw
)
self
.
current
=
0
def
__del__
(
self
):
free_raw
(
self
.
raw
)
def
reset
(
self
):
self
.
current
=
0
def
allocate_peer_tensors
(
self
,
shape
,
dtype
):
nels
=
np
.
prod
(
shape
)
if
dtype
==
torch
.
float16
:
elem_size
=
2
start
=
((
self
.
current
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
current
=
start
+
nels
*
elem_size
assert
(
self
.
current
<
self
.
size
),
"Peer memory pool exhausted"
return
[
blob_view_half
(
pr
+
start
,
shape
)
for
pr
in
self
.
peer_raw
]
elif
dtype
==
torch
.
float32
:
elem_size
=
4
start
=
((
self
.
current
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
current
=
start
+
nels
*
elem_size
assert
(
self
.
current
<
self
.
size
),
"Peer memory pool exhausted"
return
[
blob_view_float
(
pr
+
start
,
shape
)
for
pr
in
self
.
peer_raw
]
elif
dtype
==
torch
.
int32
:
elem_size
=
4
start
=
((
self
.
current
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
current
=
start
+
nels
*
elem_size
assert
(
self
.
current
<
self
.
size
),
"Peer memory pool exhausted"
return
[
blob_view_int
(
pr
+
start
,
shape
)
for
pr
in
self
.
peer_raw
]
else
:
assert
(
False
),
"Unknown dtype : %s"
%
(
str
(
dtype
))
setup.py
View file @
40a0e025
...
@@ -627,6 +627,20 @@ if "--fast_bottleneck" in sys.argv:
...
@@ -627,6 +627,20 @@ if "--fast_bottleneck" in sys.argv:
)
)
)
)
if
"--peer_memory"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--peer_memory"
)
raise_if_cuda_home_none
(
"--peer_memory"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"peer_memory"
,
sources
=
[
"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu"
,
"apex/contrib/csrc/peer_memory/peer_memory.cpp"
,
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
version_dependent_macros
+
generator_flag
},
)
)
setup
(
setup
(
name
=
"apex"
,
name
=
"apex"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment