Unverified Commit 8e963d1c authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

7529 do not disable autocast for cuda devices (#7530)



* 7529 do not disable autocast for cuda devices

* Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue

* add autocast fix to other training examples

* disable native_amp for dreambooth (sdxl)

* disable native_amp for pix2pix (sdxl)

* remove tests from remaining files

* disable native_amp on huggingface accelerator for every training example that uses it

* convert more usages of autocast to nullcontext, make style fixes

* make style fixes

* style.

* Empty-Commit

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 2b04ec2f
......@@ -23,6 +23,7 @@ import os
import re
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional
......@@ -1844,7 +1845,12 @@ def main(args):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
with torch.cuda.amp.autocast():
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
......
......@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
import argparse
import contextlib
import gc
import hashlib
import itertools
......@@ -26,6 +25,7 @@ import random
import re
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional
......@@ -2192,13 +2192,12 @@ def main(args):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
inference_ctx = (
contextlib.nullcontext()
if "playground" in args.pretrained_model_name_or_path
else torch.cuda.amp.autocast()
)
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with inference_ctx:
with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
......
......@@ -430,6 +430,9 @@ def main(args):
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
......
......@@ -23,6 +23,7 @@ import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union
......@@ -238,6 +239,10 @@ class SDText2ImageDataset:
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionPipeline.from_pretrained(
......@@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
......@@ -1172,6 +1177,11 @@ def main(args):
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
# 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
......@@ -1300,7 +1310,7 @@ def main(args):
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
......@@ -1359,7 +1369,7 @@ def main(args):
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
......
......@@ -22,6 +22,7 @@ import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
import accelerate
......@@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
......
......@@ -24,6 +24,7 @@ import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union
......@@ -256,6 +257,10 @@ class SDXLText2ImageDataset:
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionXLPipeline.from_pretrained(
......@@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
......@@ -1353,7 +1358,12 @@ def main(args):
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
......@@ -1416,7 +1426,12 @@ def main(args):
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
......
......@@ -23,6 +23,7 @@ import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union
......@@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
......@@ -1257,7 +1263,12 @@ def main(args):
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
......@@ -1315,7 +1326,12 @@ def main(args):
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
......
......@@ -24,6 +24,7 @@ import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union
......@@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
......@@ -1355,7 +1361,12 @@ def main(args):
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
......@@ -1417,7 +1428,12 @@ def main(args):
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
......
......@@ -752,6 +752,10 @@ def main(args):
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
......@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
import argparse
import contextlib
import functools
import gc
import logging
......@@ -22,6 +21,7 @@ import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
import accelerate
......@@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
)
image_logs = []
inference_ctx = (
contextlib.nullcontext()
if (is_final_validation or torch.backends.mps.is_available())
else torch.autocast("cuda")
)
if is_final_validation or torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
......@@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
images = []
for _ in range(args.num_validation_images):
with inference_ctx:
with autocast_ctx:
image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
......@@ -811,6 +810,10 @@ def main(args):
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
......@@ -676,6 +676,10 @@ def main(args):
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
......@@ -821,6 +821,10 @@ def main(args):
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
......@@ -749,6 +749,10 @@ def main(args):
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
......@@ -23,6 +23,7 @@ import os
import random
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
import numpy as np
......@@ -207,18 +208,12 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
if "playground" in args.pretrained_model_name_or_path:
enable_autocast = False
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with torch.autocast(
accelerator.device.type,
enabled=enable_autocast,
):
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers:
......@@ -992,6 +987,10 @@ def main(args):
kwargs_handlers=[kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
......@@ -21,6 +21,7 @@ import logging
import math
import os
import shutil
from contextlib import nullcontext
from pathlib import Path
import accelerate
......@@ -404,6 +405,10 @@ def main():
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
......@@ -943,9 +948,12 @@ def main():
# run inference
original_image = download_image(args.val_image_url)
edited_images = []
with torch.autocast(
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
......
......@@ -20,6 +20,7 @@ import math
import os
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from urllib.parse import urlparse
......@@ -70,9 +71,7 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
def log_validation(
pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True
):
def log_validation(pipeline, args, accelerator, generator, global_step, is_final_validation=False):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
......@@ -91,7 +90,12 @@ def log_validation(
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
edited_images = []
# Run inference
for val_img_idx in range(args.num_validation_images):
......@@ -507,6 +511,10 @@ def main():
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
# Make one log on every process with the configuration for debugging.
......@@ -983,13 +991,6 @@ def main():
if accelerator.is_main_process:
accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))
# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
......@@ -1202,7 +1203,6 @@ def main():
generator,
global_step,
is_final_validation=False,
enable_autocast=enable_autocast,
)
if args.use_ema:
......@@ -1252,7 +1252,6 @@ def main():
generator,
global_step,
is_final_validation=True,
enable_autocast=enable_autocast,
)
accelerator.end_training()
......
......@@ -458,6 +458,10 @@ def main():
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
......@@ -343,6 +343,11 @@ def main():
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
......@@ -356,6 +356,11 @@ def main():
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
......@@ -459,6 +459,10 @@ def main():
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment