Unverified Commit 8e963d1c authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

7529 do not disable autocast for cuda devices (#7530)



* 7529 do not disable autocast for cuda devices

* Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue

* add autocast fix to other training examples

* disable native_amp for dreambooth (sdxl)

* disable native_amp for pix2pix (sdxl)

* remove tests from remaining files

* disable native_amp on huggingface accelerator for every training example that uses it

* convert more usages of autocast to nullcontext, make style fixes

* make style fixes

* style.

* Empty-Commit

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 2b04ec2f
...@@ -916,6 +916,10 @@ def main(args): ...@@ -916,6 +916,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -484,6 +484,10 @@ def main(args): ...@@ -484,6 +484,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -526,6 +526,10 @@ def main(args): ...@@ -526,6 +526,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -516,6 +516,10 @@ def main(args): ...@@ -516,6 +516,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -623,6 +623,10 @@ def main(args): ...@@ -623,6 +623,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
import math import math
import os import os
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -410,6 +411,10 @@ def main(): ...@@ -410,6 +411,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb": if args.report_to == "wandb":
...@@ -967,9 +972,12 @@ def main(): ...@@ -967,9 +972,12 @@ def main():
# run inference # run inference
original_image = download_image(args.val_image_url) original_image = download_image(args.val_image_url)
edited_images = [] edited_images = []
with torch.autocast( if torch.backends.mps.is_available():
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" autocast_ctx = nullcontext()
): else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
edited_images.append( edited_images.append(
pipeline( pipeline(
......
...@@ -378,6 +378,10 @@ def main(): ...@@ -378,6 +378,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# If passed along, set the training seed now. # If passed along, set the training seed now.
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) set_seed(args.seed)
......
...@@ -411,6 +411,11 @@ def main(): ...@@ -411,6 +411,11 @@ def main():
log_with=args.report_to, log_with=args.report_to,
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -698,6 +698,10 @@ def main(args): ...@@ -698,6 +698,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -566,6 +566,10 @@ def main(): ...@@ -566,6 +566,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -439,6 +439,10 @@ def main(): ...@@ -439,6 +439,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -581,6 +581,10 @@ def main(): ...@@ -581,6 +581,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -295,6 +295,10 @@ def main(args): ...@@ -295,6 +295,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.logger == "tensorboard": if args.logger == "tensorboard":
if not is_tensorboard_available(): if not is_tensorboard_available():
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
......
...@@ -799,6 +799,10 @@ def main(args): ...@@ -799,6 +799,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -20,6 +20,7 @@ import math ...@@ -20,6 +20,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -164,7 +165,12 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight ...@@ -164,7 +165,12 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
images = [] images = []
for i in range(len(args.validation_prompts)): for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
images.append(image) images.append(image)
...@@ -523,6 +529,10 @@ def main(): ...@@ -523,6 +529,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -21,6 +21,7 @@ import math ...@@ -21,6 +21,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import datasets import datasets
...@@ -408,6 +409,11 @@ def main(): ...@@ -408,6 +409,11 @@ def main():
log_with=args.report_to, log_with=args.report_to,
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
...@@ -878,7 +884,12 @@ def main(): ...@@ -878,7 +884,12 @@ def main():
if args.seed is not None: if args.seed is not None:
generator = generator.manual_seed(args.seed) generator = generator.manual_seed(args.seed)
images = [] images = []
with torch.cuda.amp.autocast(): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
images.append( images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
...@@ -948,7 +959,12 @@ def main(): ...@@ -948,7 +959,12 @@ def main():
if args.seed is not None: if args.seed is not None:
generator = generator.manual_seed(args.seed) generator = generator.manual_seed(args.seed)
images = [] images = []
with torch.cuda.amp.autocast(): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
images.append( images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
......
...@@ -21,6 +21,7 @@ import math ...@@ -21,6 +21,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import datasets import datasets
...@@ -979,13 +980,6 @@ def main(args): ...@@ -979,13 +980,6 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args)) accelerator.init_trackers("text2image-fine-tune", config=vars(args))
# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -1211,11 +1205,12 @@ def main(args): ...@@ -1211,11 +1205,12 @@ def main(args):
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with torch.autocast( with autocast_ctx:
accelerator.device.type,
enabled=enable_autocast,
):
images = [ images = [
pipeline(**pipeline_args, generator=generator).images[0] pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
......
...@@ -23,6 +23,7 @@ import math ...@@ -23,6 +23,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -603,6 +604,10 @@ def main(args): ...@@ -603,6 +604,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
...@@ -986,12 +991,10 @@ def main(args): ...@@ -986,12 +991,10 @@ def main(args):
model = model._orig_mod if is_compiled_module(model) else model model = model._orig_mod if is_compiled_module(model) else model
return model return model
# Some configurations require autocast to be disabled. if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
enable_autocast = True autocast_ctx = nullcontext()
if torch.backends.mps.is_available() or ( else:
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" autocast_ctx = torch.autocast(accelerator.device.type)
):
enable_autocast = False
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -1226,10 +1229,7 @@ def main(args): ...@@ -1226,10 +1229,7 @@ def main(args):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
with torch.autocast( with autocast_ctx:
accelerator.device.type,
enabled=enable_autocast,
):
images = [ images = [
pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
...@@ -1284,7 +1284,8 @@ def main(args): ...@@ -1284,7 +1284,8 @@ def main(args):
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
with autocast_ctx:
images = [ images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import random import random
import shutil import shutil
import warnings import warnings
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -143,7 +144,12 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight ...@@ -143,7 +144,12 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = [] images = []
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
with torch.autocast("cuda"): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
images.append(image) images.append(image)
...@@ -600,6 +606,10 @@ def main(): ...@@ -600,6 +606,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -605,6 +605,10 @@ def main(): ...@@ -605,6 +605,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment