Unverified Commit 19ab04ff authored by Beinsezii's avatar Beinsezii Committed by GitHub
Browse files

UniPC Multistep fix tensor dtype/device on order=3 (#7532)

* UniPC UTs iterate solvers on FP16

It wasn't catching errs on order==3. Might be excessive?

* UniPC Multistep fix tensor dtype/device on order=3

* UniPC UTs Add v_pred to fp16 test iter

For completions sake. Probably overkill?
parent 4a343077
...@@ -576,7 +576,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -576,7 +576,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if order == 2: if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
else: else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
else: else:
D1s = None D1s = None
...@@ -714,7 +714,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -714,7 +714,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if order == 1: if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else: else:
rhos_c = torch.linalg.solve(R, b) rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
if self.predict_x0: if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
......
...@@ -229,20 +229,29 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -229,20 +229,29 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.1966) < 1e-3 assert abs(result_mean.item() - 0.1966) < 1e-3
def test_fp16_support(self): def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0] for order in [1, 2, 3]:
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) for solver_type in ["bh1", "bh2"]:
scheduler = scheduler_class(**scheduler_config) for prediction_type in ["epsilon", "sample", "v_prediction"]:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(
thresholding=True,
dynamic_thresholding_ratio=0,
prediction_type=prediction_type,
solver_order=order,
solver_type=solver_type,
)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10 num_inference_steps = 10
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter.half() sample = self.dummy_sample_deter.half()
scheduler.set_timesteps(num_inference_steps) scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps): for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16 assert sample.dtype == torch.float16
def test_full_loop_with_noise(self): def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment