Unverified Commit 63f767ef authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Add SVD (#5895)

* begin model

* finish blocks

* add_embedding

* addition_time_embed_dim

* use TimestepEmbedding

* fix temporal res block

* fix time_pos_embed

* fix add_embedding

* add conversion script

* fix model

* up

* add new resnet blocks

* make forward work

* return sample in original shape

* fix temb shape in TemporalResnetBlock

* add spatio temporal transformers

* add vae blocks

* fix blocks

* update

* update

* fix shapes in Alphablender and add time activation in res blcok

* use new blocks

* style

* fix temb shape

* fix SpatioTemporalResBlock

* reuse TemporalBasicTransformerBlock

* fix TemporalBasicTransformerBlock

* use TransformerSpatioTemporalModel

* fix TransformerSpatioTemporalModel

* fix time_context dim

* clean up

* make temb optional

* add blocks

* rename model

* update conversion script

* remove UNetMidBlockSpatioTemporal

* add in init

* remove unused arg

* remove unused arg
...
parent d1b2a1a9
...@@ -323,8 +323,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -323,8 +323,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() # Hack to make sure that other schedulers which copy this function don't break
sigma_max: float = in_sigmas[0].item() # TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps) ramp = np.linspace(0, 1, num_inference_steps)
......
...@@ -358,8 +358,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -358,8 +358,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() # Hack to make sure that other schedulers which copy this function don't break
sigma_max: float = in_sigmas[0].item() # TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps) ramp = np.linspace(0, 1, num_inference_steps)
......
...@@ -358,8 +358,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -358,8 +358,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() # Hack to make sure that other schedulers which copy this function don't break
sigma_max: float = in_sigmas[0].item() # TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps) ramp = np.linspace(0, 1, num_inference_steps)
......
...@@ -357,8 +357,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -357,8 +357,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() # Hack to make sure that other schedulers which copy this function don't break
sigma_max: float = in_sigmas[0].item() # TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps) ramp = np.linspace(0, 1, num_inference_steps)
......
...@@ -144,7 +144,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -144,7 +144,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
interpolation_type: str = "linear", interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0, steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
...@@ -164,13 +167,22 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -164,13 +167,22 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.sigmas = torch.from_numpy(sigmas)
sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps) # TODO: Support the full EDM scalings for all prediction types and timestep types
if timestep_type == "continuous" and prediction_type == "v_prediction":
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
else:
self.timesteps = timesteps
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.is_scale_input_called = False self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
...@@ -268,10 +280,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -268,10 +280,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device) # TODO: Support the full EDM scalings for all prediction types and timestep types
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
else:
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None self._step_index = None
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma, log_sigmas):
...@@ -301,8 +318,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -301,8 +318,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() # Hack to make sure that other schedulers which copy this function don't break
sigma_max: float = in_sigmas[0].item() # TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps) ramp = np.linspace(0, 1, num_inference_steps)
...@@ -412,7 +441,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -412,7 +441,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
elif self.config.prediction_type == "epsilon": elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip # denoised = model_output * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else: else:
raise ValueError( raise ValueError(
......
...@@ -303,8 +303,20 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -303,8 +303,20 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() # Hack to make sure that other schedulers which copy this function don't break
sigma_max: float = in_sigmas[0].item() # TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps) ramp = np.linspace(0, 1, num_inference_steps)
......
...@@ -324,8 +324,20 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -324,8 +324,20 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() # Hack to make sure that other schedulers which copy this function don't break
sigma_max: float = in_sigmas[0].item() # TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps) ramp = np.linspace(0, 1, num_inference_steps)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -37,6 +37,14 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -37,6 +37,14 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
for prediction_type in ["epsilon", "v_prediction"]: for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_timestep_type(self):
timestep_types = ["discrete", "continuous"]
for timestep_type in timestep_types:
self.check_over_configs(timestep_type=timestep_type)
def test_karras_sigmas(self):
self.check_over_configs(use_karras_sigmas=True, sigma_min=0.02, sigma_max=700.0)
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment