"dgl_sparse/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "d02e560e07754d4d01b25cf99c4d6457e2e9cdd1"
Unverified Commit 3e69e241 authored by Seongbin Lim's avatar Seongbin Lim Committed by GitHub
Browse files

Allow DDPMPipeline half precision (#9222)


Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 65f9439b
...@@ -101,10 +101,10 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -101,10 +101,10 @@ class DDPMPipeline(DiffusionPipeline):
if self.device.type == "mps": if self.device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
image = randn_tensor(image_shape, generator=generator) image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
image = image.to(self.device) image = image.to(self.device)
else: else:
image = randn_tensor(image_shape, generator=generator, device=self.device) image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment