Unverified Commit 1997614a authored by Prathik Rao's avatar Prathik Rao Committed by GitHub
Browse files

avoid upcasting by assigning dtype to noise tensor (#3713)



* avoid upcasting by assigning dtype to noise tensor

* make style

* Update train_unconditional.py

* Update train_unconditional.py

* make style

* add unit test for pickle

* revert change

---------
Co-authored-by: default avatarroot <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
parent 4e898560
......@@ -568,7 +568,9 @@ def main(args):
clean_images = batch["input"]
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
noise = torch.randn(
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
).to(clean_images.device)
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
......
......@@ -557,7 +557,9 @@ def main(args):
clean_images = batch["input"]
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
noise = torch.randn(
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
).to(clean_images.device)
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment