Unverified Commit e0d8c9ef authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Support for Offset Noise in examples (#2753)

* add noise offset

* make style
parent 92e1164e
...@@ -297,6 +297,7 @@ def parse_args(): ...@@ -297,6 +297,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -705,6 +706,12 @@ def main(): ...@@ -705,6 +706,12 @@ def main():
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
......
...@@ -333,6 +333,7 @@ def parse_args(): ...@@ -333,6 +333,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -718,6 +719,12 @@ def main(): ...@@ -718,6 +719,12 @@ def main():
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment