Unverified Commit 4b9f1c7d authored by Dev Rajput's avatar Dev Rajput Committed by GitHub
Browse files

Add correct number of channels when resuming from checkpoint for Flux Control...

Add correct number of channels when resuming from checkpoint for Flux Control LoRa training (#10422)

* Add correct number of channels when resuming from checkpoint

* Fix Formatting
parent 91008aab
......@@ -923,11 +923,28 @@ def main(args):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = FluxTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
).to(accelerator.device, weight_dtype)
# Handle input dimension doubling before adding adapter
with torch.no_grad():
initial_input_channels = transformer_.config.in_channels
new_linear = torch.nn.Linear(
transformer_.x_embedder.in_features * 2,
transformer_.x_embedder.out_features,
bias=transformer_.x_embedder.bias is not None,
dtype=transformer_.dtype,
device=transformer_.device,
)
new_linear.weight.zero_()
new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
if transformer_.x_embedder.bias is not None:
new_linear.bias.copy_(transformer_.x_embedder.bias)
transformer_.x_embedder = new_linear
transformer_.register_to_config(in_channels=initial_input_channels * 2)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment