Unverified Commit 49b7ccfb authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

parameterize pass single args through tuple (#3477)

parent 7200985e
...@@ -321,7 +321,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -321,7 +321,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand([13, 16, 27]) @parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu @require_torch_gpu
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
...@@ -339,7 +339,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -339,7 +339,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
assert torch_all_close(sample, sample_2, atol=1e-1) assert torch_all_close(sample, sample_2, atol=1e-1)
@parameterized.expand([13, 16, 37]) @parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu @require_torch_gpu
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed): def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment