Unverified Commit 6c56f050 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

v-prediction training support (#1455)

* add get_velocity

* add v prediction for training

* fix saving

* add revision arg

* fix saving

* save checkpoints dreambooth

* fix saving embeds

* add instruction in readme

* quality

* noise_pred -> model_pred
parent 77fc197f
......@@ -39,6 +39,8 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/
And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
......
......@@ -124,6 +124,7 @@ def parse_args(input_args=None):
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
......@@ -603,23 +604,31 @@ def main(args):
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.with_prior_preservation:
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
......@@ -638,6 +647,17 @@ def main(args):
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
......
......@@ -42,6 +42,8 @@ If you have already cloned the repo, then you won't need to go through these ste
#### Hardware
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
......
......@@ -15,13 +15,12 @@ from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer
logger = get_logger(__name__)
......@@ -36,6 +35,13 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
......@@ -335,10 +341,24 @@ def main():
os.makedirs(args.output_dir, exist_ok=True)
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
# Freeze vae and text_encoder
vae.requires_grad_(False)
......@@ -562,9 +582,17 @@ def main():
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
......@@ -600,14 +628,12 @@ def main():
if args.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline(
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
......
......@@ -47,6 +47,8 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c
And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATA_DIR="path-to-dir-containing-images"
......
......@@ -16,9 +16,8 @@ import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
......@@ -26,7 +25,7 @@ from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
......@@ -51,11 +50,11 @@ else:
logger = get_logger(__name__)
def save_progress(text_encoder, placeholder_token_id, accelerator, args):
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
torch.save(learned_embeds_dict, save_path)
def parse_args():
......@@ -73,6 +72,13 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
......@@ -405,9 +411,21 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
......@@ -532,9 +550,17 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added
......@@ -556,7 +582,8 @@ def main():
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
save_progress(text_encoder, placeholder_token_id, accelerator, args)
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
......@@ -569,18 +596,18 @@ def main():
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline(
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
save_progress(text_encoder, placeholder_token_id, accelerator, args)
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
......
......@@ -355,5 +355,25 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
return self.config.num_train_timesteps
......@@ -345,5 +345,25 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment