Unverified Commit 825979dd authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] fix: registration of out_channels in the control flux scripts. (#10367)

* fix: registration of out_channels in the control flux scripts.

* free memory.
parent 023b0e0d
...@@ -795,7 +795,7 @@ def main(args): ...@@ -795,7 +795,7 @@ def main(args):
flux_transformer.x_embedder = new_linear flux_transformer.x_embedder = new_linear
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2) flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
def unwrap_model(model): def unwrap_model(model):
model = accelerator.unwrap_model(model) model = accelerator.unwrap_model(model)
...@@ -1166,6 +1166,11 @@ def main(args): ...@@ -1166,6 +1166,11 @@ def main(args):
flux_transformer.to(torch.float32) flux_transformer.to(torch.float32)
flux_transformer.save_pretrained(args.output_dir) flux_transformer.save_pretrained(args.output_dir)
del flux_transformer
del text_encoding_pipeline
del vae
free_memory()
# Run a final round of validation. # Run a final round of validation.
image_logs = None image_logs = None
if args.validation_prompt is not None: if args.validation_prompt is not None:
......
...@@ -830,7 +830,7 @@ def main(args): ...@@ -830,7 +830,7 @@ def main(args):
flux_transformer.x_embedder = new_linear flux_transformer.x_embedder = new_linear
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2) flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
if args.train_norm_layers: if args.train_norm_layers:
for name, param in flux_transformer.named_parameters(): for name, param in flux_transformer.named_parameters():
...@@ -1319,6 +1319,11 @@ def main(args): ...@@ -1319,6 +1319,11 @@ def main(args):
transformer_lora_layers=transformer_lora_layers, transformer_lora_layers=transformer_lora_layers,
) )
del flux_transformer
del text_encoding_pipeline
del vae
free_memory()
# Run a final round of validation. # Run a final round of validation.
image_logs = None image_logs = None
if args.validation_prompt is not None: if args.validation_prompt is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment