Commit 464e95f5 authored by lcskrishna's avatar lcskrishna
Browse files

enable run_amp tests

parent d0555980
......@@ -5,6 +5,8 @@ This file contains common utility functions for running the unit tests on ROCM.
import torch
import os
import sys
from functools import wraps
import unittest
TEST_WITH_ROCM = os.getenv('APEX_TEST_WITH_ROCM', '0') == '1'
......
......@@ -6,7 +6,7 @@ import torch.nn.functional as F
import torch.optim as optim
from apex import amp
from apex.testing.common_utils import skipIfRocm
from utils import common_init, FLOAT
......@@ -161,6 +161,7 @@ class TestCheckpointing(unittest.TestCase):
# skip tests for different opt_levels
continue
@skipIfRocm
def test_loss_scale_decrease(self):
num_losses = 3
nb_decrease_loss_scales = [0, 1, 2]
......
......@@ -12,6 +12,8 @@ from math import floor
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
try:
import amp_C
from amp_C import multi_tensor_axpby
......@@ -137,6 +139,7 @@ class TestMultiTensorAxpby(unittest.TestCase):
@unittest.skipIf(disabled, "amp_C is unavailable")
@unittest.skipIf(not try_nhwc, "torch version is 1.4 or earlier, may not support nhwc")
@skipIfRocm
def test_fuzz_nhwc(self):
input_size_pairs = (
((7, 77, 7, 77), (5, 55, 5, 55)),
......
......@@ -11,6 +11,8 @@ import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
try:
import amp_C
from amp_C import multi_tensor_l2norm
......@@ -56,6 +58,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
self.assertTrue(self.overflow_buf.item() == 0)
@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_fuzz(self):
input_size_pairs = (
(7777*77, 555*555),
......
......@@ -6,6 +6,7 @@ import torch
from torch import nn
from utils import common_init, HALF
from apex.testing.common_utils import skipIfRocm
class TestRnnCells(unittest.TestCase):
def setUp(self):
......@@ -73,6 +74,7 @@ class TestRnns(unittest.TestCase):
output[-1, :, :].float().sum().backward()
self.assertEqual(x.grad.dtype, x.dtype)
@skipIfRocm
def test_rnn_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
......@@ -80,6 +82,7 @@ class TestRnns(unittest.TestCase):
nonlinearity='relu', bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir)
@skipIfRocm
def test_gru_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
......@@ -87,6 +90,7 @@ class TestRnns(unittest.TestCase):
bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir)
@skipIfRocm
def test_lstm_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
......@@ -94,6 +98,7 @@ class TestRnns(unittest.TestCase):
bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
@skipIfRocm
def test_rnn_packed_sequence(self):
num_layers = 2
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
......
......@@ -6,7 +6,6 @@ from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm
test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"]
ROCM_BLACKLIST = [
'run_amp',
'run_optimizers',
'run_fused_layer_norm',
'run_pyprof_nvtx',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment