Unverified Commit b85bb075 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

support v prediction in other schedulers (#1505)

* support v prediction in other schedulers

* v heun

* add tests for v pred

* fix tests

* fix test euler a

* v ddpm
parent 52eb0348
...@@ -280,10 +280,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -280,10 +280,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
pred_original_sample = model_output pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
" for the DDPMScheduler." " `v_prediction` for the DDPMScheduler."
) )
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
......
...@@ -78,6 +78,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -78,6 +78,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -202,7 +203,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -202,7 +203,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigma = self.sigmas[step_index] sigma = self.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
sigma_from = self.sigmas[step_index] sigma_from = self.sigmas[step_index]
sigma_to = self.sigmas[step_index + 1] sigma_to = self.sigmas[step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
......
...@@ -54,6 +54,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -54,6 +54,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.012, beta_end: float = 0.012,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -184,7 +185,15 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -184,7 +185,15 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma_hat * model_output if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
if self.state_in_first_order: if self.state_in_first_order:
# 2. Convert to an ODE derivative # 2. Convert to an ODE derivative
......
...@@ -78,6 +78,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -78,6 +78,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -215,7 +216,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -215,7 +216,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigma = self.sigmas[step_index] sigma = self.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. Convert to an ODE derivative # 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma derivative = (sample - pred_original_sample) / sigma
......
...@@ -102,6 +102,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -102,6 +102,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
set_alpha_to_one: bool = False, set_alpha_to_one: bool = False,
prediction_type: str = "epsilon",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
...@@ -368,6 +369,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -368,6 +369,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
if self.config.prediction_type == "v_prediction":
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
elif self.config.prediction_type != "epsilon":
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
)
# corresponds to (α_(t−δ) - α_t) divided by # corresponds to (α_(t−δ) - α_t) divided by
# denominator of x_t in formula (9) and plus 1 # denominator of x_t in formula (9) and plus 1
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
......
...@@ -635,7 +635,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -635,7 +635,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
def test_prediction_type(self): def test_prediction_type(self):
for prediction_type in ["epsilon", "sample"]: for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self): def test_deprecated_predict_epsilon(self):
...@@ -711,6 +711,37 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -711,6 +711,37 @@ class DDPMSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 258.9070) < 1e-2 assert abs(result_sum.item() - 258.9070) < 1e-2
assert abs(result_mean.item() - 0.3374) < 1e-3 assert abs(result_mean.item() - 0.3374) < 1e-3
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
scheduler = scheduler_class(**scheduler_config)
num_trained_timesteps = len(scheduler)
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)
for t in reversed(range(num_trained_timesteps)):
# 1. predict noise residual
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
# if t > 0:
# noise = self.dummy_sample_deter
# variance = scheduler.get_variance(t) ** (0.5) * noise
#
# sample = pred_prev_sample + variance
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 201.9864) < 1e-2
assert abs(result_mean.item() - 0.2630) < 1e-3
class DDIMSchedulerTest(SchedulerCommonTest): class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDIMScheduler,) scheduler_classes = (DDIMScheduler,)
...@@ -768,6 +799,10 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -768,6 +799,10 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for schedule in ["linear", "squaredcos_cap_v2"]: for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule) self.check_over_configs(beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_clip_sample(self): def test_clip_sample(self):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
...@@ -805,6 +840,15 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -805,6 +840,15 @@ class DDIMSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 172.0067) < 1e-2 assert abs(result_sum.item() - 172.0067) < 1e-2
assert abs(result_mean.item() - 0.223967) < 1e-3 assert abs(result_mean.item() - 0.223967) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 52.5302) < 1e-2
assert abs(result_mean.item() - 0.0684) < 1e-3
def test_full_loop_with_set_alpha_to_one(self): def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99 # We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
...@@ -971,6 +1015,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -971,6 +1015,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
solver_type=solver_type, solver_type=solver_type,
) )
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_solver_order_and_type(self): def test_solver_order_and_type(self):
for algorithm_type in ["dpmsolver", "dpmsolver++"]: for algorithm_type in ["dpmsolver", "dpmsolver++"]:
for solver_type in ["midpoint", "heun"]: for solver_type in ["midpoint", "heun"]:
...@@ -1004,6 +1052,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -1004,6 +1052,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.3301) < 1e-3 assert abs(result_mean.item() - 0.3301) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2251) < 1e-3
def test_fp16_support(self): def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
...@@ -1184,6 +1238,10 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -1184,6 +1238,10 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for schedule in ["linear", "squaredcos_cap_v2"]: for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule) self.check_over_configs(beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_time_indices(self): def test_time_indices(self):
for t in [1, 5, 10]: for t in [1, 5, 10]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
...@@ -1225,6 +1283,14 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -1225,6 +1283,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 198.1318) < 1e-2 assert abs(result_sum.item() - 198.1318) < 1e-2
assert abs(result_mean.item() - 0.2580) < 1e-3 assert abs(result_mean.item() - 0.2580) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 67.3986) < 1e-2
assert abs(result_mean.item() - 0.0878) < 1e-3
def test_full_loop_with_set_alpha_to_one(self): def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99 # We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
...@@ -1453,6 +1519,10 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1453,6 +1519,10 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
for schedule in ["linear", "scaled_linear"]: for schedule in ["linear", "scaled_linear"]:
self.check_over_configs(beta_schedule=schedule) self.check_over_configs(beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_time_indices(self): def test_time_indices(self):
for t in [0, 500, 800]: for t in [0, 500, 800]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
...@@ -1481,6 +1551,30 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1481,6 +1551,30 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3 assert abs(result_mean.item() - 1.31) < 1e-3
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 0.0017) < 1e-2
assert abs(result_mean.item() - 2.2676e-06) < 1e-3
def test_full_loop_device(self): def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -1534,6 +1628,10 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1534,6 +1628,10 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
for schedule in ["linear", "scaled_linear"]: for schedule in ["linear", "scaled_linear"]:
self.check_over_configs(beta_schedule=schedule) self.check_over_configs(beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -1565,6 +1663,37 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1565,6 +1663,37 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 10.0807) < 1e-2 assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3 assert abs(result_mean.item() - 0.0131) < 1e-3
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 0.0002) < 1e-2
assert abs(result_mean.item() - 2.2676e-06) < 1e-3
def test_full_loop_device(self): def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -1624,6 +1753,10 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1624,6 +1753,10 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
for schedule in ["linear", "scaled_linear"]: for schedule in ["linear", "scaled_linear"]:
self.check_over_configs(beta_schedule=schedule) self.check_over_configs(beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -1660,6 +1793,42 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1660,6 +1793,42 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 144.8084) < 1e-2 assert abs(result_sum.item() - 144.8084) < 1e-2
assert abs(result_mean.item() - 0.18855) < 1e-3 assert abs(result_mean.item() - 0.18855) < 1e-3
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 108.4439) < 1e-2
assert abs(result_mean.item() - 0.1412) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 102.5807) < 1e-2
assert abs(result_mean.item() - 0.1335) < 1e-3
def test_full_loop_device(self): def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -1932,6 +2101,10 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1932,6 +2101,10 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
for schedule in ["linear", "scaled_linear"]: for schedule in ["linear", "scaled_linear"]:
self.check_over_configs(beta_schedule=schedule) self.check_over_configs(beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -1962,6 +2135,36 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1962,6 +2135,36 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 0.1233) < 1e-2 assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3 assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 4.6934e-07) < 1e-2
assert abs(result_mean.item() - 6.1112e-10) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 4.693428650170972e-07) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_device(self): def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment