Unverified Commit ab6672fe authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Use CC12M for LCM WDS training example (#5908)

* Fix SD scripts - there are only 2 items per batch

* Adjustments to make the SDXL scripts work with other datasets

* Use public webdataset dataset for examples

* make style

* Minor tweaks to the readmes.

* Stress that the database is illustrative.
parent f90a5139
# Latent Consistency Distillation Example: # Latent Consistency Distillation Example:
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference. [Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill stable-diffusion-v1.5 for inference with few timesteps.
## Full model distillation ## Full model distillation
...@@ -24,7 +24,7 @@ Then cd in the example folder and run ...@@ -24,7 +24,7 @@ Then cd in the example folder and run
pip install -r requirements.txt pip install -r requirements.txt
``` ```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash ```bash
accelerate config accelerate config
...@@ -46,12 +46,16 @@ write_basic_config() ...@@ -46,12 +46,16 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
#### Example with LAION-A6+ dataset #### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
```bash ```bash
runwayml/stable-diffusion-v1-5 export MODEL_NAME="runwayml/stable-diffusion-v1-5"
PROGRAM="train_lcm_distill_sd_wds.py \ export OUTPUT_DIR="path/to/saved/model"
--pretrained_teacher_model=$MODEL_DIR \
accelerate launch train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \ --mixed_precision=fp16 \
--resolution=512 \ --resolution=512 \
...@@ -59,7 +63,7 @@ PROGRAM="train_lcm_distill_sd_wds.py \ ...@@ -59,7 +63,7 @@ PROGRAM="train_lcm_distill_sd_wds.py \
--max_train_steps=1000 \ --max_train_steps=1000 \
--max_train_samples=4000000 \ --max_train_samples=4000000 \
--dataloader_num_workers=8 \ --dataloader_num_workers=8 \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \ --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--validation_steps=200 \ --validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \ --train_batch_size=12 \
...@@ -69,19 +73,23 @@ PROGRAM="train_lcm_distill_sd_wds.py \ ...@@ -69,19 +73,23 @@ PROGRAM="train_lcm_distill_sd_wds.py \
--resume_from_checkpoint=latest \ --resume_from_checkpoint=latest \
--report_to=wandb \ --report_to=wandb \
--seed=453645634 \ --seed=453645634 \
--push_to_hub \ --push_to_hub
``` ```
## LCM-LoRA ## LCM-LoRA
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model. Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
### Example with LAION-A6+ dataset ### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
```bash ```bash
runwayml/stable-diffusion-v1-5 export MODEL_NAME="runwayml/stable-diffusion-v1-5"
PROGRAM="train_lcm_distill_lora_sd_wds.py \ export OUTPUT_DIR="path/to/saved/model"
--pretrained_teacher_model=$MODEL_DIR \
accelerate launch train_lcm_distill_lora_sd_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \ --mixed_precision=fp16 \
--resolution=512 \ --resolution=512 \
...@@ -90,7 +98,7 @@ PROGRAM="train_lcm_distill_lora_sd_wds.py \ ...@@ -90,7 +98,7 @@ PROGRAM="train_lcm_distill_lora_sd_wds.py \
--max_train_steps=1000 \ --max_train_steps=1000 \
--max_train_samples=4000000 \ --max_train_samples=4000000 \
--dataloader_num_workers=8 \ --dataloader_num_workers=8 \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \ --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--validation_steps=200 \ --validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \ --train_batch_size=12 \
......
# Latent Consistency Distillation Example: # Latent Consistency Distillation Example:
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference. [Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill SDXL for inference with few timesteps.
## Full model distillation ## Full model distillation
...@@ -24,7 +24,7 @@ Then cd in the example folder and run ...@@ -24,7 +24,7 @@ Then cd in the example folder and run
pip install -r requirements.txt pip install -r requirements.txt
``` ```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash ```bash
accelerate config accelerate config
...@@ -46,12 +46,16 @@ write_basic_config() ...@@ -46,12 +46,16 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
#### Example with LAION-A6+ dataset #### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
```bash ```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0" export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
PROGRAM="train_lcm_distill_sdxl_wds.py \ export OUTPUT_DIR="path/to/saved/model"
--pretrained_teacher_model=$MODEL_DIR \
accelerate launch train_lcm_distill_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \ --mixed_precision=fp16 \
...@@ -60,7 +64,7 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \ ...@@ -60,7 +64,7 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \
--max_train_steps=1000 \ --max_train_steps=1000 \
--max_train_samples=4000000 \ --max_train_samples=4000000 \
--dataloader_num_workers=8 \ --dataloader_num_workers=8 \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \ --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--validation_steps=200 \ --validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \ --train_batch_size=12 \
...@@ -77,11 +81,15 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \ ...@@ -77,11 +81,15 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model. Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
### Example with LAION-A6+ dataset ### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
```bash ```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0" export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
PROGRAM="train_lcm_distill_lora_sdxl_wds.py \ export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_lora_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_DIR \ --pretrained_teacher_model=$MODEL_DIR \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
...@@ -92,7 +100,7 @@ PROGRAM="train_lcm_distill_lora_sdxl_wds.py \ ...@@ -92,7 +100,7 @@ PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
--max_train_steps=1000 \ --max_train_steps=1000 \
--max_train_samples=4000000 \ --max_train_samples=4000000 \
--dataloader_num_workers=8 \ --dataloader_num_workers=8 \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \ --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--validation_steps=200 \ --validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \ --train_batch_size=12 \
......
...@@ -1123,7 +1123,7 @@ def main(args): ...@@ -1123,7 +1123,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
image, text, _, _ = batch image, text = batch
image = image.to(accelerator.device, non_blocking=True) image = image.to(accelerator.device, non_blocking=True)
encoded_text = compute_embeddings_fn(text) encoded_text = compute_embeddings_fn(text)
......
...@@ -68,6 +68,11 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -68,6 +68,11 @@ from diffusers.utils.import_utils import is_xformers_available
MAX_SEQ_LENGTH = 77 MAX_SEQ_LENGTH = 77
# Adjust for your dataset
WDS_JSON_WIDTH = "width" # original_width for LAION
WDS_JSON_HEIGHT = "height" # original_height for LAION
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
if is_wandb_available(): if is_wandb_available():
import wandb import wandb
...@@ -146,10 +151,10 @@ class WebdatasetFilter: ...@@ -146,10 +151,10 @@ class WebdatasetFilter:
try: try:
if "json" in x: if "json" in x:
x_json = json.loads(x["json"]) x_json = json.loads(x["json"])
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get( filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
"original_height", 0 WDS_JSON_HEIGHT, 0
) >= self.min_size ) >= self.min_size
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
return filter_size and filter_watermark return filter_size and filter_watermark
else: else:
return False return False
...@@ -180,7 +185,7 @@ class Text2ImageDataset: ...@@ -180,7 +185,7 @@ class Text2ImageDataset:
if use_fix_crop_and_size: if use_fix_crop_and_size:
return (resolution, resolution) return (resolution, resolution)
else: else:
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0))) return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
def transform(example): def transform(example):
# resize image # resize image
...@@ -212,7 +217,7 @@ class Text2ImageDataset: ...@@ -212,7 +217,7 @@ class Text2ImageDataset:
pipeline = [ pipeline = [
wds.ResampledShards(train_shards_path_or_url), wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow, tarfile_to_samples_nothrow,
wds.select(WebdatasetFilter(min_size=960)), wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
wds.shuffle(shuffle_buffer_size), wds.shuffle(shuffle_buffer_size),
*processing_pipeline, *processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
......
...@@ -1106,7 +1106,7 @@ def main(args): ...@@ -1106,7 +1106,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
image, text, _, _ = batch image, text = batch
image = image.to(accelerator.device, non_blocking=True) image = image.to(accelerator.device, non_blocking=True)
encoded_text = compute_embeddings_fn(text) encoded_text = compute_embeddings_fn(text)
......
...@@ -67,6 +67,11 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -67,6 +67,11 @@ from diffusers.utils.import_utils import is_xformers_available
MAX_SEQ_LENGTH = 77 MAX_SEQ_LENGTH = 77
# Adjust for your dataset
WDS_JSON_WIDTH = "width" # original_width for LAION
WDS_JSON_HEIGHT = "height" # original_height for LAION
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
if is_wandb_available(): if is_wandb_available():
import wandb import wandb
...@@ -128,10 +133,10 @@ class WebdatasetFilter: ...@@ -128,10 +133,10 @@ class WebdatasetFilter:
try: try:
if "json" in x: if "json" in x:
x_json = json.loads(x["json"]) x_json = json.loads(x["json"])
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get( filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
"original_height", 0 WDS_JSON_HEIGHT, 0
) >= self.min_size ) >= self.min_size
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
return filter_size and filter_watermark return filter_size and filter_watermark
else: else:
return False return False
...@@ -162,7 +167,7 @@ class Text2ImageDataset: ...@@ -162,7 +167,7 @@ class Text2ImageDataset:
if use_fix_crop_and_size: if use_fix_crop_and_size:
return (resolution, resolution) return (resolution, resolution)
else: else:
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0))) return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
def transform(example): def transform(example):
# resize image # resize image
...@@ -194,7 +199,7 @@ class Text2ImageDataset: ...@@ -194,7 +199,7 @@ class Text2ImageDataset:
pipeline = [ pipeline = [
wds.ResampledShards(train_shards_path_or_url), wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow, tarfile_to_samples_nothrow,
wds.select(WebdatasetFilter(min_size=960)), wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
wds.shuffle(shuffle_buffer_size), wds.shuffle(shuffle_buffer_size),
*processing_pipeline, *processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment