"src/vscode:/vscode.git/clone" did not exist on "3d08d8dc4e7c25b28ccfba1631e72a77ae6478be"
Unverified Commit d78acded authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

apple mps: training support for SDXL (ControlNet, LoRA, Dreambooth, T2I) (#7447)



* apple mps: training support for SDXL LoRA

* sdxl: support training lora, dreambooth, t2i, pix2pix, and controlnet on apple mps

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 6df103de
...@@ -125,7 +125,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, ...@@ -125,7 +125,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
) )
image_logs = [] image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") inference_ctx = (
contextlib.nullcontext()
if (is_final_validation or torch.backends.mps.is_available())
else torch.autocast("cuda")
)
for validation_prompt, validation_image in zip(validation_prompts, validation_images): for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB") validation_image = Image.open(validation_image).convert("RGB")
...@@ -792,6 +796,12 @@ def main(args): ...@@ -792,6 +796,12 @@ def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator( accelerator = Accelerator(
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import contextlib
import gc import gc
import itertools import itertools
import json import json
...@@ -208,11 +207,18 @@ def log_validation( ...@@ -208,11 +207,18 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = ( enable_autocast = True
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() if torch.backends.mps.is_available() or (
) accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
if "playground" in args.pretrained_model_name_or_path:
enable_autocast = False
with inference_ctx: with torch.autocast(
accelerator.device.type,
enabled=enable_autocast,
):
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
...@@ -230,6 +236,7 @@ def log_validation( ...@@ -230,6 +236,7 @@ def log_validation(
) )
del pipeline del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
return images return images
...@@ -967,6 +974,12 @@ def main(args): ...@@ -967,6 +974,12 @@ def main(args):
if args.do_edm_style_training and args.snr_gamma is not None: if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
...@@ -1009,7 +1022,8 @@ def main(args): ...@@ -1009,7 +1022,8 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir())) cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images: if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32": if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32 torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16": elif args.prior_generation_precision == "fp16":
...@@ -1134,6 +1148,12 @@ def main(args): ...@@ -1134,6 +1148,12 @@ def main(args):
elif accelerator.mixed_precision == "bf16": elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
# Move unet, vae and text_encoder to device and cast to weight_dtype # Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
...@@ -1278,7 +1298,7 @@ def main(args): ...@@ -1278,7 +1298,7 @@ def main(args):
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32: if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr: if args.scale_lr:
...@@ -1455,6 +1475,7 @@ def main(args): ...@@ -1455,6 +1475,7 @@ def main(args):
if not args.train_text_encoder and not train_dataset.custom_instance_prompts: if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders del tokenizers, text_encoders
gc.collect() gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
......
...@@ -71,12 +71,7 @@ TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": tor ...@@ -71,12 +71,7 @@ TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": tor
def log_validation( def log_validation(
pipeline, pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True
args,
accelerator,
generator,
global_step,
is_final_validation=False,
): ):
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
...@@ -96,7 +91,7 @@ def log_validation( ...@@ -96,7 +91,7 @@ def log_validation(
else Image.open(image_url_or_path).convert("RGB") else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path) )(args.val_image_url_or_path)
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"): with torch.autocast(accelerator.device.type, enabled=enable_autocast):
edited_images = [] edited_images = []
# Run inference # Run inference
for val_img_idx in range(args.num_validation_images): for val_img_idx in range(args.num_validation_images):
...@@ -497,6 +492,13 @@ def main(): ...@@ -497,6 +492,13 @@ def main():
), ),
) )
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -981,6 +983,13 @@ def main(): ...@@ -981,6 +983,13 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args)) accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))
# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -1193,6 +1202,7 @@ def main(): ...@@ -1193,6 +1202,7 @@ def main():
generator, generator,
global_step, global_step,
is_final_validation=False, is_final_validation=False,
enable_autocast=enable_autocast,
) )
if args.use_ema: if args.use_ema:
...@@ -1242,6 +1252,7 @@ def main(): ...@@ -1242,6 +1252,7 @@ def main():
generator, generator,
global_step, global_step,
is_final_validation=True, is_final_validation=True,
enable_autocast=enable_autocast,
) )
accelerator.end_training() accelerator.end_training()
......
...@@ -501,6 +501,12 @@ def main(args): ...@@ -501,6 +501,12 @@ def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator( accelerator = Accelerator(
...@@ -973,6 +979,13 @@ def main(args): ...@@ -973,6 +979,13 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args)) accelerator.init_trackers("text2image-fine-tune", config=vars(args))
# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -1199,7 +1212,10 @@ def main(args): ...@@ -1199,7 +1212,10 @@ def main(args):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
with torch.cuda.amp.autocast(): with torch.autocast(
accelerator.device.type,
enabled=enable_autocast,
):
images = [ images = [
pipeline(**pipeline_args, generator=generator).images[0] pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
......
...@@ -590,6 +590,12 @@ def main(args): ...@@ -590,6 +590,12 @@ def main(args):
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
...@@ -980,6 +986,13 @@ def main(args): ...@@ -980,6 +986,13 @@ def main(args):
model = model._orig_mod if is_compiled_module(model) else model model = model._orig_mod if is_compiled_module(model) else model
return model return model
# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -1213,7 +1226,10 @@ def main(args): ...@@ -1213,7 +1226,10 @@ def main(args):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
with torch.cuda.amp.autocast(): with torch.autocast(
accelerator.device.type,
enabled=enable_autocast,
):
images = [ images = [
pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
...@@ -1268,7 +1284,7 @@ def main(args): ...@@ -1268,7 +1284,7 @@ def main(args):
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
with torch.cuda.amp.autocast(): with torch.autocast(accelerator.device.type, enabled=enable_autocast):
images = [ images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment