Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
464e95f5
Commit
464e95f5
authored
May 19, 2020
by
lcskrishna
Browse files
enable run_amp tests
parent
d0555980
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
2 deletions
+15
-2
apex/testing/common_utils.py
apex/testing/common_utils.py
+2
-0
tests/L0/run_amp/test_checkpointing.py
tests/L0/run_amp/test_checkpointing.py
+2
-1
tests/L0/run_amp/test_multi_tensor_axpby.py
tests/L0/run_amp/test_multi_tensor_axpby.py
+3
-0
tests/L0/run_amp/test_multi_tensor_l2norm.py
tests/L0/run_amp/test_multi_tensor_l2norm.py
+3
-0
tests/L0/run_amp/test_rnn.py
tests/L0/run_amp/test_rnn.py
+5
-0
tests/L0/run_test.py
tests/L0/run_test.py
+0
-1
No files found.
apex/testing/common_utils.py
View file @
464e95f5
...
...
@@ -5,6 +5,8 @@ This file contains common utility functions for running the unit tests on ROCM.
import
torch
import
os
import
sys
from
functools
import
wraps
import
unittest
TEST_WITH_ROCM
=
os
.
getenv
(
'APEX_TEST_WITH_ROCM'
,
'0'
)
==
'1'
...
...
tests/L0/run_amp/test_checkpointing.py
View file @
464e95f5
...
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
import
torch.optim
as
optim
from
apex
import
amp
from
apex.testing.common_utils
import
skipIfRocm
from
utils
import
common_init
,
FLOAT
...
...
@@ -161,6 +161,7 @@ class TestCheckpointing(unittest.TestCase):
# skip tests for different opt_levels
continue
@
skipIfRocm
def
test_loss_scale_decrease
(
self
):
num_losses
=
3
nb_decrease_loss_scales
=
[
0
,
1
,
2
]
...
...
tests/L0/run_amp/test_multi_tensor_axpby.py
View file @
464e95f5
...
...
@@ -12,6 +12,8 @@ from math import floor
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
from
apex.testing.common_utils
import
skipIfRocm
try
:
import
amp_C
from
amp_C
import
multi_tensor_axpby
...
...
@@ -137,6 +139,7 @@ class TestMultiTensorAxpby(unittest.TestCase):
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
@
unittest
.
skipIf
(
not
try_nhwc
,
"torch version is 1.4 or earlier, may not support nhwc"
)
@
skipIfRocm
def
test_fuzz_nhwc
(
self
):
input_size_pairs
=
(
((
7
,
77
,
7
,
77
),
(
5
,
55
,
5
,
55
)),
...
...
tests/L0/run_amp/test_multi_tensor_l2norm.py
View file @
464e95f5
...
...
@@ -11,6 +11,8 @@ import torch.nn.functional as F
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
from
apex.testing.common_utils
import
skipIfRocm
try
:
import
amp_C
from
amp_C
import
multi_tensor_l2norm
...
...
@@ -56,6 +58,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
@
skipIfRocm
def
test_fuzz
(
self
):
input_size_pairs
=
(
(
7777
*
77
,
555
*
555
),
...
...
tests/L0/run_amp/test_rnn.py
View file @
464e95f5
...
...
@@ -6,6 +6,7 @@ import torch
from
torch
import
nn
from
utils
import
common_init
,
HALF
from
apex.testing.common_utils
import
skipIfRocm
class
TestRnnCells
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
@@ -73,6 +74,7 @@ class TestRnns(unittest.TestCase):
output
[
-
1
,
:,
:].
float
().
sum
().
backward
()
self
.
assertEqual
(
x
.
grad
.
dtype
,
x
.
dtype
)
@
skipIfRocm
def
test_rnn_is_half
(
self
):
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
for
layers
,
bidir
in
configs
:
...
...
@@ -80,6 +82,7 @@ class TestRnns(unittest.TestCase):
nonlinearity
=
'relu'
,
bidirectional
=
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
)
@
skipIfRocm
def
test_gru_is_half
(
self
):
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
for
layers
,
bidir
in
configs
:
...
...
@@ -87,6 +90,7 @@ class TestRnns(unittest.TestCase):
bidirectional
=
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
)
@
skipIfRocm
def
test_lstm_is_half
(
self
):
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
for
layers
,
bidir
in
configs
:
...
...
@@ -94,6 +98,7 @@ class TestRnns(unittest.TestCase):
bidirectional
=
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
,
state_tuple
=
True
)
@
skipIfRocm
def
test_rnn_packed_sequence
(
self
):
num_layers
=
2
rnn
=
nn
.
RNN
(
input_size
=
self
.
h
,
hidden_size
=
self
.
h
,
num_layers
=
num_layers
)
...
...
tests/L0/run_test.py
View file @
464e95f5
...
...
@@ -6,7 +6,6 @@ from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm
test_dirs
=
[
"run_amp"
,
"run_fp16util"
,
"run_optimizers"
,
"run_fused_layer_norm"
,
"run_pyprof_nvtx"
,
"run_pyprof_data"
,
"run_mlp"
]
ROCM_BLACKLIST
=
[
'run_amp'
,
'run_optimizers'
,
'run_fused_layer_norm'
,
'run_pyprof_nvtx'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment