Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
de6378f5
Unverified
Commit
de6378f5
authored
Feb 26, 2020
by
mcarilli
Committed by
GitHub
Feb 26, 2020
Browse files
NHWC support for multi tensor apply (#732)
* NHWC support for multi tensor apply * compilation fix for version<=1.4
parent
92b3b9a9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
9 deletions
+68
-9
csrc/multi_tensor_apply.cuh
csrc/multi_tensor_apply.cuh
+5
-1
setup.py
setup.py
+4
-1
tests/L0/run_amp/test_multi_tensor_axpby.py
tests/L0/run_amp/test_multi_tensor_axpby.py
+59
-7
No files found.
csrc/multi_tensor_apply.cuh
View file @
de6378f5
...
...
@@ -56,7 +56,11 @@ void multi_tensor_apply(
for
(
int
t
=
0
;
t
<
tensor_lists
[
l
].
size
();
t
++
)
{
// TODO: Print which tensor fails.
TORCH_CHECK
(
tensor_lists
[
l
][
t
].
is_contiguous
(),
"A tensor was not contiguous."
);
bool
contiguous_memory
=
tensor_lists
[
l
][
t
].
is_contiguous
();
#ifdef VERSION_GE_1_5
contiguous_memory
=
(
contiguous_memory
||
tensor_lists
[
l
][
t
].
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
));
#endif
TORCH_CHECK
(
contiguous_memory
,
"A tensor was not contiguous."
);
TORCH_CHECK
(
tensor_lists
[
l
][
t
].
is_cuda
(),
"A tensor was not cuda."
);
TORCH_CHECK
(
tensor_lists
[
l
][
t
].
numel
()
==
tensor_lists
[
0
][
t
].
numel
(),
"Size mismatch"
);
}
...
...
setup.py
View file @
de6378f5
...
...
@@ -91,7 +91,10 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_3
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
version_ge_1_3
=
[
'-DVERSION_GE_1_3'
]
version_dependent_macros
=
version_ge_1_1
+
version_ge_1_3
version_ge_1_5
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
4
):
version_ge_1_5
=
[
'-DVERSION_GE_1_5'
]
version_dependent_macros
=
version_ge_1_1
+
version_ge_1_3
+
version_ge_1_5
if
"--cuda_ext"
in
sys
.
argv
:
from
torch.utils.cpp_extension
import
CUDAExtension
...
...
tests/L0/run_amp/test_multi_tensor_axpby.py
View file @
de6378f5
...
...
@@ -7,6 +7,7 @@ from apex import amp
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
math
import
floor
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
...
...
@@ -20,6 +21,10 @@ except ImportError as err:
print
(
"amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was "
,
err
)
disabled
=
True
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
try_nhwc
=
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
4
)
class
TestMultiTensorAxpby
(
unittest
.
TestCase
):
...
...
@@ -31,28 +36,36 @@ class TestMultiTensorAxpby(unittest.TestCase):
self
.
xval
=
4.0
self
.
yval
=
16.0
self
.
overflow_buf
=
torch
.
cuda
.
IntTensor
(
1
).
zero_
()
self
.
ref
=
torch
.
cuda
.
FloatTensor
([
136.0
]
)
self
.
ref
=
torch
.
full
((
1
,),
136.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
def
tearDown
(
self
):
pass
# The tensor creation here is written for convenience, not speed.
def
axpby
(
self
,
sizea
,
sizeb
,
applier
,
repeat_tensors
,
x_type
,
y_type
,
out_type
,
inplace
=
False
):
x_type
,
y_type
,
out_type
,
inplace
=
False
,
nhwc
=
False
):
self
.
overflow_buf
.
zero_
()
t1
=
torch
.
cuda
.
FloatTensor
(
sizea
).
fill_
(
1.0
)
t2
=
torch
.
cuda
.
FloatTensor
(
sizeb
).
fill_
(
1.0
)
sizea
=
sizea
if
isinstance
(
sizea
,
tuple
)
else
(
sizea
,)
sizeb
=
sizeb
if
isinstance
(
sizeb
,
tuple
)
else
(
sizeb
,)
t1
=
torch
.
full
(
sizea
,
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
t2
=
torch
.
full
(
sizeb
,
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
def
to_fmt
(
t
,
tp
):
if
nhwc
:
return
t
.
clone
().
to
(
tp
,
memory_format
=
torch
.
channels_last
)
else
:
return
t
.
clone
().
to
(
tp
)
y_list
=
[]
for
i
in
range
(
repeat_tensors
):
y_list
+=
[
t
1
.
clone
().
to
(
y_type
)
*
self
.
yval
,
t
2
.
clone
().
to
(
y_type
)
*
self
.
yval
]
y_list
+=
[
t
o_fmt
(
t1
,
y_type
)
*
self
.
yval
,
t
o_fmt
(
t2
,
y_type
)
*
self
.
yval
]
x_list
=
[
x
.
clone
().
to
(
x_type
)
*
(
self
.
xval
/
self
.
yval
)
for
x
in
y_list
]
x_list
=
[
to_fmt
(
x
,
x_type
)
*
(
self
.
xval
/
self
.
yval
)
for
x
in
y_list
]
if
inplace
:
out_list
=
y_list
else
:
out_list
=
[
out
.
clone
().
to
(
out_type
)
*
3.0
for
out
in
y_list
]
out_list
=
[
to_fmt
(
out
,
out_type
)
*
3.0
for
out
in
y_list
]
applier
(
multi_tensor_axpby
,
self
.
overflow_buf
,
[
x_list
,
y_list
,
out_list
],
self
.
a
,
self
.
b
,
-
1
)
...
...
@@ -122,6 +135,45 @@ class TestMultiTensorAxpby(unittest.TestCase):
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
@
unittest
.
skipIf
(
not
try_nhwc
,
"torch version is 1.4 or earlier, may not support nhwc"
)
def
test_fuzz_nhwc
(
self
):
input_size_pairs
=
(
((
7
,
77
,
7
,
77
),
(
5
,
55
,
5
,
55
)),
((
1
,
1
,
777
,
1
),
(
1
,
1
,
555
,
1
)),
((
5
,
47
,
5
,
55
),
(
1
,
1
,
1
,
2048
*
32
+
1
)),
((
1
,
1
,
1
,
2048
*
32
+
1
),
(
55
,
47
,
5
,
55
)),
((
555
,
1
,
1
,
1
),
(
32
,
8
,
32
,
8
)),
((
32
,
8
,
32
,
8
),
(
55
,
47
,
5
,
55
)),
((
1
,
1
,
33333
,
1
),
(
55
,
47
,
55
,
5
)),
((
55
,
47
,
55
,
5
),
(
1
,
1
,
33333
,
1
)))
appliers
=
(
MultiTensorApply
(
2048
*
32
),
MultiTensorApply
(
333
),
MultiTensorApply
(
33333
))
repeat_tensors
=
(
1
,
55
)
for
sizea
,
sizeb
in
input_size_pairs
:
for
applier
in
appliers
:
for
repeat
in
repeat_tensors
:
for
x_type
in
(
torch
.
float32
,
torch
.
float16
):
for
y_type
in
(
torch
.
float32
,
torch
.
float16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
):
for
inplace
in
(
True
,
False
):
if
inplace
is
True
and
(
y_type
is
not
out_type
):
continue
else
:
self
.
axpby
(
sizea
,
sizeb
,
applier
,
repeat
,
x_type
,
y_type
,
out_type
,
inplace
=
inplace
,
nhwc
=
True
)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 0, 0, float('nan'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment