Unverified Commit 4836cfad authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Sigmas] Keep sigmas on CPU (#6173)

* correct

* Apply suggestions from code review

* make style
parent 1ccbfbb6
......@@ -98,6 +98,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.custom_timesteps = False
self.is_scale_input_called = False
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
......@@ -230,6 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Modified _convert_to_karras implementation that takes in ramp as argument
def _convert_to_karras(self, ramp):
......
......@@ -187,6 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
......@@ -254,6 +255,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
......
......@@ -214,6 +214,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
......@@ -290,6 +291,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
......
......@@ -209,6 +209,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.use_karras_sigmas = use_karras_sigmas
@property
......@@ -289,6 +290,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
......
......@@ -198,6 +198,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.noise_sampler = None
self.noise_sampler_seed = noise_sampler_seed
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
......@@ -347,6 +348,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.mid_point_sigma = None
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.noise_sampler = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
......
......@@ -197,6 +197,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
self.order_list = self.get_order_list(num_train_timesteps)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def get_order_list(self, num_inference_steps: int) -> List[int]:
"""
......@@ -288,6 +289,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
......
......@@ -166,6 +166,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
......@@ -249,6 +250,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
......
......@@ -237,6 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
......@@ -341,6 +342,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
......
......@@ -148,6 +148,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
......@@ -269,6 +270,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.dt = None
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
......
......@@ -140,6 +140,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
......@@ -295,6 +296,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._index_counter = defaultdict(int)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
......
......@@ -140,6 +140,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
......@@ -284,6 +285,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self._index_counter = defaultdict(int)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def state_in_first_order(self):
......
......@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
......@@ -279,6 +280,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.derivatives = []
......
......@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
......@@ -268,6 +269,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment