Unverified Commit fa9e35fc authored by Isamu Isozaki's avatar Isamu Isozaki Committed by GitHub
Browse files

Added input pretubation (#3292)

* Added input pretubation

* Fixed spelling
parent 4bae76e4
...@@ -112,6 +112,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight ...@@ -112,6 +112,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
)
parser.add_argument( parser.add_argument(
"--pretrained_model_name_or_path", "--pretrained_model_name_or_path",
type=str, type=str,
...@@ -801,7 +804,8 @@ def main(): ...@@ -801,7 +804,8 @@ def main():
noise += args.noise_offset * torch.randn( noise += args.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
) )
if args.input_pertubation:
new_noise = noise + args.input_pertubation * torch.randn_like(noise)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
...@@ -809,7 +813,10 @@ def main(): ...@@ -809,7 +813,10 @@ def main():
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) if args.input_pertubation:
noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning # Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0] encoder_hidden_states = text_encoder(batch["input_ids"])[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment