"docs/vscode:/vscode.git/clone" did not exist on "d70f8ee18b50c38f377a18a9fa8da0ae15b6426d"
Commit 20d91782 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

correct readme

parent 7764669c
...@@ -48,7 +48,11 @@ noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church") ...@@ -48,7 +48,11 @@ noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise # 2. Sample gaussian noise
image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator) image = torch.randn(
(1, unet.in_channels, unet.resolution, unet.resolution)
generator=generator,
)
image = image.to(torch_device)
# 3. Denoise # 3. Denoise
num_prediction_steps = len(noise_scheduler) num_prediction_steps = len(noise_scheduler)
...@@ -63,7 +67,7 @@ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_s ...@@ -63,7 +67,7 @@ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_s
# optionally sample variance # optionally sample variance
variance = 0 variance = 0
if t > 0: if t > 0:
noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = noise_scheduler.get_variance(t).sqrt() * noise variance = noise_scheduler.get_variance(t).sqrt() * noise
# set current image to prev_image: x_t -> x_t-1 # set current image to prev_image: x_t -> x_t-1
...@@ -96,7 +100,11 @@ noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq") ...@@ -96,7 +100,11 @@ noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq")
unet = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device) unet = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device)
# 2. Sample gaussian noise # 2. Sample gaussian noise
image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator) image = torch.randn(
(1, unet.in_channels, unet.resolution, unet.resolution)
generator=generator,
)
image = image.to(torch_device)
# 3. Denoise # 3. Denoise
num_inference_steps = 50 num_inference_steps = 50
...@@ -114,7 +122,7 @@ for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_ste ...@@ -114,7 +122,7 @@ for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_ste
# 3. optionally sample variance # 3. optionally sample variance
variance = 0 variance = 0
if eta > 0: if eta > 0:
noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = noise_scheduler.get_variance(t).sqrt() * eta * noise variance = noise_scheduler.get_variance(t).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment