Unverified Commit 3be48918 authored by Yaman Ahlawat's avatar Yaman Ahlawat Committed by GitHub
Browse files

feat: allow offset_noise in dreambooth training example (#2826)

parent d82b0323
...@@ -417,6 +417,16 @@ def parse_args(input_args=None): ...@@ -417,6 +417,16 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument(
"--offset_noise",
action="store_true",
default=False,
help=(
"Fine-tuning against a modified noise"
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
),
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
else: else:
...@@ -943,6 +953,11 @@ def main(args): ...@@ -943,6 +953,11 @@ def main(args):
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
if args.offset_noise:
noise = torch.randn_like(latents) + 0.1 * torch.randn(
latents.shape[0], latents.shape[1], 1, 1, device=latents.device
)
else:
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment