Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
464e95f5
Commit
464e95f5
authored
May 19, 2020
by
lcskrishna
Browse files
enable run_amp tests
parent
d0555980
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
2 deletions
+15
-2
apex/testing/common_utils.py
apex/testing/common_utils.py
+2
-0
tests/L0/run_amp/test_checkpointing.py
tests/L0/run_amp/test_checkpointing.py
+2
-1
tests/L0/run_amp/test_multi_tensor_axpby.py
tests/L0/run_amp/test_multi_tensor_axpby.py
+3
-0
tests/L0/run_amp/test_multi_tensor_l2norm.py
tests/L0/run_amp/test_multi_tensor_l2norm.py
+3
-0
tests/L0/run_amp/test_rnn.py
tests/L0/run_amp/test_rnn.py
+5
-0
tests/L0/run_test.py
tests/L0/run_test.py
+0
-1
No files found.
apex/testing/common_utils.py
View file @
464e95f5
...
@@ -5,6 +5,8 @@ This file contains common utility functions for running the unit tests on ROCM.
...
@@ -5,6 +5,8 @@ This file contains common utility functions for running the unit tests on ROCM.
import
torch
import
torch
import
os
import
os
import
sys
import
sys
from
functools
import
wraps
import
unittest
TEST_WITH_ROCM
=
os
.
getenv
(
'APEX_TEST_WITH_ROCM'
,
'0'
)
==
'1'
TEST_WITH_ROCM
=
os
.
getenv
(
'APEX_TEST_WITH_ROCM'
,
'0'
)
==
'1'
...
...
tests/L0/run_amp/test_checkpointing.py
View file @
464e95f5
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
import
torch.optim
as
optim
import
torch.optim
as
optim
from
apex
import
amp
from
apex
import
amp
from
apex.testing.common_utils
import
skipIfRocm
from
utils
import
common_init
,
FLOAT
from
utils
import
common_init
,
FLOAT
...
@@ -161,6 +161,7 @@ class TestCheckpointing(unittest.TestCase):
...
@@ -161,6 +161,7 @@ class TestCheckpointing(unittest.TestCase):
# skip tests for different opt_levels
# skip tests for different opt_levels
continue
continue
@
skipIfRocm
def
test_loss_scale_decrease
(
self
):
def
test_loss_scale_decrease
(
self
):
num_losses
=
3
num_losses
=
3
nb_decrease_loss_scales
=
[
0
,
1
,
2
]
nb_decrease_loss_scales
=
[
0
,
1
,
2
]
...
...
tests/L0/run_amp/test_multi_tensor_axpby.py
View file @
464e95f5
...
@@ -12,6 +12,8 @@ from math import floor
...
@@ -12,6 +12,8 @@ from math import floor
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
from
apex.testing.common_utils
import
skipIfRocm
try
:
try
:
import
amp_C
import
amp_C
from
amp_C
import
multi_tensor_axpby
from
amp_C
import
multi_tensor_axpby
...
@@ -137,6 +139,7 @@ class TestMultiTensorAxpby(unittest.TestCase):
...
@@ -137,6 +139,7 @@ class TestMultiTensorAxpby(unittest.TestCase):
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
@
unittest
.
skipIf
(
not
try_nhwc
,
"torch version is 1.4 or earlier, may not support nhwc"
)
@
unittest
.
skipIf
(
not
try_nhwc
,
"torch version is 1.4 or earlier, may not support nhwc"
)
@
skipIfRocm
def
test_fuzz_nhwc
(
self
):
def
test_fuzz_nhwc
(
self
):
input_size_pairs
=
(
input_size_pairs
=
(
((
7
,
77
,
7
,
77
),
(
5
,
55
,
5
,
55
)),
((
7
,
77
,
7
,
77
),
(
5
,
55
,
5
,
55
)),
...
...
tests/L0/run_amp/test_multi_tensor_l2norm.py
View file @
464e95f5
...
@@ -11,6 +11,8 @@ import torch.nn.functional as F
...
@@ -11,6 +11,8 @@ import torch.nn.functional as F
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
from
apex.testing.common_utils
import
skipIfRocm
try
:
try
:
import
amp_C
import
amp_C
from
amp_C
import
multi_tensor_l2norm
from
amp_C
import
multi_tensor_l2norm
...
@@ -56,6 +58,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
...
@@ -56,6 +58,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
@
skipIfRocm
def
test_fuzz
(
self
):
def
test_fuzz
(
self
):
input_size_pairs
=
(
input_size_pairs
=
(
(
7777
*
77
,
555
*
555
),
(
7777
*
77
,
555
*
555
),
...
...
tests/L0/run_amp/test_rnn.py
View file @
464e95f5
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
utils
import
common_init
,
HALF
from
utils
import
common_init
,
HALF
from
apex.testing.common_utils
import
skipIfRocm
class
TestRnnCells
(
unittest
.
TestCase
):
class
TestRnnCells
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -73,6 +74,7 @@ class TestRnns(unittest.TestCase):
...
@@ -73,6 +74,7 @@ class TestRnns(unittest.TestCase):
output
[
-
1
,
:,
:].
float
().
sum
().
backward
()
output
[
-
1
,
:,
:].
float
().
sum
().
backward
()
self
.
assertEqual
(
x
.
grad
.
dtype
,
x
.
dtype
)
self
.
assertEqual
(
x
.
grad
.
dtype
,
x
.
dtype
)
@
skipIfRocm
def
test_rnn_is_half
(
self
):
def
test_rnn_is_half
(
self
):
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
for
layers
,
bidir
in
configs
:
for
layers
,
bidir
in
configs
:
...
@@ -80,6 +82,7 @@ class TestRnns(unittest.TestCase):
...
@@ -80,6 +82,7 @@ class TestRnns(unittest.TestCase):
nonlinearity
=
'relu'
,
bidirectional
=
bidir
)
nonlinearity
=
'relu'
,
bidirectional
=
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
)
@
skipIfRocm
def
test_gru_is_half
(
self
):
def
test_gru_is_half
(
self
):
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
for
layers
,
bidir
in
configs
:
for
layers
,
bidir
in
configs
:
...
@@ -87,6 +90,7 @@ class TestRnns(unittest.TestCase):
...
@@ -87,6 +90,7 @@ class TestRnns(unittest.TestCase):
bidirectional
=
bidir
)
bidirectional
=
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
)
@
skipIfRocm
def
test_lstm_is_half
(
self
):
def
test_lstm_is_half
(
self
):
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
for
layers
,
bidir
in
configs
:
for
layers
,
bidir
in
configs
:
...
@@ -94,6 +98,7 @@ class TestRnns(unittest.TestCase):
...
@@ -94,6 +98,7 @@ class TestRnns(unittest.TestCase):
bidirectional
=
bidir
)
bidirectional
=
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
,
state_tuple
=
True
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
,
state_tuple
=
True
)
@
skipIfRocm
def
test_rnn_packed_sequence
(
self
):
def
test_rnn_packed_sequence
(
self
):
num_layers
=
2
num_layers
=
2
rnn
=
nn
.
RNN
(
input_size
=
self
.
h
,
hidden_size
=
self
.
h
,
num_layers
=
num_layers
)
rnn
=
nn
.
RNN
(
input_size
=
self
.
h
,
hidden_size
=
self
.
h
,
num_layers
=
num_layers
)
...
...
tests/L0/run_test.py
View file @
464e95f5
...
@@ -6,7 +6,6 @@ from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm
...
@@ -6,7 +6,6 @@ from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm
test_dirs
=
[
"run_amp"
,
"run_fp16util"
,
"run_optimizers"
,
"run_fused_layer_norm"
,
"run_pyprof_nvtx"
,
"run_pyprof_data"
,
"run_mlp"
]
test_dirs
=
[
"run_amp"
,
"run_fp16util"
,
"run_optimizers"
,
"run_fused_layer_norm"
,
"run_pyprof_nvtx"
,
"run_pyprof_data"
,
"run_mlp"
]
ROCM_BLACKLIST
=
[
ROCM_BLACKLIST
=
[
'run_amp'
,
'run_optimizers'
,
'run_optimizers'
,
'run_fused_layer_norm'
,
'run_fused_layer_norm'
,
'run_pyprof_nvtx'
,
'run_pyprof_nvtx'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment