Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
a4eb97fb
"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "1209978d97fc0a1bb04182916d6e3332b842ace3"
Commit
a4eb97fb
authored
Mar 23, 2022
by
Thor Johnsen
Browse files
Bug fixes
parent
40a0e025
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
148 additions
and
70 deletions
+148
-70
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+97
-42
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+6
-3
apex/contrib/peer_memory/peer_memory.py
apex/contrib/peer_memory/peer_memory.py
+45
-25
No files found.
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
a4eb97fb
...
@@ -30,15 +30,23 @@ void deleter(void* ptr)
...
@@ -30,15 +30,23 @@ void deleter(void* ptr)
*/
*/
template
<
class
T
>
template
<
class
T
>
at
::
Tensor
blob_view
(
T
*
raw_ptr
,
std
::
vector
<
int64_t
>
shape
,
const
at
::
TensorOptions
&
options
)
at
::
Tensor
blob_view
(
T
*
raw_ptr
,
std
::
vector
<
int64_t
>
shape
,
const
at
::
TensorOptions
&
options
,
bool
channels_last
)
{
{
std
::
vector
<
int64_t
>
strides
(
shape
.
size
());
size_t
size
=
1
;
size_t
size
=
1
;
int
idx
=
strides
.
size
();
std
::
vector
<
int64_t
>
strides
(
shape
.
size
());
for
(
auto
it
=
shape
.
rbegin
();
it
!=
shape
.
rend
();
++
it
)
if
(
channels_last
)
{
{
assert
(
shape
.
size
()
==
4
);
strides
[
--
idx
]
=
size
;
strides
[
0
]
=
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
];
size
*=
*
it
;
strides
[
1
]
=
1
;
strides
[
2
]
=
shape
[
1
]
*
shape
[
3
];
strides
[
3
]
=
shape
[
1
];
}
else
{
int
idx
=
strides
.
size
();
for
(
auto
it
=
shape
.
rbegin
();
it
!=
shape
.
rend
();
++
it
)
{
strides
[
--
idx
]
=
size
;
size
*=
*
it
;
}
}
}
size
*=
sizeof
(
T
);
size
*=
sizeof
(
T
);
// TODO: Implement dynamic reuse of pooled peer memory.
// TODO: Implement dynamic reuse of pooled peer memory.
...
@@ -139,11 +147,11 @@ __device__ void strided_copy_kernel(
...
@@ -139,11 +147,11 @@ __device__ void strided_copy_kernel(
}
}
}
}
template
<
bool
wait
,
bool
clear
>
__device__
void
dual_signal_wait_clear
(
__device__
void
dual_signal_wait_clear
(
volatile
int
*
signal1_flag
,
volatile
int
*
wait1_flag
,
volatile
int
*
signal1_flag
,
volatile
int
*
wait1_flag
,
volatile
int
*
signal2_flag
,
volatile
int
*
wait2_flag
,
volatile
int
*
signal2_flag
,
volatile
int
*
wait2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
const
bool
clear
)
)
{
{
register
int
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
;
register
int
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
;
...
@@ -152,17 +160,20 @@ __device__ void dual_signal_wait_clear(
...
@@ -152,17 +160,20 @@ __device__ void dual_signal_wait_clear(
if
(
is_main_thread
)
{
if
(
is_main_thread
)
{
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
do
{
if
(
wait
)
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait1_flag
)
:
"memory"
);
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r5
),
"=r"
(
r6
),
"=r"
(
r7
),
"=r"
(
r8
)
:
"l"
(
wait2_flag
)
:
"memory"
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait1_flag
)
:
"memory"
);
}
while
(
r1
!=
v1
||
r5
!=
v1
||
r2
!=
v2
||
r6
!=
v2
||
r3
!=
v3
||
r7
!=
v3
||
r4
!=
v4
||
r8
!=
v4
);
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r5
),
"=r"
(
r6
),
"=r"
(
r7
),
"=r"
(
r8
)
:
"l"
(
wait2_flag
)
:
"memory"
);
}
while
(
r1
!=
v1
||
r5
!=
v1
||
r2
!=
v2
||
r6
!=
v2
||
r3
!=
v3
||
r7
!=
v3
||
r4
!=
v4
||
r8
!=
v4
);
}
}
}
cg
::
this_grid
().
sync
();
cg
::
this_grid
().
sync
();
// optionally clear wait flag
if
(
clear
)
{
if
(
clear
&&
is_main_thread
)
{
if
(
is_main_thread
)
{
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait1_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait1_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait2_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait2_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
}
}
}
}
}
...
@@ -173,12 +184,14 @@ __launch_bounds__(128, 16)
...
@@ -173,12 +184,14 @@ __launch_bounds__(128, 16)
__global__
void
push_pull_halos_1d_kernel
(
__global__
void
push_pull_halos_1d_kernel
(
// top halo,
// top halo,
const
T
*
toh
,
int
toh_stride_C
,
int
toh_stride_H
,
int
toh_stride_W
,
// top output halo
const
T
*
toh
,
int
toh_stride_C
,
int
toh_stride_H
,
int
toh_stride_W
,
// top output halo
T
*
tox
,
int
tox_stride_C
,
int
tox_stride_H
,
int
tox_stride_W
,
// top tx buffer
T
*
tox
,
int
tox_stride_C
,
int
tox_stride_H
,
int
tox_stride_W
,
// top output tx buffer
T
*
tix
,
int
tix_stride_C
,
int
tix_stride_H
,
int
tix_stride_W
,
// top input tx buffer
T
*
tih
,
int
tih_stride_C
,
int
tih_stride_H
,
int
tih_stride_W
,
// top input halo
T
*
tih
,
int
tih_stride_C
,
int
tih_stride_H
,
int
tih_stride_W
,
// top input halo
// btm halo
// btm halo
const
T
*
boh
,
int
boh_stride_C
,
int
boh_stride_H
,
int
boh_stride_W
,
// top output halo
const
T
*
boh
,
int
boh_stride_C
,
int
boh_stride_H
,
int
boh_stride_W
,
// btm output halo
T
*
box
,
int
box_stride_C
,
int
box_stride_H
,
int
box_stride_W
,
// top tx buffer
T
*
box
,
int
box_stride_C
,
int
box_stride_H
,
int
box_stride_W
,
// btm output tx buffer
T
*
bih
,
int
bih_stride_C
,
int
bih_stride_H
,
int
bih_stride_W
,
// top input halo
T
*
bix
,
int
bix_stride_C
,
int
bix_stride_H
,
int
bix_stride_W
,
// btm input tx buffer
T
*
bih
,
int
bih_stride_C
,
int
bih_stride_H
,
int
bih_stride_W
,
// btm input halo
// dimensions
// dimensions
int
NC
,
int
NH
,
int
NW
,
int
NC
,
int
NH
,
int
NW
,
// signals
// signals
...
@@ -194,11 +207,11 @@ __global__ void push_pull_halos_1d_kernel(
...
@@ -194,11 +207,11 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
dual_signal_wait_clear
(
signal1_flag
,
wait1_flag
,
signal2_flag
,
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
,
true
);
dual_signal_wait_clear
<
true
,
true
>
(
signal1_flag
,
wait1_flag
,
signal2_flag
,
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
// pull top halo from transfer buffer in peer memory to input
// pull top halo from transfer buffer in peer memory to input
strided_copy_kernel
<
T
,
is_HWC
>
(
t
ox
,
t
ox
_stride_C
,
t
ox
_stride_H
,
t
ox
_stride_W
,
ti
h
,
ti
h
_stride_C
,
ti
h
_stride_H
,
ti
h
_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
>
(
t
ih
,
t
ih
_stride_C
,
t
ih
_stride_H
,
t
ih
_stride_W
,
ti
x
,
ti
x
_stride_C
,
ti
x
_stride_H
,
ti
x
_stride_W
,
NC
,
NH
,
NW
);
// pull btm halo from transfer buffer in peer memory to input
// pull btm halo from transfer buffer in peer memory to input
strided_copy_kernel
<
T
,
is_HWC
>
(
b
ox
,
b
ox
_stride_C
,
b
ox
_stride_H
,
b
ox
_stride_W
,
bi
h
,
bi
h
_stride_C
,
bi
h
_stride_H
,
bi
h
_stride_W
,
NC
,
NH
,
NW
);
strided_copy_kernel
<
T
,
is_HWC
>
(
b
ih
,
b
ih
_stride_C
,
b
ih
_stride_H
,
b
ih
_stride_W
,
bi
x
,
bi
x
_stride_C
,
bi
x
_stride_H
,
bi
x
_stride_W
,
NC
,
NH
,
NW
);
}
}
}
}
...
@@ -246,29 +259,32 @@ std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int6
...
@@ -246,29 +259,32 @@ std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int6
return
results
;
return
results
;
}
}
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
)
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
{
return
blob_view
<
at
::
Half
>
((
at
::
Half
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
));
return
blob_view
<
at
::
Half
>
((
at
::
Half
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
)
,
channels_last
);
}
}
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
)
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
{
return
blob_view
<
float
>
((
float
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat
16
).
device
(
torch
::
kCUDA
));
return
blob_view
<
float
>
((
float
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat
32
).
device
(
torch
::
kCUDA
)
,
channels_last
);
}
}
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
)
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
{
return
blob_view
<
int
>
((
int
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
k
Float16
).
device
(
torch
::
kCUDA
));
return
blob_view
<
int
>
((
int
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
k
Int32
).
device
(
torch
::
kCUDA
)
,
channels_last
);
}
}
void
push_pull_halos_1d
(
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
...
@@ -278,9 +294,11 @@ void push_pull_halos_1d(
...
@@ -278,9 +294,11 @@ void push_pull_halos_1d(
// basic checks of inputs
// basic checks of inputs
TORCH_CHECK
(
top_out_halo
.
is_cuda
());
TORCH_CHECK
(
top_out_halo
.
is_cuda
());
TORCH_CHECK
(
top_out_tx
.
is_cuda
());
TORCH_CHECK
(
top_out_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_halo
.
is_cuda
());
TORCH_CHECK
(
top_inp_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_tx
.
is_cuda
());
TORCH_CHECK
(
btm_out_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_halo
.
is_cuda
());
TORCH_CHECK
(
btm_inp_halo
.
is_cuda
());
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
...
@@ -291,46 +309,56 @@ void push_pull_halos_1d(
...
@@ -291,46 +309,56 @@ void push_pull_halos_1d(
tensor_shape
(
top_out_halo
,
explicit_nhwc
,
toh_N
,
toh_C
,
toh_H
,
toh_W
);
tensor_shape
(
top_out_halo
,
explicit_nhwc
,
toh_N
,
toh_C
,
toh_H
,
toh_W
);
int
tox_N
,
tox_C
,
tox_H
,
tox_W
;
int
tox_N
,
tox_C
,
tox_H
,
tox_W
;
tensor_shape
(
top_out_tx
,
explicit_nhwc
,
tox_N
,
tox_C
,
tox_H
,
tox_W
);
tensor_shape
(
top_out_tx
,
explicit_nhwc
,
tox_N
,
tox_C
,
tox_H
,
tox_W
);
int
tix_N
,
tix_C
,
tix_H
,
tix_W
;
tensor_shape
(
top_inp_tx
,
explicit_nhwc
,
tix_N
,
tix_C
,
tix_H
,
tix_W
);
int
tih_N
,
tih_C
,
tih_H
,
tih_W
;
int
tih_N
,
tih_C
,
tih_H
,
tih_W
;
tensor_shape
(
top_inp_halo
,
explicit_nhwc
,
tih_N
,
tih_C
,
tih_H
,
tih_W
);
tensor_shape
(
top_inp_halo
,
explicit_nhwc
,
tih_N
,
tih_C
,
tih_H
,
tih_W
);
TORCH_CHECK
(
TORCH_CHECK
(
(
toh_N
==
tox_N
&&
tox_N
==
tih_N
)
&&
(
toh_N
==
tox_N
&&
tox_N
==
tix_N
&&
tix_N
==
tih_N
)
&&
(
toh_C
==
tox_C
&&
tox_C
==
tih_C
)
&&
(
toh_C
==
tox_C
&&
tox_C
==
tix_C
&&
tix_C
==
tih_C
)
&&
(
toh_H
==
tox_H
&&
tox_H
==
tih_H
)
&&
(
toh_H
==
tox_H
&&
tox_H
==
tix_H
&&
tix_H
==
tih_H
)
&&
(
toh_W
==
tox_W
&&
tox_W
==
tih_W
));
(
toh_W
==
tox_W
&&
tox_W
==
tix_W
&&
tix_W
==
tih_W
));
int
boh_N
,
boh_C
,
boh_H
,
boh_W
;
int
boh_N
,
boh_C
,
boh_H
,
boh_W
;
tensor_shape
(
btm_out_halo
,
explicit_nhwc
,
boh_N
,
boh_C
,
boh_H
,
boh_W
);
tensor_shape
(
btm_out_halo
,
explicit_nhwc
,
boh_N
,
boh_C
,
boh_H
,
boh_W
);
int
box_N
,
box_C
,
box_H
,
box_W
;
int
box_N
,
box_C
,
box_H
,
box_W
;
tensor_shape
(
btm_out_tx
,
explicit_nhwc
,
box_N
,
box_C
,
box_H
,
box_W
);
tensor_shape
(
btm_out_tx
,
explicit_nhwc
,
box_N
,
box_C
,
box_H
,
box_W
);
int
bix_N
,
bix_C
,
bix_H
,
bix_W
;
tensor_shape
(
btm_inp_tx
,
explicit_nhwc
,
bix_N
,
bix_C
,
bix_H
,
bix_W
);
int
bih_N
,
bih_C
,
bih_H
,
bih_W
;
int
bih_N
,
bih_C
,
bih_H
,
bih_W
;
tensor_shape
(
btm_inp_halo
,
explicit_nhwc
,
bih_N
,
bih_C
,
bih_H
,
bih_W
);
tensor_shape
(
btm_inp_halo
,
explicit_nhwc
,
bih_N
,
bih_C
,
bih_H
,
bih_W
);
TORCH_CHECK
(
TORCH_CHECK
(
(
boh_N
==
box_N
&&
box_N
==
bih_N
)
&&
(
boh_N
==
box_N
&&
box_N
==
bix_N
&&
bix_N
==
bih_N
)
&&
(
boh_C
==
box_C
&&
box_C
==
bih_C
)
&&
(
boh_C
==
box_C
&&
box_C
==
bix_C
&&
bix_C
==
bih_C
)
&&
(
boh_H
==
box_H
&&
box_H
==
bih_H
)
&&
(
boh_H
==
box_H
&&
box_H
==
bix_H
&&
bix_H
==
bih_H
)
&&
(
boh_W
==
box_W
&&
box_W
==
bih_W
));
(
boh_W
==
box_W
&&
box_W
==
bix_W
&&
bix_W
==
bih_W
));
TORCH_CHECK
(
TORCH_CHECK
(
(
toh_N
==
boh_N
)
&&
(
toh_N
==
boh_N
)
&&
(
toh_C
==
boh_C
)
&&
(
toh_C
==
boh_C
)
&&
(
toh_H
==
boh_H
)
&&
(
toh_H
==
boh_H
)
&&
(
toh_W
==
boh_W
));
(
toh_W
==
boh_W
));
int
NC
=
toh_C
,
NH
=
toh_H
,
NW
=
toh_W
;
int
NC
=
toh_C
,
NH
=
toh_H
,
NW
=
toh_W
;
if
(
diagnostics
)
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
int
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
;
int
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
;
tensor_strides
(
top_out_halo
,
explicit_nhwc
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
tensor_strides
(
top_out_halo
,
explicit_nhwc
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
int
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
;
int
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
;
tensor_strides
(
top_out_tx
,
explicit_nhwc
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
tensor_strides
(
top_out_tx
,
explicit_nhwc
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
int
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
;
tensor_strides
(
top_inp_tx
,
explicit_nhwc
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
int
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
;
int
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
;
tensor_strides
(
top_inp_halo
,
explicit_nhwc
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
tensor_strides
(
top_inp_halo
,
explicit_nhwc
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
int
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
;
int
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
;
tensor_strides
(
btm_out_halo
,
explicit_nhwc
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
tensor_strides
(
btm_out_halo
,
explicit_nhwc
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
int
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
;
int
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
;
tensor_strides
(
btm_out_tx
,
explicit_nhwc
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
tensor_strides
(
btm_out_tx
,
explicit_nhwc
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
int
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
;
tensor_strides
(
btm_inp_tx
,
explicit_nhwc
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
int
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
;
int
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
;
tensor_strides
(
btm_inp_halo
,
explicit_nhwc
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
tensor_strides
(
btm_inp_halo
,
explicit_nhwc
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
// determine if nhwc
// determine if nhwc
auto
is_nhwc
=
(
toh_stride_C
==
1
)
?
true
:
false
;
auto
is_nhwc
=
(
toh_stride_C
==
1
)
?
true
:
false
;
if
(
diagnostics
)
printf
(
"is_nhwc = %s
\n
"
,
is_nhwc
?
"true"
:
"false"
);
// figure out launch parameters
// figure out launch parameters
int
device
;
int
device
;
...
@@ -342,35 +370,59 @@ void push_pull_halos_1d(
...
@@ -342,35 +370,59 @@ void push_pull_halos_1d(
const
int
numThreads
=
128
;
const
int
numThreads
=
128
;
dim3
block
(
numThreads
,
1
,
1
);
dim3
block
(
numThreads
,
1
,
1
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
top_out_halo
.
scalar_type
(),
"push_pull_halos_1d_kernel"
,
[
&
]{
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
top_out_halo
.
scalar_type
(),
"push_pull_halos_1d_kernel"
,
[
&
]{
if
(
diagnostics
)
printf
(
"size(scalar_t) = %d
\n
"
,
sizeof
(
scalar_t
));
scalar_t
*
toh_p
=
top_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
toh_p
=
top_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tox_p
=
top_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tox_p
=
top_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tix_p
=
top_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tih_p
=
top_inp_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tih_p
=
top_inp_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
boh_p
=
btm_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
boh_p
=
btm_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
box_p
=
btm_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
box_p
=
btm_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bix_p
=
btm_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bih_p
=
btm_inp_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bih_p
=
btm_inp_halo
.
data_ptr
<
scalar_t
>
();
int
*
top_signal_p
=
top_signal
.
data_ptr
<
int
>
();
if
(
diagnostics
)
printf
(
"waypoint1
\n
"
);
int
*
btm_signal_p
=
btm_signal
.
data_ptr
<
int
>
()
+
4
;
int
*
top_signal_p
=
top_signal
.
data_ptr
<
int
>
()
+
4
;
int
*
btm_signal_p
=
btm_signal
.
data_ptr
<
int
>
();
int
*
top_wait_p
=
waits
.
data_ptr
<
int
>
();
int
*
top_wait_p
=
waits
.
data_ptr
<
int
>
();
int
*
btm_wait_p
=
waits
.
data_ptr
<
int
>
()
+
4
;
int
*
btm_wait_p
=
waits
.
data_ptr
<
int
>
()
+
4
;
if
(
diagnostics
)
printf
(
"waypoint2
\n
"
);
// do int4 vector loads if channel count permits
// do int4 vector loads if channel count permits
int
elem_size_in_bytes
=
toh_C
*
sizeof
(
scalar_t
);
int
elem_size_in_bytes
=
toh_C
*
sizeof
(
scalar_t
);
int
elem_size_in_int4
=
(
elem_size_in_bytes
/
16
);
int
elem_size_in_int4
=
(
elem_size_in_bytes
/
16
);
if
(
diagnostics
)
printf
(
"elem_size_in_bytes = %d, elem_size_in_int4 = %d
\n
"
,
elem_size_in_bytes
,
elem_size_in_int4
);
if
(
is_nhwc
&&
elem_size_in_int4
*
16
==
elem_size_in_bytes
)
{
if
(
is_nhwc
&&
elem_size_in_int4
*
16
==
elem_size_in_bytes
)
{
// can do int4 transfers
// can do int4 transfers
int
divisor
=
elem_size_in_bytes
/
elem_size_in_int4
;
int
divisor
=
toh_C
/
elem_size_in_int4
;
if
(
diagnostics
)
printf
(
"CAN DO INT4 :: divisor = %d
\n
"
,
divisor
);
toh_stride_N
/=
divisor
;
toh_stride_H
/=
divisor
;
toh_stride_W
/=
divisor
;
toh_stride_N
/=
divisor
;
toh_stride_H
/=
divisor
;
toh_stride_W
/=
divisor
;
tox_stride_N
/=
divisor
;
tox_stride_H
/=
divisor
;
tox_stride_W
/=
divisor
;
tox_stride_N
/=
divisor
;
tox_stride_H
/=
divisor
;
tox_stride_W
/=
divisor
;
tix_stride_N
/=
divisor
;
tix_stride_H
/=
divisor
;
tix_stride_W
/=
divisor
;
tih_stride_N
/=
divisor
;
tih_stride_H
/=
divisor
;
tih_stride_W
/=
divisor
;
tih_stride_N
/=
divisor
;
tih_stride_H
/=
divisor
;
tih_stride_W
/=
divisor
;
boh_stride_N
/=
divisor
;
boh_stride_H
/=
divisor
;
boh_stride_W
/=
divisor
;
boh_stride_N
/=
divisor
;
boh_stride_H
/=
divisor
;
boh_stride_W
/=
divisor
;
box_stride_N
/=
divisor
;
box_stride_H
/=
divisor
;
box_stride_W
/=
divisor
;
box_stride_N
/=
divisor
;
box_stride_H
/=
divisor
;
box_stride_W
/=
divisor
;
bix_stride_N
/=
divisor
;
bix_stride_H
/=
divisor
;
bix_stride_W
/=
divisor
;
bih_stride_N
/=
divisor
;
bih_stride_H
/=
divisor
;
bih_stride_W
/=
divisor
;
bih_stride_N
/=
divisor
;
bih_stride_H
/=
divisor
;
bih_stride_W
/=
divisor
;
NC
/=
divisor
;
if
(
diagnostics
)
{
printf
(
"divisor=%d
\n
"
,
divisor
);
printf
(
"toh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
printf
(
"tox_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
printf
(
"tix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
printf
(
"tih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
printf
(
"boh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
printf
(
"box_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
printf
(
"bix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
printf
(
"bih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
}
void
*
kernelArgs
[]
=
{
void
*
kernelArgs
[]
=
{
(
int4
**
)
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
(
int4
**
)
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
(
int4
**
)
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
(
int4
**
)
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
(
int4
**
)
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
(
int4
**
)
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
(
int4
**
)
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
(
int4
**
)
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
(
int4
**
)
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
(
int4
**
)
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
(
int4
**
)
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
(
int4
**
)
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
(
int4
**
)
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
(
int4
**
)
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
...
@@ -381,12 +433,15 @@ void push_pull_halos_1d(
...
@@ -381,12 +433,15 @@ void push_pull_halos_1d(
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
else
{
}
else
{
// cannot do int4 transfers
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
void
*
kernelArgs
[]
=
{
void
*
kernelArgs
[]
=
{
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
View file @
a4eb97fb
...
@@ -24,17 +24,20 @@ namespace apex { namespace peer_memory {
...
@@ -24,17 +24,20 @@ namespace apex { namespace peer_memory {
void
free_raw
(
int64_t
raw
);
void
free_raw
(
int64_t
raw
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
);
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
);
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
void
push_pull_halos_1d
(
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
...
...
apex/contrib/peer_memory/peer_memory.py
View file @
a4eb97fb
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
peer_memory
import
peer_memory
as
pm
class
PeerMemoryPool
(
object
):
class
PeerMemoryPool
(
object
):
def
__init__
(
self
,
rank
,
world_size
,
peer_group_size
,
size
):
def
__init__
(
self
,
rank
,
world_size
,
peer_group_size
,
static_size
,
dynamic_
size
):
self
.
peer_group
=
rank
//
peer_group_size
self
.
peer_group
=
rank
//
peer_group_size
self
.
peer_rank
=
rank
%
peer_group_size
self
.
peer_rank
=
rank
%
peer_group_size
self
.
peer_group_size
=
peer_group_size
self
.
peer_group_size
=
peer_group_size
self
.
alignment
=
256
self
.
alignment
=
256
self
.
size
=
size
self
.
static_size
=
((
static_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
dynamic_size
=
((
dynamic_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
# allocate giant pool of device memory
# allocate giant pool of device memory
self
.
raw
=
allocate_raw
(
size
)
self
.
raw
=
pm
.
allocate_raw
(
self
.
static_size
+
self
.
dynamic_
size
)
# exchange peer pointers with nccl
# exchange peer pointers with nccl
raw_ipc
=
get_raw_ipc_address
(
self
.
raw
).
cuda
()
raw_ipc
=
pm
.
get_raw_ipc_address
(
self
.
raw
).
cuda
()
peer_raw_ipcs
=
[
torch
.
empty_like
(
raw_ipc
)
for
_
in
range
(
world_size
)]
peer_raw_ipcs
=
[
torch
.
empty_like
(
raw_ipc
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
peer_raw_ipcs
,
raw_ipc
)
torch
.
distributed
.
all_gather
(
peer_raw_ipcs
,
raw_ipc
)
peer_raw_ipcs
=
torch
.
stack
(
peer_raw_ipcs
).
cpu
()
peer_raw_ipcs
=
torch
.
stack
(
peer_raw_ipcs
).
cpu
()
self
.
peer_raw
=
get_raw_peers
(
peer_raw_ipcs
,
self
.
peer_rank
,
self
.
raw
)
self
.
peer_raw
=
pm
.
get_raw_peers
(
peer_raw_ipcs
,
self
.
peer_rank
,
self
.
raw
)
self
.
current
=
0
self
.
static_offset
=
0
self
.
dynamic_offset
=
0
def
__del__
(
self
):
def
__del__
(
self
):
free_raw
(
self
.
raw
)
pm
.
free_raw
(
self
.
raw
)
def
reset
(
self
):
def
reset
(
self
):
self
.
curren
t
=
0
self
.
dynamic_offse
t
=
0
def
allocate_peer_tensors
(
self
,
shape
,
dtype
):
def
allocate_peer_tensors
(
self
,
shape
,
dtype
,
channels_last
,
dynamic
):
nels
=
np
.
prod
(
shape
)
nels
=
np
.
prod
(
shape
)
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
elem_size
=
2
elem_size
=
2
start
=
((
self
.
current
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
if
dynamic
:
self
.
current
=
start
+
nels
*
elem_size
start
=
((
self
.
dynamic_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
assert
(
self
.
current
<
self
.
size
),
"Peer memory pool exhausted"
self
.
dynamic_offset
=
start
+
nels
*
elem_size
return
[
blob_view_half
(
pr
+
start
,
shape
)
for
pr
in
self
.
peer_raw
]
assert
(
self
.
dynamic_offset
<
self
.
dynamic_size
),
"Dynamic peer memory pool exhausted"
elif
dtype
==
torch
.
float32
:
return
[
pm
.
blob_view_half
(
pr
+
self
.
static_size
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
else
:
start
=
((
self
.
static_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
static_offset
=
start
+
nels
*
elem_size
assert
(
self
.
static_offset
<
self
.
static_size
),
"Static peer memory pool exhausted"
return
[
pm
.
blob_view_half
(
pr
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
if
dtype
==
torch
.
float32
:
elem_size
=
4
elem_size
=
4
start
=
((
self
.
current
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
if
dynamic
:
self
.
current
=
start
+
nels
*
elem_size
start
=
((
self
.
dynamic_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
assert
(
self
.
current
<
self
.
size
),
"Peer memory pool exhausted"
self
.
dynamic_offset
=
start
+
nels
*
elem_size
return
[
blob_view_float
(
pr
+
start
,
shape
)
for
pr
in
self
.
peer_raw
]
assert
(
self
.
dynamic_offset
<
self
.
dynamic_size
),
"Dynamic peer memory pool exhausted"
elif
dtype
==
torch
.
int32
:
return
[
pm
.
blob_view_float
(
pr
+
self
.
static_size
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
else
:
start
=
((
self
.
static_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
static_offset
=
start
+
nels
*
elem_size
assert
(
self
.
static_offset
<
self
.
static_size
),
"Static peer memory pool exhausted"
return
[
pm
.
blob_view_float
(
pr
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
if
dtype
==
torch
.
int32
:
elem_size
=
4
elem_size
=
4
start
=
((
self
.
current
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
if
dynamic
:
self
.
current
=
start
+
nels
*
elem_size
start
=
((
self
.
dynamic_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
assert
(
self
.
current
<
self
.
size
),
"Peer memory pool exhausted"
self
.
dynamic_offset
=
start
+
nels
*
elem_size
return
[
blob_view_int
(
pr
+
start
,
shape
)
for
pr
in
self
.
peer_raw
]
assert
(
self
.
dynamic_offset
<
self
.
dynamic_size
),
"Dynamic peer memory pool exhausted"
return
[
pm
.
blob_view_int
(
pr
+
self
.
static_size
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
else
:
start
=
((
self
.
static_offset
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
static_offset
=
start
+
nels
*
elem_size
assert
(
self
.
static_offset
<
self
.
static_size
),
"Static peer memory pool exhausted"
return
[
pm
.
blob_view_int
(
pr
+
start
,
shape
,
channels_last
)
for
pr
in
self
.
peer_raw
]
else
:
else
:
assert
(
False
),
"
Unknown dtype : %s
"
%
(
str
(
dtype
))
assert
(
False
),
"
dtype %s not supported
"
%
(
str
(
dtype
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment