Commit 9187ea1d authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Merge remote-tracking branch 'origin/master' into dev/hubertlu/focal_loss_and_index_mul_2d_cuda

parents ebb4e88a a27b4e43
......@@ -103,9 +103,12 @@ def cached_cast(cast_fn, x, cache):
return type(x)([cached_cast(y) for y in x])
if x in cache:
cached_x = cache[x]
next_functions_available = False
if x.requires_grad and cached_x.requires_grad:
if len(cached_x.grad_fn.next_functions) > 1:
next_functions_available = True
# Make sure x is actually cached_x's autograd parent.
if cached_x.grad_fn.next_functions[1][0].variable is not x:
if next_functions_available and cached_x.grad_fn.next_functions[1][0].variable is not x:
raise RuntimeError("x and cache[x] both require grad, but x is not "
"cache[x]'s parent. This is likely an error.")
# During eval, it's possible to end up caching casted weights with
......@@ -125,6 +128,8 @@ def cached_cast(cast_fn, x, cache):
# connection between x and cached_x.
if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
del cache[x]
elif x.requires_grad and cached_x.requires_grad and not next_functions_available:
del cache[x]
else:
return cached_x
......
......@@ -74,11 +74,9 @@ class TestBasicCastsHalf(_TestBasicCasts):
def tearDown(self):
self.handle._deactivate()
@unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
def test_linear_is_half(self):
self._test_linear(ALWAYS_HALF)
@unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
def test_conv2d_is_half(self):
self._test_conv2d(ALWAYS_HALF)
......
......@@ -138,7 +138,6 @@ class TestCache(unittest.TestCase):
def test_whitelist_module_bfp16_weight(self):
self.train_eval_train_test(WhitelistModule, torch.bfloat16, "O4")
@unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
def test_whitelist_module_fp32_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float32, "O4")
......
......@@ -69,7 +69,6 @@ class TestCheckpointing(unittest.TestCase):
'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format(
key, paramA, paramB, paramA - paramB, test_setup))
@unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
def test_restoring(self):
nb_epochs = 10
nb_epochs_restore = nb_epochs // 2
......@@ -222,7 +221,6 @@ class TestCheckpointing(unittest.TestCase):
unskipped_target = 0
self.assertEqual(scaler['unskipped'], unskipped_target)
@unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
def test_state_dict(self):
for opt_level in self.test_opt_levels:
# Skip O3
......
......@@ -40,17 +40,14 @@ class TestRnnCells(unittest.TestCase):
for i, x in enumerate(xs):
self.assertEqual(x.grad.dtype, x.dtype)
@unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
def test_rnn_cell_is_half(self):
cell = nn.RNNCell(self.h, self.h)
self.run_cell_test(cell)
@unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
def test_gru_cell_is_half(self):
cell = nn.GRUCell(self.h, self.h)
self.run_cell_test(cell)
@unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
def test_lstm_cell_is_half(self):
cell = nn.LSTMCell(self.h, self.h)
self.run_cell_test(cell, state_tuple=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment