Unverified Commit 0b64c2c6 authored by Nipun Jindal's avatar Nipun Jindal Committed by GitHub
Browse files

[Stochastic Sampler][Slow Test]: Cuda test fixes (#3257)



[Slow Test]: Cuda test fixes
Co-authored-by: default avatarnjindal <njindal@adobe.com>
parent fd512d74
......@@ -65,6 +65,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
assert abs(result_mean.item() - 0.22342906892299652) < 1e-3
else:
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
......@@ -94,6 +97,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 128.1663360595703) < 1e-2
assert abs(result_mean.item() - 0.16688326001167297) < 1e-3
else:
assert abs(result_sum.item() - 119.8487548828125) < 1e-2
assert abs(result_mean.item() - 0.1560530662536621) < 1e-3
......@@ -122,6 +128,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 171.59353637695312) < 1e-2
assert abs(result_mean.item() - 0.22342908382415771) < 1e-3
else:
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
......@@ -151,6 +160,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 177.63653564453125) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
else:
assert abs(result_sum.item() - 170.3135223388672) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment