Unverified Commit 8ba90aa7 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

chore: add a cleaning utility to be useful during training. (#9240)

parent 9d49b45b
......@@ -15,7 +15,6 @@
import argparse
import copy
import gc
import itertools
import logging
import math
......@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
......@@ -210,9 +210,7 @@ def log_validation(
}
)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])
return images
......@@ -1107,9 +1105,7 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])
# Handle the repository creation
if accelerator.is_main_process:
......@@ -1455,12 +1451,10 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
......@@ -1795,11 +1789,11 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
)
objs = []
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
torch.cuda.empty_cache()
gc.collect()
clear_objs_and_retain_memory(objs=objs)
# Save the lora layers
accelerator.wait_for_everyone()
......
import contextlib
import copy
import gc
import math
import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
......@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
def clear_objs_and_retain_memory(objs: List[Any]):
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
if len(objs) >= 1:
for obj in objs:
del obj
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif is_torch_npu_available():
torch_npu.empty_cache()
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment