Unverified Commit 825979dd authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] fix: registration of out_channels in the control flux scripts. (#10367)

* fix: registration of out_channels in the control flux scripts.

* free memory.
parent 023b0e0d
......@@ -795,7 +795,7 @@ def main(args):
flux_transformer.x_embedder = new_linear
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
......@@ -1166,6 +1166,11 @@ def main(args):
flux_transformer.to(torch.float32)
flux_transformer.save_pretrained(args.output_dir)
del flux_transformer
del text_encoding_pipeline
del vae
free_memory()
# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:
......
......@@ -830,7 +830,7 @@ def main(args):
flux_transformer.x_embedder = new_linear
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
if args.train_norm_layers:
for name, param in flux_transformer.named_parameters():
......@@ -1319,6 +1319,11 @@ def main(args):
transformer_lora_layers=transformer_lora_layers,
)
del flux_transformer
del text_encoding_pipeline
del vae
free_memory()
# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment