"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "6862e372597d1baeca3ae17e8d7956d6c23755b1"
Unverified Commit 8e7d6c03 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[chore] fix: retain memory utility. (#9543)

* fix: retain memory utility.

* fix

* quality

* free_memory.
parent b28675c6
...@@ -38,10 +38,7 @@ from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPi ...@@ -38,10 +38,7 @@ from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPi
from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import ( from diffusers.training_utils import cast_training_params, free_memory
cast_training_params,
clear_objs_and_retain_memory,
)
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
...@@ -726,7 +723,8 @@ def log_validation( ...@@ -726,7 +723,8 @@ def log_validation(
} }
) )
clear_objs_and_retain_memory([pipe]) del pipe
free_memory()
return videos return videos
......
...@@ -54,7 +54,7 @@ from diffusers import ( ...@@ -54,7 +54,7 @@ from diffusers import (
from diffusers.models.controlnet_flux import FluxControlNetModel from diffusers.models.controlnet_flux import FluxControlNetModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.training_utils import clear_objs_and_retain_memory, compute_density_for_timestep_sampling from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
...@@ -193,7 +193,8 @@ def log_validation( ...@@ -193,7 +193,8 @@ def log_validation(
else: else:
logger.warning(f"image logging not implemented for {tracker.name}") logger.warning(f"image logging not implemented for {tracker.name}")
clear_objs_and_retain_memory([pipeline]) del pipeline
free_memory()
return image_logs return image_logs
...@@ -1103,7 +1104,8 @@ def main(args): ...@@ -1103,7 +1104,8 @@ def main(args):
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50 compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
) )
clear_objs_and_retain_memory([text_encoders, tokenizers]) del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
# Then get the training dataset ready to be passed to the dataloader. # Then get the training dataset ready to be passed to the dataloader.
train_dataset = prepare_train_dataset(train_dataset, accelerator) train_dataset = prepare_train_dataset(train_dataset, accelerator)
......
...@@ -49,11 +49,7 @@ from diffusers import ( ...@@ -49,11 +49,7 @@ from diffusers import (
StableDiffusion3ControlNetPipeline, StableDiffusion3ControlNetPipeline,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import ( from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
...@@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v ...@@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
else: else:
logger.warning(f"image logging not implemented for {tracker.name}") logger.warning(f"image logging not implemented for {tracker.name}")
clear_objs_and_retain_memory(pipeline) del pipeline
free_memory()
if not is_final_validation: if not is_final_validation:
controlnet.to(accelerator.device) controlnet.to(accelerator.device)
...@@ -1131,7 +1128,9 @@ def main(args): ...@@ -1131,7 +1128,9 @@ def main(args):
new_fingerprint = Hasher.hash(args) new_fingerprint = Hasher.hash(args)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
clear_objs_and_retain_memory(text_encoders + tokenizers) del text_encoder_one, text_encoder_two, text_encoder_three
del tokenizer_one, tokenizer_two, tokenizer_three
free_memory()
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
......
...@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler ...@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import ( from diffusers.training_utils import (
_set_state_dict_into_text_encoder, _set_state_dict_into_text_encoder,
cast_training_params, cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling, compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3, compute_loss_weighting_for_sd3,
free_memory,
) )
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
...@@ -1437,7 +1437,8 @@ def main(args): ...@@ -1437,7 +1437,8 @@ def main(args):
# Clear the memory here # Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts: if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two]) del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't # pack the statically computed variables appropriately here. This is so that we don't
...@@ -1480,7 +1481,8 @@ def main(args): ...@@ -1480,7 +1481,8 @@ def main(args):
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if args.validation_prompt is None: if args.validation_prompt is None:
clear_objs_and_retain_memory([vae]) del vae
free_memory()
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
...@@ -1817,7 +1819,8 @@ def main(args): ...@@ -1817,7 +1819,8 @@ def main(args):
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
if not args.train_text_encoder: if not args.train_text_encoder:
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two]) del text_encoder_one, text_encoder_two
free_memory()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler ...@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import ( from diffusers.training_utils import (
_set_state_dict_into_text_encoder, _set_state_dict_into_text_encoder,
cast_training_params, cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling, compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3, compute_loss_weighting_for_sd3,
free_memory,
) )
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
...@@ -211,7 +211,8 @@ def log_validation( ...@@ -211,7 +211,8 @@ def log_validation(
} }
) )
clear_objs_and_retain_memory(objs=[pipeline]) del pipeline
free_memory()
return images return images
...@@ -1106,7 +1107,8 @@ def main(args): ...@@ -1106,7 +1107,8 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename) image.save(image_filename)
clear_objs_and_retain_memory(objs=[pipeline]) del pipeline
free_memory()
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -1453,9 +1455,9 @@ def main(args): ...@@ -1453,9 +1455,9 @@ def main(args):
# Clear the memory here # Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts: if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
clear_objs_and_retain_memory( del tokenizers, text_encoders
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] del text_encoder_one, text_encoder_two, text_encoder_three
) free_memory()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't # pack the statically computed variables appropriately here. This is so that we don't
...@@ -1791,11 +1793,9 @@ def main(args): ...@@ -1791,11 +1793,9 @@ def main(args):
epoch=epoch, epoch=epoch,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
objs = []
if not args.train_text_encoder:
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
clear_objs_and_retain_memory(objs=objs) del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): ...@@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting return weighting
def clear_objs_and_retain_memory(objs: List[Any]): def free_memory():
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator.""" """Runs garbage collection. Then clears the cache of the available accelerator."""
if len(objs) >= 1:
for obj in objs:
del obj
gc.collect() gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment