Unverified Commit 051b3463 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Half precision] Make sure half-precision is correct (#182)



* [Half precision] Make sure half-precision is correct

* Update src/diffusers/models/unet_2d.py

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* correct some tests

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* finalize

* finish
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 5f25818a
...@@ -32,10 +32,10 @@ def get_timestep_embedding( ...@@ -32,10 +32,10 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
exponent = exponent / (half_dim - downscale_freq_shift)
emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift) emb = torch.exp(exponent).to(device=timesteps.device)
emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
emb = torch.exp(emb * emb_coeff)
emb = timesteps[:, None].float() * emb[None, :] emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings # scale embeddings
......
...@@ -331,7 +331,9 @@ class ResnetBlock(nn.Module): ...@@ -331,7 +331,9 @@ class ResnetBlock(nn.Module):
def forward(self, x, temb, hey=False): def forward(self, x, temb, hey=False):
h = x h = x
h = self.norm1(h) # make sure hidden states is in float32
# when running in half-precision
h = self.norm1(h.float()).type(h.dtype)
h = self.nonlinearity(h) h = self.nonlinearity(h)
if self.upsample is not None: if self.upsample is not None:
...@@ -347,7 +349,9 @@ class ResnetBlock(nn.Module): ...@@ -347,7 +349,9 @@ class ResnetBlock(nn.Module):
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h + temb h = h + temb
h = self.norm2(h) # make sure hidden states is in float32
# when running in half-precision
h = self.norm2(h.float()).type(h.dtype)
h = self.nonlinearity(h) h = self.nonlinearity(h)
h = self.dropout(h) h = self.dropout(h)
......
...@@ -132,6 +132,9 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -132,6 +132,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension
timesteps = timesteps.broadcast_to(sample.shape[0])
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
...@@ -166,7 +169,9 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -166,7 +169,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
sample = upsample_block(sample, res_samples, emb) sample = upsample_block(sample, res_samples, emb)
# 6. post-process # 6. post-process
sample = self.conv_norm_out(sample) # make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
......
...@@ -133,6 +133,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -133,6 +133,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension
timesteps = timesteps.broadcast_to(sample.shape[0])
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
...@@ -172,8 +175,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -172,8 +175,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
# 6. post-process # 6. post-process
# make sure hidden states is in float32
sample = self.conv_norm_out(sample) # when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
......
...@@ -55,7 +55,13 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -55,7 +55,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
self.text_encoder.to(torch_device) self.text_encoder.to(torch_device)
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
...@@ -79,19 +85,25 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -79,19 +85,25 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8), (batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator, generator=generator,
device=torch_device,
) )
latents = latents.to(torch_device)
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1] # and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {} extra_step_kwargs = {}
if accepts_eta: if accepts_eta:
extra_kwargs["eta"] = eta extra_step_kwargs["eta"] = eta
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps): for t in tqdm(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
...@@ -106,7 +118,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -106,7 +118,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
......
...@@ -59,6 +59,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -59,6 +59,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas=None, trained_betas=None,
timestep_values=None, timestep_values=None,
clip_sample=True, clip_sample=True,
clip_alpha_at_one=True,
tensor_format="pt", tensor_format="pt",
): ):
...@@ -75,7 +76,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -75,7 +76,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `clip_alpha_at_one` decides whether we set this paratemer simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = np.array(1.0) if clip_alpha_at_one else self.alphas_cumprod[0]
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
...@@ -86,7 +92,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -86,7 +92,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def _get_variance(self, timestep, prev_timestep): def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
...@@ -94,11 +100,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -94,11 +100,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps, offset=0):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange( self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy() )[::-1].copy()
self.timesteps += offset
self.set_format(tensor_format=self.tensor_format) self.set_format(tensor_format=self.tensor_format)
def step( def step(
...@@ -126,7 +133,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -126,7 +133,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called # 3. compute predicted original sample from predicted noise also called
......
...@@ -37,6 +37,7 @@ from diffusers import ( ...@@ -37,6 +37,7 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
ScoreSdeVePipeline, ScoreSdeVePipeline,
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
StableDiffusionPipeline,
UNet2DModel, UNet2DModel,
VQModel, VQModel,
) )
...@@ -45,8 +46,6 @@ from diffusers.pipeline_utils import DiffusionPipeline ...@@ -45,8 +46,6 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from ..src.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -667,7 +666,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -667,7 +666,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03]) expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
...@@ -842,38 +841,51 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -842,38 +841,51 @@ class PipelineTesterMixin(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
def test_stable_diffusion(self): def test_stable_diffusion(self):
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[ with torch.autocast("cuda"):
"sample" output = sd_pipe(
] [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
)
image = output["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3) assert image.shape == (1, 512, 512, 3)
# fmt: off expected_slice = np.array([0.898, 0.9194, 0.91, 0.8955, 0.915, 0.919, 0.9233, 0.9307, 0.8887])
expected_slice = np.array([0.09609553, 0.09020892, 0.07902172, 0.07634321, 0.08755809, 0.06491277, 0.07687345, 0.07173461, 0.07374045])
# fmt: on
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
def test_stable_diffusion_fast(self): @unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") def test_stable_diffusion_fast_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
clip_alpha_at_one=False,
)
sd_pipe.scheduler = scheduler
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
with torch.autocast("cuda"):
output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
image = output["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3) assert image.shape == (1, 512, 512, 3)
# fmt: off expected_slice = np.array([0.8354, 0.83, 0.866, 0.838, 0.8315, 0.867, 0.836, 0.8584, 0.869])
expected_slice = np.array([0.16537648, 0.17572534, 0.14657784, 0.20084214, 0.19819549, 0.16032678, 0.30438453, 0.22730353, 0.21307352]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
# fmt: on
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
...@@ -890,6 +902,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -890,6 +902,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3) assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633]) expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment