"vscode:/vscode.git/clone" did not exist on "c7ba6ba2678ca7e4e58320da8209be8883a56322"
Unverified Commit 9e234d80 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

handle fp16 in `UNet2DModel` (#1216)



* make sure fp16 runs well

* add fp16 test for superes

* Update src/diffusers/models/unet_2d.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* gen on cuda

* always run fast inferecne test on cpu

* run on cpu
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 8fd3a743
...@@ -209,6 +209,11 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -209,6 +209,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
# 2. pre-process # 2. pre-process
...@@ -242,9 +247,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -242,9 +247,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
sample = upsample_block(sample, res_samples, emb) sample = upsample_block(sample, res_samples, emb)
# 6. post-process # 6. post-process
# make sure hidden states is in float32 sample = self.conv_norm_out(sample)
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
......
...@@ -87,6 +87,27 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -87,6 +87,27 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase
expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150]) expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_inference_superresolution_fp16(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()
vqvae = self.dummy_vq_model
# put models in fp16
unet = unet.half()
vqvae = vqvae.half()
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
init_image = self.dummy_image.to(torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
assert image.shape == (1, 64, 64, 3)
@slow @slow
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment