Unverified Commit 8ba90aa7 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

chore: add a cleaning utility to be useful during training. (#9240)

parent 9d49b45b
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import argparse import argparse
import copy import copy
import gc
import itertools import itertools
import logging import logging
import math import math
...@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler ...@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import ( from diffusers.training_utils import (
_set_state_dict_into_text_encoder, _set_state_dict_into_text_encoder,
cast_training_params, cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling, compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3, compute_loss_weighting_for_sd3,
) )
...@@ -210,9 +210,7 @@ def log_validation( ...@@ -210,9 +210,7 @@ def log_validation(
} }
) )
del pipeline clear_objs_and_retain_memory(objs=[pipeline])
if torch.cuda.is_available():
torch.cuda.empty_cache()
return images return images
...@@ -1107,9 +1105,7 @@ def main(args): ...@@ -1107,9 +1105,7 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename) image.save(image_filename)
del pipeline clear_objs_and_retain_memory(objs=[pipeline])
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -1455,12 +1451,10 @@ def main(args): ...@@ -1455,12 +1451,10 @@ def main(args):
# Clear the memory here # Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts: if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three clear_objs_and_retain_memory(
gc.collect() objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
if torch.cuda.is_available(): )
torch.cuda.empty_cache()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't # pack the statically computed variables appropriately here. This is so that we don't
...@@ -1795,11 +1789,11 @@ def main(args): ...@@ -1795,11 +1789,11 @@ def main(args):
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
) )
objs = []
if not args.train_text_encoder: if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
torch.cuda.empty_cache() clear_objs_and_retain_memory(objs=objs)
gc.collect()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
import contextlib import contextlib
import copy import copy
import gc
import math import math
import random import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
...@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): ...@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting return weighting
def clear_objs_and_retain_memory(objs: List[Any]):
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
if len(objs) >= 1:
for obj in objs:
del obj
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif is_torch_npu_available():
torch_npu.empty_cache()
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel: class EMAModel:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment