Unverified Commit df55f053 authored by Justin Ruan's avatar Justin Ruan Committed by GitHub
Browse files

Fix wrong indent for examples of controlnet script (#11632)

fix wrong indent for training controlnet
parent 89ddb6c0
...@@ -178,11 +178,11 @@ def log_validation( ...@@ -178,11 +178,11 @@ def log_validation(
else: else:
logger.warning(f"image logging not implemented for {tracker.name}") logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline del pipeline
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return image_logs return image_logs
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
......
...@@ -192,9 +192,9 @@ def log_validation( ...@@ -192,9 +192,9 @@ def log_validation(
else: else:
logger.warning(f"image logging not implemented for {tracker.name}") logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline del pipeline
free_memory() free_memory()
return image_logs return image_logs
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
......
...@@ -199,13 +199,13 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v ...@@ -199,13 +199,13 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
else: else:
logger.warning(f"image logging not implemented for {tracker.name}") logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline del pipeline
free_memory() free_memory()
if not is_final_validation: if not is_final_validation:
controlnet.to(accelerator.device) controlnet.to(accelerator.device)
return image_logs return image_logs
# Copied from dreambooth sd3 example # Copied from dreambooth sd3 example
......
...@@ -201,11 +201,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, ...@@ -201,11 +201,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
else: else:
logger.warning(f"image logging not implemented for {tracker.name}") logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline del pipeline
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return image_logs return image_logs
def import_model_class_from_model_name_or_path( def import_model_class_from_model_name_or_path(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment