Commit 464e95f5 authored by lcskrishna's avatar lcskrishna
Browse files

enable run_amp tests

parent d0555980
...@@ -5,6 +5,8 @@ This file contains common utility functions for running the unit tests on ROCM. ...@@ -5,6 +5,8 @@ This file contains common utility functions for running the unit tests on ROCM.
import torch import torch
import os import os
import sys import sys
from functools import wraps
import unittest
TEST_WITH_ROCM = os.getenv('APEX_TEST_WITH_ROCM', '0') == '1' TEST_WITH_ROCM = os.getenv('APEX_TEST_WITH_ROCM', '0') == '1'
......
...@@ -6,7 +6,7 @@ import torch.nn.functional as F ...@@ -6,7 +6,7 @@ import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from apex import amp from apex import amp
from apex.testing.common_utils import skipIfRocm
from utils import common_init, FLOAT from utils import common_init, FLOAT
...@@ -161,6 +161,7 @@ class TestCheckpointing(unittest.TestCase): ...@@ -161,6 +161,7 @@ class TestCheckpointing(unittest.TestCase):
# skip tests for different opt_levels # skip tests for different opt_levels
continue continue
@skipIfRocm
def test_loss_scale_decrease(self): def test_loss_scale_decrease(self):
num_losses = 3 num_losses = 3
nb_decrease_loss_scales = [0, 1, 2] nb_decrease_loss_scales = [0, 1, 2]
......
...@@ -12,6 +12,8 @@ from math import floor ...@@ -12,6 +12,8 @@ from math import floor
from utils import common_init, HALF, FLOAT,\ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
try: try:
import amp_C import amp_C
from amp_C import multi_tensor_axpby from amp_C import multi_tensor_axpby
...@@ -137,6 +139,7 @@ class TestMultiTensorAxpby(unittest.TestCase): ...@@ -137,6 +139,7 @@ class TestMultiTensorAxpby(unittest.TestCase):
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@unittest.skipIf(not try_nhwc, "torch version is 1.4 or earlier, may not support nhwc") @unittest.skipIf(not try_nhwc, "torch version is 1.4 or earlier, may not support nhwc")
@skipIfRocm
def test_fuzz_nhwc(self): def test_fuzz_nhwc(self):
input_size_pairs = ( input_size_pairs = (
((7, 77, 7, 77), (5, 55, 5, 55)), ((7, 77, 7, 77), (5, 55, 5, 55)),
......
...@@ -11,6 +11,8 @@ import torch.nn.functional as F ...@@ -11,6 +11,8 @@ import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
try: try:
import amp_C import amp_C
from amp_C import multi_tensor_l2norm from amp_C import multi_tensor_l2norm
...@@ -56,6 +58,7 @@ class TestMultiTensorL2Norm(unittest.TestCase): ...@@ -56,6 +58,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
self.assertTrue(self.overflow_buf.item() == 0) self.assertTrue(self.overflow_buf.item() == 0)
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_fuzz(self): def test_fuzz(self):
input_size_pairs = ( input_size_pairs = (
(7777*77, 555*555), (7777*77, 555*555),
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from torch import nn from torch import nn
from utils import common_init, HALF from utils import common_init, HALF
from apex.testing.common_utils import skipIfRocm
class TestRnnCells(unittest.TestCase): class TestRnnCells(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -73,6 +74,7 @@ class TestRnns(unittest.TestCase): ...@@ -73,6 +74,7 @@ class TestRnns(unittest.TestCase):
output[-1, :, :].float().sum().backward() output[-1, :, :].float().sum().backward()
self.assertEqual(x.grad.dtype, x.dtype) self.assertEqual(x.grad.dtype, x.dtype)
@skipIfRocm
def test_rnn_is_half(self): def test_rnn_is_half(self):
configs = [(1, False), (2, False), (2, True)] configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs: for layers, bidir in configs:
...@@ -80,6 +82,7 @@ class TestRnns(unittest.TestCase): ...@@ -80,6 +82,7 @@ class TestRnns(unittest.TestCase):
nonlinearity='relu', bidirectional=bidir) nonlinearity='relu', bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir) self.run_rnn_test(rnn, layers, bidir)
@skipIfRocm
def test_gru_is_half(self): def test_gru_is_half(self):
configs = [(1, False), (2, False), (2, True)] configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs: for layers, bidir in configs:
...@@ -87,6 +90,7 @@ class TestRnns(unittest.TestCase): ...@@ -87,6 +90,7 @@ class TestRnns(unittest.TestCase):
bidirectional=bidir) bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir) self.run_rnn_test(rnn, layers, bidir)
@skipIfRocm
def test_lstm_is_half(self): def test_lstm_is_half(self):
configs = [(1, False), (2, False), (2, True)] configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs: for layers, bidir in configs:
...@@ -94,6 +98,7 @@ class TestRnns(unittest.TestCase): ...@@ -94,6 +98,7 @@ class TestRnns(unittest.TestCase):
bidirectional=bidir) bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir, state_tuple=True) self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
@skipIfRocm
def test_rnn_packed_sequence(self): def test_rnn_packed_sequence(self):
num_layers = 2 num_layers = 2
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers) rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
......
...@@ -6,7 +6,6 @@ from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm ...@@ -6,7 +6,6 @@ from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm
test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"] test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"]
ROCM_BLACKLIST = [ ROCM_BLACKLIST = [
'run_amp',
'run_optimizers', 'run_optimizers',
'run_fused_layer_norm', 'run_fused_layer_norm',
'run_pyprof_nvtx', 'run_pyprof_nvtx',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment