Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
22920fe0
Commit
22920fe0
authored
Sep 10, 2018
by
Carl Case
Browse files
add more promotion testing
parent
81c788f0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
8 deletions
+25
-8
apex/amp/test/test_promotion.py
apex/amp/test/test_promotion.py
+25
-8
No files found.
apex/amp/test/test_promotion.py
View file @
22920fe0
...
...
@@ -17,18 +17,27 @@ class TestPromotion(unittest.TestCase):
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
run_binary_promote_test
(
self
,
fns
,
input_shape
):
def
run_binary_promote_test
(
self
,
fns
,
input_shape
,
x_inplace
=
False
):
type_pairs
=
it
.
product
(
DTYPES
,
DTYPES
)
for
fn
,
(
xtype
,
ytype
)
in
it
.
product
(
fns
,
type_pairs
):
x
=
torch
.
randn
(
input_shape
,
dtype
=
xtype
).
requires_grad_
()
x_leaf
=
x
if
x_inplace
:
# We need a non-leaf to call in place on
x
=
x
.
clone
()
y
=
torch
.
randn
(
input_shape
,
dtype
=
ytype
)
out
=
fn
(
x
,
y
)
if
xtype
==
torch
.
float
or
ytype
==
torch
.
float
:
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
if
x_inplace
:
# In place: always match xtype
self
.
assertEqual
(
out
.
type
(),
x
.
type
())
else
:
self
.
assertEqual
(
out
.
type
(),
HALF
)
# Out of place: match widest type
if
xtype
==
torch
.
float
or
ytype
==
torch
.
float
:
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
else
:
self
.
assertEqual
(
out
.
type
(),
HALF
)
out
.
float
().
sum
().
backward
()
self
.
assertEqual
(
x
.
grad
.
dtype
,
xtype
)
self
.
assertEqual
(
x
_leaf
.
grad
.
dtype
,
xtype
)
def
test_atan2_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
atan2
(
x
,
y
),
...
...
@@ -50,9 +59,17 @@ class TestPromotion(unittest.TestCase):
out
=
torch
.
cat
(
ys
+
[
x_half
])
self
.
assertEqual
(
out
.
type
(),
HALF
)
# TODOs:
# In-place methods on fp16 are errors for fp32
# In-place methods match type of self tensor
def
test_inplace_exp_is_error_for_half
(
self
):
xs
=
torch
.
randn
(
self
.
b
)
xs
.
exp_
()
self
.
assertEqual
(
xs
.
type
(),
FLOAT
)
xs
=
torch
.
randn
(
self
.
b
,
dtype
=
torch
.
half
)
with
self
.
assertRaises
(
NotImplementedError
):
xs
.
exp_
()
def
test_inplace_add_matches_self
(
self
):
fn
=
lambda
x
,
y
:
x
.
add_
(
y
)
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
x_inplace
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment