Unverified Commit 5ef74fd5 authored by Luo Yihang's avatar Luo Yihang Committed by GitHub
Browse files

fix norm not training in train_control_lora_flux.py (#11832)

parent 64a92103
...@@ -837,11 +837,6 @@ def main(args): ...@@ -837,11 +837,6 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
if args.train_norm_layers:
for name, param in flux_transformer.named_parameters():
if any(k in name for k in NORM_LAYER_PREFIXES):
param.requires_grad = True
if args.lora_layers is not None: if args.lora_layers is not None:
if args.lora_layers != "all-linear": if args.lora_layers != "all-linear":
target_modules = [layer.strip() for layer in args.lora_layers.split(",")] target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
...@@ -879,6 +874,11 @@ def main(args): ...@@ -879,6 +874,11 @@ def main(args):
) )
flux_transformer.add_adapter(transformer_lora_config) flux_transformer.add_adapter(transformer_lora_config)
if args.train_norm_layers:
for name, param in flux_transformer.named_parameters():
if any(k in name for k in NORM_LAYER_PREFIXES):
param.requires_grad = True
def unwrap_model(model): def unwrap_model(model):
model = accelerator.unwrap_model(model) model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model model = model._orig_mod if is_compiled_module(model) else model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment