Commit eba809d7 authored by lcskrishna's avatar lcskrishna
Browse files

skip newer tests

parent 8d5c2624
......@@ -5,6 +5,7 @@ import torch
from torch.optim import Optimizer
import apex
from apex.multi_tensor_apply import multi_tensor_applier
from apex.testing.common_utils import skipIfRocm
class RefLAMB(Optimizer):
r"""Implements Lamb algorithm.
......@@ -207,6 +208,7 @@ class TestFusedLAMB(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
@skipIfRocm
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
......@@ -214,6 +216,7 @@ class TestFusedLAMB(unittest.TestCase):
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@skipIfRocm
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
weight_decay = [0, 0.01]
......@@ -234,6 +237,7 @@ class TestFusedLAMB(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
@skipIfRocm
def test_lamb_option(self):
nelem = 1
tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment