"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7071b7461b224bdc82b9dd2bde2c1842320ccc66"
Unverified Commit 012d08b1 authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

Enable dreambooth lora finetune example on other devices (#10602)



* enable dreambooth_lora on other devices
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* enable xpu
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* check cuda device before empty cache
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix comment
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* import free_memory
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>
parent 4ace7d04
...@@ -54,7 +54,11 @@ from diffusers import ( ...@@ -54,7 +54,11 @@ from diffusers import (
) )
from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
free_memory,
)
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
...@@ -151,14 +155,14 @@ def log_validation( ...@@ -151,14 +155,14 @@ def log_validation(
if args.validation_images is None: if args.validation_images is None:
images = [] images = []
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast(): with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, generator=generator).images[0] image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image) images.append(image)
else: else:
images = [] images = []
for image in args.validation_images: for image in args.validation_images:
image = Image.open(image) image = Image.open(image)
with torch.cuda.amp.autocast(): with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, image=image, generator=generator).images[0] image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image) images.append(image)
...@@ -177,7 +181,7 @@ def log_validation( ...@@ -177,7 +181,7 @@ def log_validation(
) )
del pipeline del pipeline
torch.cuda.empty_cache() free_memory()
return images return images
...@@ -793,7 +797,7 @@ def main(args): ...@@ -793,7 +797,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir())) cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images: if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
if args.prior_generation_precision == "fp32": if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32 torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16": elif args.prior_generation_precision == "fp16":
...@@ -829,8 +833,7 @@ def main(args): ...@@ -829,8 +833,7 @@ def main(args):
image.save(image_filename) image.save(image_filename)
del pipeline del pipeline
if torch.cuda.is_available(): free_memory()
torch.cuda.empty_cache()
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -1085,7 +1088,7 @@ def main(args): ...@@ -1085,7 +1088,7 @@ def main(args):
tokenizer = None tokenizer = None
gc.collect() gc.collect()
torch.cuda.empty_cache() free_memory()
else: else:
pre_computed_encoder_hidden_states = None pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None validation_prompt_encoder_hidden_states = None
......
...@@ -299,6 +299,8 @@ def free_memory(): ...@@ -299,6 +299,8 @@ def free_memory():
torch.mps.empty_cache() torch.mps.empty_cache()
elif is_torch_npu_available(): elif is_torch_npu_available():
torch_npu.npu.empty_cache() torch_npu.npu.empty_cache()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment