Unverified Commit ccb93dca authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Support EDM-style training in DreamBooth LoRA SDXL script (#7126)



* add: dreambooth lora script for Playground v2.5

* fix: kwarg

* address suraj's comments.

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* apply suraj's suggestion

* incorporate changes in the canonical script./

* tracker naming

* fix: schedule determination

* add: two simple tests

* remove playground script

* note about edm-style training

* address pedro's comments.

* address part of Suraj's comments.

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* remove guidance_scale.

* use mse_loss.

* add comments for preconditioning.

* quality

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* tackle v-pred.

* Empty-Commit

* support edm for sdxl too.

* address suraj's comments.

* Empty-Commit

---------
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent ec953047
...@@ -206,3 +206,40 @@ You can explore the results from a couple of our internal experiments by checkin ...@@ -206,3 +206,40 @@ You can explore the results from a couple of our internal experiments by checkin
## Running on a free-tier Colab Notebook ## Running on a free-tier Colab Notebook
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb). Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb).
## Conducting EDM-style training
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
For the SDXL model, simple set:
```diff
+ --do_edm_style_training \
```
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
```bash
accelerate launch train_dreambooth_lora_sdxl.py \
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
--instance_data_dir="dog" \
--output_dir="dog-playground-lora" \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--use_8bit_adam \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
> [!CAUTION]
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate):
def test_dreambooth_lora_sdxl_with_edm(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--do_edm_style_training
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
def test_dreambooth_lora_playground(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import contextlib
import gc import gc
import itertools import itertools
import json
import logging import logging
import math import math
import os import os
...@@ -32,7 +34,7 @@ import transformers ...@@ -32,7 +34,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, hf_hub_download, upload_folder
from huggingface_hub.utils import insecure_hashlib from huggingface_hub.utils import insecure_hashlib
from packaging import version from packaging import version
from peft import LoraConfig, set_peft_model_state_dict from peft import LoraConfig, set_peft_model_state_dict
...@@ -50,6 +52,8 @@ from diffusers import ( ...@@ -50,6 +52,8 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
DDPMScheduler, DDPMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EDMEulerScheduler,
EulerDiscreteScheduler,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
...@@ -76,6 +80,20 @@ check_min_version("0.27.0.dev0") ...@@ -76,6 +80,20 @@ check_min_version("0.27.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def determine_scheduler_type(pretrained_model_name_or_path, revision):
model_index_filename = "model_index.json"
if os.path.isdir(pretrained_model_name_or_path):
model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
else:
model_index = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
)
with open(model_index, "r") as f:
scheduler_type = json.load(f)["scheduler"][1]
return scheduler_type
def save_model_card( def save_model_card(
repo_id: str, repo_id: str,
images=None, images=None,
...@@ -95,7 +113,7 @@ def save_model_card( ...@@ -95,7 +113,7 @@ def save_model_card(
) )
model_description = f""" model_description = f"""
# SDXL LoRA DreamBooth - {repo_id} # {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
...@@ -119,11 +137,17 @@ Weights for this model are available in Safetensors format. ...@@ -119,11 +137,17 @@ Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab. [Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
if "playgroundai" in args.pretrained_model_name_or_path:
model_description += """\n
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
""" """
model_card = load_or_create_model_card( model_card = load_or_create_model_card(
repo_id_or_path=repo_id, repo_id_or_path=repo_id,
from_training=True, from_training=True,
license="openrail++", license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community",
base_model=base_model, base_model=base_model,
prompt=instance_prompt, prompt=instance_prompt,
model_description=model_description, model_description=model_description,
...@@ -131,15 +155,17 @@ Weights for this model are available in Safetensors format. ...@@ -131,15 +155,17 @@ Weights for this model are available in Safetensors format.
) )
tags = [ tags = [
"text-to-image", "text-to-image",
"stable-diffusion-xl",
"stable-diffusion-xl-diffusers",
"text-to-image", "text-to-image",
"diffusers", "diffusers",
"lora", "lora",
"template:sd-lora", "template:sd-lora",
] ]
model_card = populate_model_card(model_card, tags=tags) if "playgroundai" in base_model:
tags.extend(["playground", "playground-diffusers"])
else:
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md")) model_card.save(os.path.join(repo_folder, "README.md"))
...@@ -159,23 +185,29 @@ def log_validation( ...@@ -159,23 +185,29 @@ def log_validation(
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {} scheduler_args = {}
if "variance_type" in pipeline.scheduler.config: if not args.do_edm_style_training:
variance_type = pipeline.scheduler.config.variance_type if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]: if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small" variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = (
contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
)
with torch.cuda.amp.autocast(): with inference_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
...@@ -334,6 +366,12 @@ def parse_args(input_args=None): ...@@ -334,6 +366,12 @@ def parse_args(input_args=None):
" `args.validation_prompt` multiple times: `args.num_validation_images`." " `args.validation_prompt` multiple times: `args.num_validation_images`."
), ),
) )
parser.add_argument(
"--do_edm_style_training",
default=False,
action="store_true",
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
)
parser.add_argument( parser.add_argument(
"--with_prior_preservation", "--with_prior_preservation",
default=False, default=False,
...@@ -905,6 +943,9 @@ def main(args): ...@@ -905,6 +943,9 @@ def main(args):
" Please use `huggingface-cli login` to authenticate with the Hub." " Please use `huggingface-cli login` to authenticate with the Hub."
) )
if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
...@@ -1018,7 +1059,19 @@ def main(args): ...@@ -1018,7 +1059,19 @@ def main(args):
) )
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
if "EDM" in scheduler_type:
args.do_edm_style_training = True
noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
logger.info("Performing EDM-style training!")
elif args.do_edm_style_training:
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
logger.info("Performing EDM-style training!")
else:
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
...@@ -1036,6 +1089,12 @@ def main(args): ...@@ -1036,6 +1089,12 @@ def main(args):
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
) )
latents_mean = latents_std = None
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
...@@ -1433,7 +1492,12 @@ def main(args): ...@@ -1433,7 +1492,12 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args)) tracker_name = (
"dreambooth-lora-sd-xl"
if "playgroundai" not in args.pretrained_model_name_or_path
else "dreambooth-lora-playground"
)
accelerator.init_trackers(tracker_name, config=vars(args))
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -1485,6 +1549,18 @@ def main(args): ...@@ -1485,6 +1549,18 @@ def main(args):
disable=not accelerator.is_local_main_process, disable=not accelerator.is_local_main_process,
) )
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
if args.train_text_encoder: if args.train_text_encoder:
...@@ -1512,22 +1588,46 @@ def main(args): ...@@ -1512,22 +1588,46 @@ def main(args):
# Convert images to latent space # Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None: if latents_mean is None and latents_std is None:
model_input = model_input.to(weight_dtype) model_input = model_input * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
else:
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
bsz = model_input.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( if not args.do_edm_style_training:
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device timesteps = torch.randint(
) 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
timesteps = timesteps.long() )
timesteps = timesteps.long()
else:
# in EDM formulation, the model is conditioned on the pre-conditioned noise levels
# instead of discrete timesteps, so here we sample indices to get the noise levels
# from `scheduler.timesteps`
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
# Add noise to the model input according to the noise magnitude at each timestep # Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
if args.do_edm_style_training:
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
if "EDM" in scheduler_type:
inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
else:
inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
# time ids # time ids
add_time_ids = torch.cat( add_time_ids = torch.cat(
...@@ -1551,7 +1651,7 @@ def main(args): ...@@ -1551,7 +1651,7 @@ def main(args):
} }
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet( model_pred = unet(
noisy_model_input, inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps, timesteps,
prompt_embeds_input, prompt_embeds_input,
added_cond_kwargs=unet_added_conditions, added_cond_kwargs=unet_added_conditions,
...@@ -1570,18 +1670,43 @@ def main(args): ...@@ -1570,18 +1670,43 @@ def main(args):
) )
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet( model_pred = unet(
noisy_model_input, inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps, timesteps,
prompt_embeds_input, prompt_embeds_input,
added_cond_kwargs=unet_added_conditions, added_cond_kwargs=unet_added_conditions,
return_dict=False, return_dict=False,
)[0] )[0]
weighting = None
if args.do_edm_style_training:
# Similar to the input preconditioning, the model predictions are also preconditioned
# on noised model inputs (before preconditioning) and the sigmas.
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
if "EDM" in scheduler_type:
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
else:
if noise_scheduler.config.prediction_type == "epsilon":
model_pred = model_pred * (-sigmas) + noisy_model_input
elif noise_scheduler.config.prediction_type == "v_prediction":
model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
noisy_model_input / (sigmas**2 + 1)
)
# We are not doing weighting here because it tends result in numerical problems.
# See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
# There might be other alternatives for weighting as well:
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
if "EDM" not in scheduler_type:
weighting = (sigmas**-2.0).float()
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
target = noise target = model_input if args.do_edm_style_training else noise
elif noise_scheduler.config.prediction_type == "v_prediction": elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps) target = (
model_input
if args.do_edm_style_training
else noise_scheduler.get_velocity(model_input, noise, timesteps)
)
else: else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
...@@ -1591,10 +1716,28 @@ def main(args): ...@@ -1591,10 +1716,28 @@ def main(args):
target, target_prior = torch.chunk(target, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0)
# Compute prior loss # Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") if weighting is not None:
prior_loss = torch.mean(
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
target_prior.shape[0], -1
),
1,
)
prior_loss = prior_loss.mean()
else:
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
if args.snr_gamma is None: if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") if weighting is not None:
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
target.shape[0], -1
),
1,
)
loss = loss.mean()
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else: else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
...@@ -1696,7 +1839,6 @@ def main(args): ...@@ -1696,7 +1839,6 @@ def main(args):
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
images = log_validation( images = log_validation(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment