"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bc4039886d4a1163b6f99912a17d4b82ad00adce"
Unverified Commit 0b64c2c6 authored by Nipun Jindal's avatar Nipun Jindal Committed by GitHub
Browse files

[Stochastic Sampler][Slow Test]: Cuda test fixes (#3257)



[Slow Test]: Cuda test fixes
Co-authored-by: default avatarnjindal <njindal@adobe.com>
parent fd512d74
...@@ -65,6 +65,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest): ...@@ -65,6 +65,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]: if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.47821044921875) < 1e-2 assert abs(result_sum.item() - 167.47821044921875) < 1e-2
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3 assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
assert abs(result_mean.item() - 0.22342906892299652) < 1e-3
else: else:
assert abs(result_sum.item() - 162.52383422851562) < 1e-2 assert abs(result_sum.item() - 162.52383422851562) < 1e-2
assert abs(result_mean.item() - 0.211619570851326) < 1e-3 assert abs(result_mean.item() - 0.211619570851326) < 1e-3
...@@ -94,6 +97,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest): ...@@ -94,6 +97,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]: if torch_device in ["mps"]:
assert abs(result_sum.item() - 124.77149200439453) < 1e-2 assert abs(result_sum.item() - 124.77149200439453) < 1e-2
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3 assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 128.1663360595703) < 1e-2
assert abs(result_mean.item() - 0.16688326001167297) < 1e-3
else: else:
assert abs(result_sum.item() - 119.8487548828125) < 1e-2 assert abs(result_sum.item() - 119.8487548828125) < 1e-2
assert abs(result_mean.item() - 0.1560530662536621) < 1e-3 assert abs(result_mean.item() - 0.1560530662536621) < 1e-3
...@@ -122,6 +128,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest): ...@@ -122,6 +128,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]: if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.46957397460938) < 1e-2 assert abs(result_sum.item() - 167.46957397460938) < 1e-2
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3 assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 171.59353637695312) < 1e-2
assert abs(result_mean.item() - 0.22342908382415771) < 1e-3
else: else:
assert abs(result_sum.item() - 162.52383422851562) < 1e-2 assert abs(result_sum.item() - 162.52383422851562) < 1e-2
assert abs(result_mean.item() - 0.211619570851326) < 1e-3 assert abs(result_mean.item() - 0.211619570851326) < 1e-3
...@@ -151,6 +160,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest): ...@@ -151,6 +160,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]: if torch_device in ["mps"]:
assert abs(result_sum.item() - 176.66974135742188) < 1e-2 assert abs(result_sum.item() - 176.66974135742188) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 177.63653564453125) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
else: else:
assert abs(result_sum.item() - 170.3135223388672) < 1e-2 assert abs(result_sum.item() - 170.3135223388672) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment