"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "58fc8244881f2225803c85fa3179ac32b310cb9d"
Commit c088c257 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use torch.testing.assert_close in test_modeling_distillation

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/455

The test can be flaky due to numerical mismatch if using `self.AssertEqual`, eg. https://www.internalfb.com/intern/testinfra/diagnostics/1688850007977704.562950031998292.1672749571/

```
Traceback (most recent call last):
  File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/buck-out/v2/gen/fbcode/104a4d5c3a690252/mobile-vision/d2go/tests/__modeling_test_modeling_distillation__/modeling_test_modeling_distillation#link-tree/d2go/tests/modeling/test_modeling_distillation.py", line 674, in test_da_train
    self.assertEqual(
AssertionError: {'rea[14 chars]2894], grad_fn=<MulBackward0>), 'synthetic': t[85 chars]d0>)} != {'rea[14 chars]2894]), 'synthetic': tensor([1.4532]), 'add': [13 chars]64])}
- {'add': tensor([18.0064], grad_fn=<MulBackward0>),
-  'real': tensor([0.2894], grad_fn=<MulBackward0>),
-  'synthetic': tensor([1.4532], grad_fn=<MulBackward0>)}
+ {'add': tensor([18.0064]),
+  'real': tensor([0.2894]),
+  'synthetic': tensor([1.4532])}
```

.Change to use `torch.testing.assert_close` instead for tensor comparison.

Reviewed By: YanjunChen329

Differential Revision: D42352509

fbshipit-source-id: 8a647685d1347a9bd493f2faed7e066eb9159e14
parent 9e93852d
...@@ -365,7 +365,7 @@ class TestDistillation(unittest.TestCase): ...@@ -365,7 +365,7 @@ class TestDistillation(unittest.TestCase):
) )
input = torch.randn(1) input = torch.randn(1)
output = model(input) output = model(input)
self.assertEqual(output, cache["test_layer"]) torch.testing.assert_close(output, cache["test_layer"])
def test_cached_layer_list(self): def test_cached_layer_list(self):
"""Check cached layer saves list""" """Check cached layer saves list"""
...@@ -378,7 +378,7 @@ class TestDistillation(unittest.TestCase): ...@@ -378,7 +378,7 @@ class TestDistillation(unittest.TestCase):
) )
input = [torch.randn(1) for _ in range(2)] input = [torch.randn(1) for _ in range(2)]
output = model(input) output = model(input)
self.assertEqual(output, cache["test_layer"]) torch.testing.assert_close(output, cache["test_layer"])
def test_cached_layer_tuple(self): def test_cached_layer_tuple(self):
"""Check cached layer saves list""" """Check cached layer saves list"""
...@@ -391,7 +391,7 @@ class TestDistillation(unittest.TestCase): ...@@ -391,7 +391,7 @@ class TestDistillation(unittest.TestCase):
) )
input = (torch.randn(1) for _ in range(2)) input = (torch.randn(1) for _ in range(2))
output = model(input) output = model(input)
self.assertEqual(output, cache["test_layer"]) torch.testing.assert_close(output, cache["test_layer"])
def test_cached_layer_dict(self): def test_cached_layer_dict(self):
"""Check cached layer saves dict""" """Check cached layer saves dict"""
...@@ -404,7 +404,7 @@ class TestDistillation(unittest.TestCase): ...@@ -404,7 +404,7 @@ class TestDistillation(unittest.TestCase):
) )
input = [torch.randn(1) for _ in range(2)] input = [torch.randn(1) for _ in range(2)]
output = model(input) output = model(input)
self.assertEqual(output, cache["test_layer"]) torch.testing.assert_close(output, cache["test_layer"])
def test_cached_layer_arbitrary(self): def test_cached_layer_arbitrary(self):
"""Check cached layer saves arbitrary nested data structure""" """Check cached layer saves arbitrary nested data structure"""
...@@ -417,7 +417,7 @@ class TestDistillation(unittest.TestCase): ...@@ -417,7 +417,7 @@ class TestDistillation(unittest.TestCase):
) )
input = [torch.randn(1) for _ in range(2)] input = [torch.randn(1) for _ in range(2)]
output = model(input) output = model(input)
self.assertEqual(output, cache["test_layer"]) torch.testing.assert_close(output, cache["test_layer"])
def test_cached_layer_unsupported(self): def test_cached_layer_unsupported(self):
"""Check cached layer doesn't save unsupported data type like strings""" """Check cached layer doesn't save unsupported data type like strings"""
...@@ -438,10 +438,10 @@ class TestDistillation(unittest.TestCase): ...@@ -438,10 +438,10 @@ class TestDistillation(unittest.TestCase):
input = torch.Tensor([0]) input = torch.Tensor([0])
output = model(input) output = model(input)
self.assertEqual(cache["layer0"], torch.Tensor([1])) torch.testing.assert_close(cache["layer0"], torch.Tensor([1]))
self.assertEqual(cache["layer1"], torch.Tensor([2])) torch.testing.assert_close(cache["layer1"], torch.Tensor([2]))
self.assertEqual(cache["layer2"], torch.Tensor([3])) torch.testing.assert_close(cache["layer2"], torch.Tensor([3]))
self.assertEqual(cache[""], output) torch.testing.assert_close(cache[""], output)
def test_unrecord_layers(self): def test_unrecord_layers(self):
"""Check we can remove a recorded layer""" """Check we can remove a recorded layer"""
...@@ -463,8 +463,12 @@ class TestDistillation(unittest.TestCase): ...@@ -463,8 +463,12 @@ class TestDistillation(unittest.TestCase):
layer0_cache = {"l00": torch.randn(1), "l01": torch.randn(1)} layer0_cache = {"l00": torch.randn(1), "l01": torch.randn(1)}
layer1_cache = {"l10": torch.randn(1), "l11": torch.randn(1)} layer1_cache = {"l10": torch.randn(1), "l11": torch.randn(1)}
output = compute_layer_losses(layer_losses, layer0_cache, layer1_cache) output = compute_layer_losses(layer_losses, layer0_cache, layer1_cache)
self.assertEqual(output["add"], layer0_cache["l00"] + layer1_cache["l10"]) torch.testing.assert_close(
self.assertEqual(output["div"], layer0_cache["l01"] / layer1_cache["l11"]) output["add"], layer0_cache["l00"] + layer1_cache["l10"]
)
torch.testing.assert_close(
output["div"], layer0_cache["l01"] / layer1_cache["l11"]
)
def test_set_cache_dict(self): def test_set_cache_dict(self):
"""Check we can swap the cache dict used when recording layers""" """Check we can swap the cache dict used when recording layers"""
...@@ -518,7 +522,7 @@ class TestPseudoLabeler(unittest.TestCase): ...@@ -518,7 +522,7 @@ class TestPseudoLabeler(unittest.TestCase):
pseudo_labeler = NoopPseudoLabeler() pseudo_labeler = NoopPseudoLabeler()
x = np.random.randn(1) x = np.random.randn(1)
output = pseudo_labeler.label(x) output = pseudo_labeler.label(x)
self.assertEqual(x, output) torch.testing.assert_close(x, output)
def test_relabeltargetinbatch(self): def test_relabeltargetinbatch(self):
"""Check target is relabed using teacher""" """Check target is relabed using teacher"""
...@@ -529,7 +533,7 @@ class TestPseudoLabeler(unittest.TestCase): ...@@ -529,7 +533,7 @@ class TestPseudoLabeler(unittest.TestCase):
batched_inputs = _get_input_data(n=2, use_input_target=True) batched_inputs = _get_input_data(n=2, use_input_target=True)
gt = [{"input": d["input"], "target": d["input"] / 2.0} for d in batched_inputs] gt = [{"input": d["input"], "target": d["input"] / 2.0} for d in batched_inputs]
outputs = relabeler.label(batched_inputs) outputs = relabeler.label(batched_inputs)
self.assertEqual(outputs, gt) torch.testing.assert_close(outputs, gt)
class TestDistillationHelper(unittest.TestCase): class TestDistillationHelper(unittest.TestCase):
...@@ -657,7 +661,7 @@ class TestDistillationAlgorithm(unittest.TestCase): ...@@ -657,7 +661,7 @@ class TestDistillationAlgorithm(unittest.TestCase):
model.eval() model.eval()
input = {"real": torch.randn(1), "synthetic": torch.randn(1)} input = {"real": torch.randn(1), "synthetic": torch.randn(1)}
output = model(input) output = model(input)
self.assertEqual(output, input["real"] + 3.0) torch.testing.assert_close(output, input["real"] + 3.0)
def test_da_train(self): def test_da_train(self):
"""Check train pass results in updated loss output""" """Check train pass results in updated loss output"""
...@@ -671,13 +675,13 @@ class TestDistillationAlgorithm(unittest.TestCase): ...@@ -671,13 +675,13 @@ class TestDistillationAlgorithm(unittest.TestCase):
model.train() model.train()
input = {"real": torch.randn(1), "synthetic": torch.randn(1)} input = {"real": torch.randn(1), "synthetic": torch.randn(1)}
output = model(input) output = model(input)
self.assertEqual( self.assertEqual(set(output.keys()), {"real", "synthetic", "add"})
output, torch.testing.assert_close(output["real"], (input["real"] + 3.0) * 0.1)
{ torch.testing.assert_close(
"real": (input["real"] + 3.0) * 0.1, output["synthetic"], (input["synthetic"] + 3.0) * 0.5
"synthetic": (input["synthetic"] + 3.0) * 0.5, )
"add": ((input["real"] + 1.0) + (input["synthetic"] + 1.0)) * 10.0, torch.testing.assert_close(
}, output["add"], ((input["real"] + 1.0) + (input["synthetic"] + 1.0)) * 10.0
) )
def test_da_remove_dynamic_mixin(self): def test_da_remove_dynamic_mixin(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment