Unverified Commit 328e0d20 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] set rest of the blocks with `requires_grad` False. (#10607)

set rest of the blocks with requires_grad False.
parent 23b467c7
...@@ -812,6 +812,8 @@ def main(args): ...@@ -812,6 +812,8 @@ def main(args):
for name, module in flux_transformer.named_modules(): for name, module in flux_transformer.named_modules():
if "transformer_blocks" in name: if "transformer_blocks" in name:
module.requires_grad_(True) module.requires_grad_(True)
else:
module.requirs_grad_(False)
def unwrap_model(model): def unwrap_model(model):
model = accelerator.unwrap_model(model) model = accelerator.unwrap_model(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment