Unverified Commit f3d1333e authored by dg845's avatar dg845 Committed by GitHub
Browse files

Improve LCM(-LoRA) Distillation Scripts (#6420)

* Make WDS pipeline interpolation type configurable.

* Make the VAE encoding batch size configurable.

* Make lora_alpha and lora_dropout configurable for LCM LoRA scripts.

* Generalize scalings_for_boundary_conditions function and make the timestep scaling configurable.

* Make LoRA target modules configurable for LCM-LoRA scripts.

* Move resolve_interpolation_mode to src/diffusers/training_utils.py and make interpolation type configurable in non-WDS script.

* apply suggestions from review
parent acd926f4
...@@ -61,6 +61,7 @@ from diffusers import ( ...@@ -61,6 +61,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import resolve_interpolation_mode
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -165,6 +166,7 @@ class SDText2ImageDataset: ...@@ -165,6 +166,7 @@ class SDText2ImageDataset:
global_batch_size: int, global_batch_size: int,
num_workers: int, num_workers: int,
resolution: int = 512, resolution: int = 512,
interpolation_type: str = "bilinear",
shuffle_buffer_size: int = 1000, shuffle_buffer_size: int = 1000,
pin_memory: bool = False, pin_memory: bool = False,
persistent_workers: bool = False, persistent_workers: bool = False,
...@@ -174,10 +176,12 @@ class SDText2ImageDataset: ...@@ -174,10 +176,12 @@ class SDText2ImageDataset:
# flatten list using itertools # flatten list using itertools
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
interpolation_mode = resolve_interpolation_mode(interpolation_type)
def transform(example): def transform(example):
# resize image # resize image
image = example["image"] image = example["image"]
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) image = TF.resize(image, resolution, interpolation=interpolation_mode)
# get crop coordinates and crop image # get crop coordinates and crop image
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
...@@ -353,8 +357,9 @@ def append_dims(x, target_dims): ...@@ -353,8 +357,9 @@ def append_dims(x, target_dims):
# From LCMScheduler.get_scalings_for_boundary_condition_discrete # From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) scaled_timestep = timestep_scaling * timestep
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out return c_skip, c_out
...@@ -572,6 +577,15 @@ def parse_args(): ...@@ -572,6 +577,15 @@ def parse_args():
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--interpolation_type",
type=str,
default="bilinear",
help=(
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
),
)
parser.add_argument( parser.add_argument(
"--center_crop", "--center_crop",
default=False, default=False,
...@@ -710,6 +724,50 @@ def parse_args(): ...@@ -710,6 +724,50 @@ def parse_args():
default=64, default=64,
help="The rank of the LoRA projection matrix.", help="The rank of the LoRA projection matrix.",
) )
parser.add_argument(
"--lora_alpha",
type=int,
default=64,
help=(
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
),
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.0,
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default=None,
help=(
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
" be used. By default, LoRA will be applied to all conv and linear layers."
),
)
parser.add_argument(
"--vae_encode_batch_size",
type=int,
default=32,
required=False,
help=(
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
" Encoding or decoding the whole batch at once may run into OOM issues."
),
)
parser.add_argument(
"--timestep_scaling_factor",
type=float,
default=10.0,
help=(
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
" suffice."
),
)
# ----Mixed Precision---- # ----Mixed Precision----
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
...@@ -915,9 +973,10 @@ def main(args): ...@@ -915,9 +973,10 @@ def main(args):
) )
# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
lora_config = LoraConfig( if args.lora_target_modules is not None:
r=args.lora_rank, lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
target_modules=[ else:
lora_target_modules = [
"to_q", "to_q",
"to_k", "to_k",
"to_v", "to_v",
...@@ -932,7 +991,12 @@ def main(args): ...@@ -932,7 +991,12 @@ def main(args):
"downsamplers.0.conv", "downsamplers.0.conv",
"upsamplers.0.conv", "upsamplers.0.conv",
"time_emb_proj", "time_emb_proj",
], ]
lora_config = LoraConfig(
r=args.lora_rank,
target_modules=lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
) )
unet = get_peft_model(unet, lora_config) unet = get_peft_model(unet, lora_config)
...@@ -1051,6 +1115,7 @@ def main(args): ...@@ -1051,6 +1115,7 @@ def main(args):
global_batch_size=args.train_batch_size * accelerator.num_processes, global_batch_size=args.train_batch_size * accelerator.num_processes,
num_workers=args.dataloader_num_workers, num_workers=args.dataloader_num_workers,
resolution=args.resolution, resolution=args.resolution,
interpolation_type=args.interpolation_type,
shuffle_buffer_size=1000, shuffle_buffer_size=1000,
pin_memory=True, pin_memory=True,
persistent_workers=True, persistent_workers=True,
...@@ -1162,10 +1227,10 @@ def main(args): ...@@ -1162,10 +1227,10 @@ def main(args):
if vae.dtype != weight_dtype: if vae.dtype != weight_dtype:
vae.to(dtype=weight_dtype) vae.to(dtype=weight_dtype)
# encode pixel values with batch size of at most 32 # encode pixel values with batch size of at most args.vae_encode_batch_size
latents = [] latents = []
for i in range(0, pixel_values.shape[0], 32): for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
...@@ -1181,9 +1246,13 @@ def main(args): ...@@ -1181,9 +1246,13 @@ def main(args):
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 3. Get boundary scalings for start_timesteps and (end) timesteps. # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = scalings_for_boundary_conditions(
start_timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = scalings_for_boundary_conditions(
timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
......
...@@ -51,6 +51,7 @@ from diffusers import ( ...@@ -51,6 +51,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import resolve_interpolation_mode
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -193,8 +194,9 @@ def append_dims(x, target_dims): ...@@ -193,8 +194,9 @@ def append_dims(x, target_dims):
# From LCMScheduler.get_scalings_for_boundary_condition_discrete # From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) scaled_timestep = timestep_scaling * timestep
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out return c_skip, c_out
...@@ -396,6 +398,15 @@ def parse_args(): ...@@ -396,6 +398,15 @@ def parse_args():
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--interpolation_type",
type=str,
default="bilinear",
help=(
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
),
)
parser.add_argument( parser.add_argument(
"--center_crop", "--center_crop",
default=False, default=False,
...@@ -534,6 +545,50 @@ def parse_args(): ...@@ -534,6 +545,50 @@ def parse_args():
default=64, default=64,
help="The rank of the LoRA projection matrix.", help="The rank of the LoRA projection matrix.",
) )
parser.add_argument(
"--lora_alpha",
type=int,
default=64,
help=(
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
),
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.0,
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default=None,
help=(
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
" be used. By default, LoRA will be applied to all conv and linear layers."
),
)
parser.add_argument(
"--vae_encode_batch_size",
type=int,
default=8,
required=False,
help=(
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
" Encoding or decoding the whole batch at once may run into OOM issues."
),
)
parser.add_argument(
"--timestep_scaling_factor",
type=float,
default=10.0,
help=(
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
" suffice."
),
)
# ----Mixed Precision---- # ----Mixed Precision----
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
...@@ -776,10 +831,10 @@ def main(args): ...@@ -776,10 +831,10 @@ def main(args):
text_encoder_two.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype)
# 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
lora_config = LoraConfig( if args.lora_target_modules is not None:
r=args.lora_rank, lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
lora_alpha=args.lora_rank, else:
target_modules=[ lora_target_modules = [
"to_q", "to_q",
"to_k", "to_k",
"to_v", "to_v",
...@@ -794,7 +849,12 @@ def main(args): ...@@ -794,7 +849,12 @@ def main(args):
"downsamplers.0.conv", "downsamplers.0.conv",
"upsamplers.0.conv", "upsamplers.0.conv",
"time_emb_proj", "time_emb_proj",
], ]
lora_config = LoraConfig(
r=args.lora_rank,
target_modules=lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
) )
unet.add_adapter(lora_config) unet.add_adapter(lora_config)
...@@ -929,7 +989,8 @@ def main(args): ...@@ -929,7 +989,8 @@ def main(args):
) )
# Preprocessing the datasets. # Preprocessing the datasets.
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) interpolation_mode = resolve_interpolation_mode(args.interpolation_type)
train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode)
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
train_flip = transforms.RandomHorizontalFlip(p=1.0) train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
...@@ -1121,11 +1182,11 @@ def main(args): ...@@ -1121,11 +1182,11 @@ def main(args):
encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
# encode pixel values with batch size of at most 8 # encode pixel values with batch size of at most args.vae_encode_batch_size
pixel_values = pixel_values.to(dtype=vae.dtype) pixel_values = pixel_values.to(dtype=vae.dtype)
latents = [] latents = []
for i in range(0, pixel_values.shape[0], args.encode_batch_size): for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample()) latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
...@@ -1142,9 +1203,13 @@ def main(args): ...@@ -1142,9 +1203,13 @@ def main(args):
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 3. Get boundary scalings for start_timesteps and (end) timesteps. # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = scalings_for_boundary_conditions(
start_timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = scalings_for_boundary_conditions(
timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
......
...@@ -62,6 +62,7 @@ from diffusers import ( ...@@ -62,6 +62,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import resolve_interpolation_mode
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -171,6 +172,7 @@ class SDXLText2ImageDataset: ...@@ -171,6 +172,7 @@ class SDXLText2ImageDataset:
global_batch_size: int, global_batch_size: int,
num_workers: int, num_workers: int,
resolution: int = 1024, resolution: int = 1024,
interpolation_type: str = "bilinear",
shuffle_buffer_size: int = 1000, shuffle_buffer_size: int = 1000,
pin_memory: bool = False, pin_memory: bool = False,
persistent_workers: bool = False, persistent_workers: bool = False,
...@@ -187,10 +189,12 @@ class SDXLText2ImageDataset: ...@@ -187,10 +189,12 @@ class SDXLText2ImageDataset:
else: else:
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
interpolation_mode = resolve_interpolation_mode(interpolation_type)
def transform(example): def transform(example):
# resize image # resize image
image = example["image"] image = example["image"]
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) image = TF.resize(image, resolution, interpolation=interpolation_mode)
# get crop coordinates and crop image # get crop coordinates and crop image
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
...@@ -340,8 +344,9 @@ def append_dims(x, target_dims): ...@@ -340,8 +344,9 @@ def append_dims(x, target_dims):
# From LCMScheduler.get_scalings_for_boundary_condition_discrete # From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) scaled_timestep = timestep_scaling * timestep
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out return c_skip, c_out
...@@ -546,6 +551,15 @@ def parse_args(): ...@@ -546,6 +551,15 @@ def parse_args():
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--interpolation_type",
type=str,
default="bilinear",
help=(
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
),
)
parser.add_argument( parser.add_argument(
"--use_fix_crop_and_size", "--use_fix_crop_and_size",
action="store_true", action="store_true",
...@@ -690,6 +704,50 @@ def parse_args(): ...@@ -690,6 +704,50 @@ def parse_args():
default=64, default=64,
help="The rank of the LoRA projection matrix.", help="The rank of the LoRA projection matrix.",
) )
parser.add_argument(
"--lora_alpha",
type=int,
default=64,
help=(
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
),
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.0,
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default=None,
help=(
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
" be used. By default, LoRA will be applied to all conv and linear layers."
),
)
parser.add_argument(
"--vae_encode_batch_size",
type=int,
default=8,
required=False,
help=(
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
" Encoding or decoding the whole batch at once may run into OOM issues."
),
)
parser.add_argument(
"--timestep_scaling_factor",
type=float,
default=10.0,
help=(
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
" suffice."
),
)
# ----Mixed Precision---- # ----Mixed Precision----
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
...@@ -929,9 +987,10 @@ def main(args): ...@@ -929,9 +987,10 @@ def main(args):
) )
# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
lora_config = LoraConfig( if args.lora_target_modules is not None:
r=args.lora_rank, lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
target_modules=[ else:
lora_target_modules = [
"to_q", "to_q",
"to_k", "to_k",
"to_v", "to_v",
...@@ -946,7 +1005,12 @@ def main(args): ...@@ -946,7 +1005,12 @@ def main(args):
"downsamplers.0.conv", "downsamplers.0.conv",
"upsamplers.0.conv", "upsamplers.0.conv",
"time_emb_proj", "time_emb_proj",
], ]
lora_config = LoraConfig(
r=args.lora_rank,
target_modules=lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
) )
unet = get_peft_model(unet, lora_config) unet = get_peft_model(unet, lora_config)
...@@ -1090,6 +1154,7 @@ def main(args): ...@@ -1090,6 +1154,7 @@ def main(args):
global_batch_size=args.train_batch_size * accelerator.num_processes, global_batch_size=args.train_batch_size * accelerator.num_processes,
num_workers=args.dataloader_num_workers, num_workers=args.dataloader_num_workers,
resolution=args.resolution, resolution=args.resolution,
interpolation_type=args.interpolation_type,
shuffle_buffer_size=1000, shuffle_buffer_size=1000,
pin_memory=True, pin_memory=True,
persistent_workers=True, persistent_workers=True,
...@@ -1214,10 +1279,10 @@ def main(args): ...@@ -1214,10 +1279,10 @@ def main(args):
else: else:
pixel_values = image pixel_values = image
# encode pixel values with batch size of at most 8 # encode pixel values with batch size of at most args.vae_encode_batch_size
latents = [] latents = []
for i in range(0, pixel_values.shape[0], 8): for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
...@@ -1234,9 +1299,13 @@ def main(args): ...@@ -1234,9 +1299,13 @@ def main(args):
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 3. Get boundary scalings for start_timesteps and (end) timesteps. # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = scalings_for_boundary_conditions(
start_timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = scalings_for_boundary_conditions(
timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
......
...@@ -60,6 +60,7 @@ from diffusers import ( ...@@ -60,6 +60,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import resolve_interpolation_mode
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -147,6 +148,7 @@ class SDText2ImageDataset: ...@@ -147,6 +148,7 @@ class SDText2ImageDataset:
global_batch_size: int, global_batch_size: int,
num_workers: int, num_workers: int,
resolution: int = 512, resolution: int = 512,
interpolation_type: str = "bilinear",
shuffle_buffer_size: int = 1000, shuffle_buffer_size: int = 1000,
pin_memory: bool = False, pin_memory: bool = False,
persistent_workers: bool = False, persistent_workers: bool = False,
...@@ -156,10 +158,12 @@ class SDText2ImageDataset: ...@@ -156,10 +158,12 @@ class SDText2ImageDataset:
# flatten list using itertools # flatten list using itertools
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
interpolation_mode = resolve_interpolation_mode(interpolation_type)
def transform(example): def transform(example):
# resize image # resize image
image = example["image"] image = example["image"]
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) image = TF.resize(image, resolution, interpolation=interpolation_mode)
# get crop coordinates and crop image # get crop coordinates and crop image
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
...@@ -330,8 +334,9 @@ def append_dims(x, target_dims): ...@@ -330,8 +334,9 @@ def append_dims(x, target_dims):
# From LCMScheduler.get_scalings_for_boundary_condition_discrete # From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) scaled_timestep = timestep_scaling * timestep
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out return c_skip, c_out
...@@ -549,6 +554,15 @@ def parse_args(): ...@@ -549,6 +554,15 @@ def parse_args():
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--interpolation_type",
type=str,
default="bilinear",
help=(
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
),
)
parser.add_argument( parser.add_argument(
"--center_crop", "--center_crop",
default=False, default=False,
...@@ -690,6 +704,26 @@ def parse_args(): ...@@ -690,6 +704,26 @@ def parse_args():
" does not have `time_cond_proj_dim` set." " does not have `time_cond_proj_dim` set."
), ),
) )
parser.add_argument(
"--vae_encode_batch_size",
type=int,
default=32,
required=False,
help=(
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
" Encoding or decoding the whole batch at once may run into OOM issues."
),
)
parser.add_argument(
"--timestep_scaling_factor",
type=float,
default=10.0,
help=(
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
" suffice."
),
)
# ----Exponential Moving Average (EMA)---- # ----Exponential Moving Average (EMA)----
parser.add_argument( parser.add_argument(
"--ema_decay", "--ema_decay",
...@@ -1034,6 +1068,7 @@ def main(args): ...@@ -1034,6 +1068,7 @@ def main(args):
global_batch_size=args.train_batch_size * accelerator.num_processes, global_batch_size=args.train_batch_size * accelerator.num_processes,
num_workers=args.dataloader_num_workers, num_workers=args.dataloader_num_workers,
resolution=args.resolution, resolution=args.resolution,
interpolation_type=args.interpolation_type,
shuffle_buffer_size=1000, shuffle_buffer_size=1000,
pin_memory=True, pin_memory=True,
persistent_workers=True, persistent_workers=True,
...@@ -1145,10 +1180,10 @@ def main(args): ...@@ -1145,10 +1180,10 @@ def main(args):
if vae.dtype != weight_dtype: if vae.dtype != weight_dtype:
vae.to(dtype=weight_dtype) vae.to(dtype=weight_dtype)
# encode pixel values with batch size of at most 32 # encode pixel values with batch size of at most args.vae_encode_batch_size
latents = [] latents = []
for i in range(0, pixel_values.shape[0], 32): for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
...@@ -1164,9 +1199,13 @@ def main(args): ...@@ -1164,9 +1199,13 @@ def main(args):
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 3. Get boundary scalings for start_timesteps and (end) timesteps. # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = scalings_for_boundary_conditions(
start_timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = scalings_for_boundary_conditions(
timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
......
...@@ -61,6 +61,7 @@ from diffusers import ( ...@@ -61,6 +61,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import resolve_interpolation_mode
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -153,6 +154,7 @@ class SDXLText2ImageDataset: ...@@ -153,6 +154,7 @@ class SDXLText2ImageDataset:
global_batch_size: int, global_batch_size: int,
num_workers: int, num_workers: int,
resolution: int = 1024, resolution: int = 1024,
interpolation_type: str = "bilinear",
shuffle_buffer_size: int = 1000, shuffle_buffer_size: int = 1000,
pin_memory: bool = False, pin_memory: bool = False,
persistent_workers: bool = False, persistent_workers: bool = False,
...@@ -169,10 +171,12 @@ class SDXLText2ImageDataset: ...@@ -169,10 +171,12 @@ class SDXLText2ImageDataset:
else: else:
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
interpolation_mode = resolve_interpolation_mode(interpolation_type)
def transform(example): def transform(example):
# resize image # resize image
image = example["image"] image = example["image"]
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) image = TF.resize(image, resolution, interpolation=interpolation_mode)
# get crop coordinates and crop image # get crop coordinates and crop image
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
...@@ -318,8 +322,9 @@ def append_dims(x, target_dims): ...@@ -318,8 +322,9 @@ def append_dims(x, target_dims):
# From LCMScheduler.get_scalings_for_boundary_condition_discrete # From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) scaled_timestep = timestep_scaling * timestep
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out return c_skip, c_out
...@@ -568,6 +573,15 @@ def parse_args(): ...@@ -568,6 +573,15 @@ def parse_args():
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--interpolation_type",
type=str,
default="bilinear",
help=(
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
),
)
parser.add_argument( parser.add_argument(
"--use_fix_crop_and_size", "--use_fix_crop_and_size",
action="store_true", action="store_true",
...@@ -715,6 +729,26 @@ def parse_args(): ...@@ -715,6 +729,26 @@ def parse_args():
" does not have `time_cond_proj_dim` set." " does not have `time_cond_proj_dim` set."
), ),
) )
parser.add_argument(
"--vae_encode_batch_size",
type=int,
default=8,
required=False,
help=(
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
" Encoding or decoding the whole batch at once may run into OOM issues."
),
)
parser.add_argument(
"--timestep_scaling_factor",
type=float,
default=10.0,
help=(
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
" suffice."
),
)
# ----Exponential Moving Average (EMA)---- # ----Exponential Moving Average (EMA)----
parser.add_argument( parser.add_argument(
"--ema_decay", "--ema_decay",
...@@ -1118,6 +1152,7 @@ def main(args): ...@@ -1118,6 +1152,7 @@ def main(args):
global_batch_size=args.train_batch_size * accelerator.num_processes, global_batch_size=args.train_batch_size * accelerator.num_processes,
num_workers=args.dataloader_num_workers, num_workers=args.dataloader_num_workers,
resolution=args.resolution, resolution=args.resolution,
interpolation_type=args.interpolation_type,
shuffle_buffer_size=1000, shuffle_buffer_size=1000,
pin_memory=True, pin_memory=True,
persistent_workers=True, persistent_workers=True,
...@@ -1242,10 +1277,10 @@ def main(args): ...@@ -1242,10 +1277,10 @@ def main(args):
else: else:
pixel_values = image pixel_values = image
# encode pixel values with batch size of at most 8 # encode pixel values with batch size of at most args.vae_encode_batch_size
latents = [] latents = []
for i in range(0, pixel_values.shape[0], 8): for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
...@@ -1262,9 +1297,13 @@ def main(args): ...@@ -1262,9 +1297,13 @@ def main(args):
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 3. Get boundary scalings for start_timesteps and (end) timesteps. # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = scalings_for_boundary_conditions(
start_timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = scalings_for_boundary_conditions(
timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
......
...@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, Optional, Union ...@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, Optional, Union
import numpy as np import numpy as np
import torch import torch
from torchvision import transforms
from .models import UNet2DConditionModel from .models import UNet2DConditionModel
from .utils import deprecate, is_transformers_available from .utils import deprecate, is_transformers_available
...@@ -53,6 +54,45 @@ def compute_snr(noise_scheduler, timesteps): ...@@ -53,6 +54,45 @@ def compute_snr(noise_scheduler, timesteps):
return snr return snr
def resolve_interpolation_mode(interpolation_type: str):
"""
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
full list of supported enums is documented at
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
Args:
interpolation_type (`str`):
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
in torchvision.
Returns:
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
transform.
"""
if interpolation_type == "bilinear":
interpolation_mode = transforms.InterpolationMode.BILINEAR
elif interpolation_type == "bicubic":
interpolation_mode = transforms.InterpolationMode.BICUBIC
elif interpolation_type == "box":
interpolation_mode = transforms.InterpolationMode.BOX
elif interpolation_type == "nearest":
interpolation_mode = transforms.InterpolationMode.NEAREST
elif interpolation_type == "nearest_exact":
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
elif interpolation_type == "hamming":
interpolation_mode = transforms.InterpolationMode.HAMMING
elif interpolation_type == "lanczos":
interpolation_mode = transforms.InterpolationMode.LANCZOS
else:
raise ValueError(
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
)
return interpolation_mode
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
r""" r"""
Returns: Returns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment