Unverified Commit 1fcf279d authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix mps tests on torch 2.0 (#2766)

parent 58bcf46a
...@@ -255,10 +255,7 @@ class SimpleCrossAttnUpBlock2DTests(UNetBlockTesterMixin, unittest.TestCase): ...@@ -255,10 +255,7 @@ class SimpleCrossAttnUpBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
return init_dict, inputs_dict return init_dict, inputs_dict
def test_output(self): def test_output(self):
if torch_device == "mps": expected_slice = [0.2645, 0.1480, 0.0909, 0.8044, -0.9758, -0.9083, 0.0994, -1.1453, -0.7402]
expected_slice = [0.4327, 0.5538, 0.3919, 0.5682, 0.2704, 0.1573, -0.8768, -0.4615, -0.4146]
else:
expected_slice = [0.2645, 0.1480, 0.0909, 0.8044, -0.9758, -0.9083, 0.0994, -1.1453, -0.7402]
super().test_output(expected_slice) super().test_output(expected_slice)
...@@ -336,8 +333,5 @@ class AttnUpDecoderBlock2DTests(UNetBlockTesterMixin, unittest.TestCase): ...@@ -336,8 +333,5 @@ class AttnUpDecoderBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
return init_dict, inputs_dict return init_dict, inputs_dict
def test_output(self): def test_output(self):
if torch_device == "mps": expected_slice = [0.6738, 0.4491, 0.1055, 1.0710, 0.7316, 0.3339, 0.3352, 0.1023, 0.3568]
expected_slice = [-0.3669, -0.3387, 0.1029, -0.6564, 0.2728, -0.3233, 0.5977, -0.1784, 0.5482]
else:
expected_slice = [0.6738, 0.4491, 0.1055, 1.0710, 0.7316, 0.3339, 0.3352, 0.1023, 0.3568]
super().test_output(expected_slice) super().test_output(expected_slice)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment