Unverified Commit dcfa6e1d authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add Min-SNR loss to Controlnet flax train script (#3016)

* add wandb team and min-snr loss

* make style

* apply feedbacks
parent 1c96f82e
......@@ -409,3 +409,7 @@ You can then start your training from this saved checkpoint with
```bash
--controlnet_model_name_or_path="./control_out/500"
```
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
\ No newline at end of file
......@@ -289,6 +289,13 @@ def parse_args():
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
......@@ -328,11 +335,8 @@ def parse_args():
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
default="wandb",
help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'),
)
parser.add_argument(
"--mixed_precision",
......@@ -442,6 +446,7 @@ def parse_args():
" `args.validation_prompt` and logging the images."
),
)
parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams)."))
parser.add_argument(
"--tracker_project_name",
type=str,
......@@ -668,6 +673,7 @@ def main():
# wandb init
if jax.process_index() == 0 and args.report_to == "wandb":
wandb.init(
entity=args.wandb_entity,
project=args.tracker_project_name,
job_type="train",
config=args,
......@@ -806,6 +812,20 @@ def main():
validation_rng, train_rngs = jax.random.split(rng)
train_rngs = jax.random.split(train_rngs, jax.local_device_count())
def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler_state.common.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
alpha = sqrt_alphas_cumprod[timesteps]
sigma = sqrt_one_minus_alphas_cumprod[timesteps]
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):
# reshape batch, add grad_step_dim if gradient_accumulation_steps > 1
if args.gradient_accumulation_steps > 1:
......@@ -876,6 +896,12 @@ def main():
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = (target - model_pred) ** 2
if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps))
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
loss = loss * snr_loss_weights
loss = loss.mean()
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment