Unverified Commit 6c56f050 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

v-prediction training support (#1455)

* add get_velocity

* add v prediction for training

* fix saving

* add revision arg

* fix saving

* save checkpoints dreambooth

* fix saving embeds

* add instruction in readme

* quality

* noise_pred -> model_pred
parent 77fc197f
...@@ -39,6 +39,8 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/ ...@@ -39,6 +39,8 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/
And launch the training using And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="path-to-instance-images"
......
...@@ -124,6 +124,7 @@ def parse_args(input_args=None): ...@@ -124,6 +124,7 @@ def parse_args(input_args=None):
default=None, default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
) )
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument( parser.add_argument(
"--gradient_accumulation_steps", "--gradient_accumulation_steps",
type=int, type=int,
...@@ -603,23 +604,31 @@ def main(args): ...@@ -603,23 +604,31 @@ def main(args):
encoder_hidden_states = text_encoder(batch["input_ids"])[0] encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual # Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.with_prior_preservation: if args.with_prior_preservation:
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately. # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
noise, noise_prior = torch.chunk(noise, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss # Compute instance loss
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss # Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss. # Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss loss = loss + args.prior_loss_weight * prior_loss
else: else:
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
...@@ -638,6 +647,17 @@ def main(args): ...@@ -638,6 +647,17 @@ def main(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
if global_step % args.save_steps == 0:
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
......
...@@ -42,6 +42,8 @@ If you have already cloned the repo, then you won't need to go through these ste ...@@ -42,6 +42,8 @@ If you have already cloned the repo, then you won't need to go through these ste
#### Hardware #### Hardware
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory. With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions" export dataset_name="lambdalabs/pokemon-blip-captions"
......
...@@ -15,13 +15,12 @@ from accelerate import Accelerator ...@@ -15,13 +15,12 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from datasets import load_dataset from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -36,6 +35,13 @@ def parse_args(): ...@@ -36,6 +35,13 @@ def parse_args():
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -335,10 +341,24 @@ def main(): ...@@ -335,10 +341,24 @@ def main():
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") tokenizer = CLIPTokenizer.from_pretrained(
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") )
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
# Freeze vae and text_encoder # Freeze vae and text_encoder
vae.requires_grad_(False) vae.requires_grad_(False)
...@@ -562,9 +582,17 @@ def main(): ...@@ -562,9 +582,17 @@ def main():
# Get the text embedding for conditioning # Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0] encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss # Predict the noise residual and compute loss
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training). # Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
...@@ -600,14 +628,12 @@ def main(): ...@@ -600,14 +628,12 @@ def main():
if args.use_ema: if args.use_ema:
ema_unet.copy_to(unet.parameters()) ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, revision=args.revision,
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
......
...@@ -47,6 +47,8 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c ...@@ -47,6 +47,8 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c
And launch the training using And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash ```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5" export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATA_DIR="path-to-dir-containing-images" export DATA_DIR="path-to-dir-containing-images"
......
...@@ -16,9 +16,8 @@ import PIL ...@@ -16,9 +16,8 @@ import PIL
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
...@@ -26,7 +25,7 @@ from packaging import version ...@@ -26,7 +25,7 @@ from packaging import version
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
...@@ -51,11 +50,11 @@ else: ...@@ -51,11 +50,11 @@ else:
logger = get_logger(__name__) logger = get_logger(__name__)
def save_progress(text_encoder, placeholder_token_id, accelerator, args): def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
logger.info("Saving embeddings") logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) torch.save(learned_embeds_dict, save_path)
def parse_args(): def parse_args():
...@@ -73,6 +72,13 @@ def parse_args(): ...@@ -73,6 +72,13 @@ def parse_args():
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
type=str, type=str,
...@@ -405,9 +411,21 @@ def main(): ...@@ -405,9 +411,21 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") text_encoder = CLIPTextModel.from_pretrained(
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") args.pretrained_model_name_or_path,
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
# Resize the token embeddings as we are adding new special tokens to the tokenizer # Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer)) text_encoder.resize_token_embeddings(len(tokenizer))
...@@ -532,9 +550,17 @@ def main(): ...@@ -532,9 +550,17 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0] encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual # Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss) accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added # Zero out the gradients for all token embeddings except the newly added
...@@ -556,7 +582,8 @@ def main(): ...@@ -556,7 +582,8 @@ def main():
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
save_progress(text_encoder, placeholder_token_id, accelerator, args) save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
...@@ -569,18 +596,18 @@ def main(): ...@@ -569,18 +596,18 @@ def main():
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process: if accelerator.is_main_process:
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, revision=args.revision,
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings # Also save the newly trained embeddings
save_progress(text_encoder, placeholder_token_id, accelerator, args) save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
......
...@@ -355,5 +355,25 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -355,5 +355,25 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
...@@ -345,5 +345,25 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -345,5 +345,25 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment