Commit 7c0a8618 authored by anton-l's avatar anton-l
Browse files

Add torch_device to the VE pipeline

parent a73ae3e5
...@@ -11,22 +11,23 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -11,22 +11,23 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.register_modules(model=model, scheduler=scheduler) self.register_modules(model=model, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"): def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
img_size = self.model.config.sample_size img_size = self.model.config.sample_size
shape = (1, 3, img_size, img_size) shape = (batch_size, 3, img_size, img_size)
model = self.model.to(device) model = self.model.to(torch_device)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(device) sample = sample.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps)
for i, t in tqdm(enumerate(self.scheduler.timesteps)): for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
# correction step # correction step
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment