"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "005a334f99b236eea3f7fe2cc8e5ad946116dd10"
Unverified Commit 1997614a authored by Prathik Rao's avatar Prathik Rao Committed by GitHub
Browse files

avoid upcasting by assigning dtype to noise tensor (#3713)



* avoid upcasting by assigning dtype to noise tensor

* make style

* Update train_unconditional.py

* Update train_unconditional.py

* make style

* add unit test for pickle

* revert change

---------
Co-authored-by: default avatarroot <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
parent 4e898560
...@@ -568,7 +568,9 @@ def main(args): ...@@ -568,7 +568,9 @@ def main(args):
clean_images = batch["input"] clean_images = batch["input"]
# Sample noise that we'll add to the images # Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device) noise = torch.randn(
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
).to(clean_images.device)
bsz = clean_images.shape[0] bsz = clean_images.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( timesteps = torch.randint(
......
...@@ -557,7 +557,9 @@ def main(args): ...@@ -557,7 +557,9 @@ def main(args):
clean_images = batch["input"] clean_images = batch["input"]
# Sample noise that we'll add to the images # Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device) noise = torch.randn(
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
).to(clean_images.device)
bsz = clean_images.shape[0] bsz = clean_images.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( timesteps = torch.randint(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment