"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c15cda03ca36a5e344c8f26179c9ef48d3a88c69"
Unverified Commit 8e963d1c authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

7529 do not disable autocast for cuda devices (#7530)



* 7529 do not disable autocast for cuda devices

* Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue

* add autocast fix to other training examples

* disable native_amp for dreambooth (sdxl)

* disable native_amp for pix2pix (sdxl)

* remove tests from remaining files

* disable native_amp on huggingface accelerator for every training example that uses it

* convert more usages of autocast to nullcontext, make style fixes

* make style fixes

* style.

* Empty-Commit

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 2b04ec2f
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import re import re
import shutil import shutil
import warnings import warnings
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
...@@ -1844,7 +1845,12 @@ def main(args): ...@@ -1844,7 +1845,12 @@ def main(args):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
with torch.cuda.amp.autocast(): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = [ images = [
pipeline(**pipeline_args, generator=generator).images[0] pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import contextlib
import gc import gc
import hashlib import hashlib
import itertools import itertools
...@@ -26,6 +25,7 @@ import random ...@@ -26,6 +25,7 @@ import random
import re import re
import shutil import shutil
import warnings import warnings
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
...@@ -2192,13 +2192,12 @@ def main(args): ...@@ -2192,13 +2192,12 @@ def main(args):
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
inference_ctx = ( if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
contextlib.nullcontext() autocast_ctx = nullcontext()
if "playground" in args.pretrained_model_name_or_path else:
else torch.cuda.amp.autocast() autocast_ctx = torch.autocast(accelerator.device.type)
)
with inference_ctx: with autocast_ctx:
images = [ images = [
pipeline(**pipeline_args, generator=generator).images[0] pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
......
...@@ -430,6 +430,9 @@ def main(args): ...@@ -430,6 +430,9 @@ def main(args):
log_with=args.report_to, log_with=args.report_to,
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if accelerator.is_main_process: if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
......
...@@ -23,6 +23,7 @@ import math ...@@ -23,6 +23,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
...@@ -238,6 +239,10 @@ class SDText2ImageDataset: ...@@ -238,6 +239,10 @@ class SDText2ImageDataset:
def log_validation(vae, unet, args, accelerator, weight_dtype, step): def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ") logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
...@@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): ...@@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
for _, prompt in enumerate(validation_prompts): for _, prompt in enumerate(validation_prompts):
images = [] images = []
with torch.autocast("cuda", dtype=weight_dtype): with autocast_ctx:
images = pipeline( images = pipeline(
prompt=prompt, prompt=prompt,
num_inference_steps=4, num_inference_steps=4,
...@@ -1172,6 +1177,11 @@ def main(args): ...@@ -1172,6 +1177,11 @@ def main(args):
).input_ids.to(accelerator.device) ).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
# 16. Train! # 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -1300,7 +1310,7 @@ def main(args): ...@@ -1300,7 +1310,7 @@ def main(args):
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep. # solver timestep.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda"): with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet( cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
...@@ -1359,7 +1369,7 @@ def main(args): ...@@ -1359,7 +1369,7 @@ def main(args):
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation. # Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype): with autocast_ctx:
target_noise_pred = unet( target_noise_pred = unet(
x_prev.float(), x_prev.float(),
timesteps, timesteps,
......
...@@ -22,6 +22,7 @@ import math ...@@ -22,6 +22,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin ...@@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
for _, prompt in enumerate(validation_prompts): for _, prompt in enumerate(validation_prompts):
images = [] images = []
with torch.autocast("cuda", dtype=weight_dtype): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
with autocast_ctx:
images = pipeline( images = pipeline(
prompt=prompt, prompt=prompt,
num_inference_steps=4, num_inference_steps=4,
......
...@@ -24,6 +24,7 @@ import math ...@@ -24,6 +24,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
...@@ -256,6 +257,10 @@ class SDXLText2ImageDataset: ...@@ -256,6 +257,10 @@ class SDXLText2ImageDataset:
def log_validation(vae, unet, args, accelerator, weight_dtype, step): def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ") logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
...@@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): ...@@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
for _, prompt in enumerate(validation_prompts): for _, prompt in enumerate(validation_prompts):
images = [] images = []
with torch.autocast("cuda", dtype=weight_dtype): with autocast_ctx:
images = pipeline( images = pipeline(
prompt=prompt, prompt=prompt,
num_inference_steps=4, num_inference_steps=4,
...@@ -1353,7 +1358,12 @@ def main(args): ...@@ -1353,7 +1358,12 @@ def main(args):
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep. # solver timestep.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda"): if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet( cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
...@@ -1416,7 +1426,12 @@ def main(args): ...@@ -1416,7 +1426,12 @@ def main(args):
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation. # Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda", enabled=True, dtype=weight_dtype): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
with autocast_ctx:
target_noise_pred = unet( target_noise_pred = unet(
x_prev.float(), x_prev.float(),
timesteps, timesteps,
......
...@@ -23,6 +23,7 @@ import math ...@@ -23,6 +23,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
...@@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe ...@@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
for _, prompt in enumerate(validation_prompts): for _, prompt in enumerate(validation_prompts):
images = [] images = []
with torch.autocast("cuda"): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = pipeline( images = pipeline(
prompt=prompt, prompt=prompt,
num_inference_steps=4, num_inference_steps=4,
...@@ -1257,7 +1263,12 @@ def main(args): ...@@ -1257,7 +1263,12 @@ def main(args):
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep. # solver timestep.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda"): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet( cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
...@@ -1315,7 +1326,12 @@ def main(args): ...@@ -1315,7 +1326,12 @@ def main(args):
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
with autocast_ctx:
target_noise_pred = target_unet( target_noise_pred = target_unet(
x_prev.float(), x_prev.float(),
timesteps, timesteps,
......
...@@ -24,6 +24,7 @@ import math ...@@ -24,6 +24,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
...@@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe ...@@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
for _, prompt in enumerate(validation_prompts): for _, prompt in enumerate(validation_prompts):
images = [] images = []
with torch.autocast("cuda"): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = pipeline( images = pipeline(
prompt=prompt, prompt=prompt,
num_inference_steps=4, num_inference_steps=4,
...@@ -1355,7 +1361,12 @@ def main(args): ...@@ -1355,7 +1361,12 @@ def main(args):
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep. # solver timestep.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda"): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet( cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
...@@ -1417,7 +1428,12 @@ def main(args): ...@@ -1417,7 +1428,12 @@ def main(args):
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
with autocast_ctx:
target_noise_pred = target_unet( target_noise_pred = target_unet(
x_prev.float(), x_prev.float(),
timesteps, timesteps,
......
...@@ -752,6 +752,10 @@ def main(args): ...@@ -752,6 +752,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import contextlib
import functools import functools
import gc import gc
import logging import logging
...@@ -22,6 +21,7 @@ import math ...@@ -22,6 +21,7 @@ import math
import os import os
import random import random
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, ...@@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
) )
image_logs = [] image_logs = []
inference_ctx = ( if is_final_validation or torch.backends.mps.is_available():
contextlib.nullcontext() autocast_ctx = nullcontext()
if (is_final_validation or torch.backends.mps.is_available()) else:
else torch.autocast("cuda") autocast_ctx = torch.autocast(accelerator.device.type)
)
for validation_prompt, validation_image in zip(validation_prompts, validation_images): for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB") validation_image = Image.open(validation_image).convert("RGB")
...@@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, ...@@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
images = [] images = []
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
with inference_ctx: with autocast_ctx:
image = pipeline( image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0] ).images[0]
...@@ -811,6 +810,10 @@ def main(args): ...@@ -811,6 +810,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -676,6 +676,10 @@ def main(args): ...@@ -676,6 +676,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -821,6 +821,10 @@ def main(args): ...@@ -821,6 +821,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -749,6 +749,10 @@ def main(args): ...@@ -749,6 +749,10 @@ def main(args):
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import random import random
import shutil import shutil
import warnings import warnings
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -207,18 +208,12 @@ def log_validation( ...@@ -207,18 +208,12 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
enable_autocast = True if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
if torch.backends.mps.is_available() or ( autocast_ctx = nullcontext()
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" else:
): autocast_ctx = torch.autocast(accelerator.device.type)
enable_autocast = False
if "playground" in args.pretrained_model_name_or_path:
enable_autocast = False
with torch.autocast( with autocast_ctx:
accelerator.device.type,
enabled=enable_autocast,
):
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
...@@ -992,6 +987,10 @@ def main(args): ...@@ -992,6 +987,10 @@ def main(args):
kwargs_handlers=[kwargs], kwargs_handlers=[kwargs],
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
import math import math
import os import os
import shutil import shutil
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -404,6 +405,10 @@ def main(): ...@@ -404,6 +405,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb": if args.report_to == "wandb":
...@@ -943,9 +948,12 @@ def main(): ...@@ -943,9 +948,12 @@ def main():
# run inference # run inference
original_image = download_image(args.val_image_url) original_image = download_image(args.val_image_url)
edited_images = [] edited_images = []
with torch.autocast( if torch.backends.mps.is_available():
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" autocast_ctx = nullcontext()
): else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
edited_images.append( edited_images.append(
pipeline( pipeline(
......
...@@ -20,6 +20,7 @@ import math ...@@ -20,6 +20,7 @@ import math
import os import os
import shutil import shutil
import warnings import warnings
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
...@@ -70,9 +71,7 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] ...@@ -70,9 +71,7 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
def log_validation( def log_validation(pipeline, args, accelerator, generator, global_step, is_final_validation=False):
pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True
):
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
...@@ -91,7 +90,12 @@ def log_validation( ...@@ -91,7 +90,12 @@ def log_validation(
else Image.open(image_url_or_path).convert("RGB") else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path) )(args.val_image_url_or_path)
with torch.autocast(accelerator.device.type, enabled=enable_autocast): if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
edited_images = [] edited_images = []
# Run inference # Run inference
for val_img_idx in range(args.num_validation_images): for val_img_idx in range(args.num_validation_images):
...@@ -507,6 +511,10 @@ def main(): ...@@ -507,6 +511,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
...@@ -983,13 +991,6 @@ def main(): ...@@ -983,13 +991,6 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args)) accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))
# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -1202,7 +1203,6 @@ def main(): ...@@ -1202,7 +1203,6 @@ def main():
generator, generator,
global_step, global_step,
is_final_validation=False, is_final_validation=False,
enable_autocast=enable_autocast,
) )
if args.use_ema: if args.use_ema:
...@@ -1252,7 +1252,6 @@ def main(): ...@@ -1252,7 +1252,6 @@ def main():
generator, generator,
global_step, global_step,
is_final_validation=True, is_final_validation=True,
enable_autocast=enable_autocast,
) )
accelerator.end_training() accelerator.end_training()
......
...@@ -458,6 +458,10 @@ def main(): ...@@ -458,6 +458,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
...@@ -343,6 +343,11 @@ def main(): ...@@ -343,6 +343,11 @@ def main():
log_with=args.report_to, log_with=args.report_to,
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -356,6 +356,11 @@ def main(): ...@@ -356,6 +356,11 @@ def main():
log_with=args.report_to, log_with=args.report_to,
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
......
...@@ -459,6 +459,10 @@ def main(): ...@@ -459,6 +459,10 @@ def main():
project_config=accelerator_project_config, project_config=accelerator_project_config,
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment