Unverified Commit 96a90974 authored by Junjie's avatar Junjie Committed by GitHub
Browse files

Add offload option in flux-control training (#10225)



* Add offload option in flux-control training

* Update examples/flux-control/train_control_flux.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* modify help message

* fix format

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent a5f35ee4
......@@ -36,6 +36,7 @@ accelerate launch train_control_lora_flux.py \
--max_train_steps=5000 \
--validation_image="openpose.png" \
--validation_prompt="A couple, 4k photo, highly detailed" \
--offload \
--seed="0" \
--push_to_hub
```
......@@ -154,6 +155,7 @@ accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
--validation_steps=200 \
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
--offload \
--seed="0" \
--push_to_hub
```
......
......@@ -541,6 +541,11 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
)
if input_args is not None:
args = parser.parse_args(input_args)
......@@ -999,8 +1004,9 @@ def main(args):
control_latents = encode_images(
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
)
# offload vae to CPU.
vae.cpu()
if args.offload:
# offload vae to CPU.
vae.cpu()
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
......@@ -1064,7 +1070,8 @@ def main(args):
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
prompt_embeds.zero_()
pooled_prompt_embeds.zero_()
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Predict.
model_pred = flux_transformer(
......
......@@ -573,6 +573,11 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
)
if input_args is not None:
args = parser.parse_args(input_args)
......@@ -1140,8 +1145,10 @@ def main(args):
control_latents = encode_images(
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
)
# offload vae to CPU.
vae.cpu()
if args.offload:
# offload vae to CPU.
vae.cpu()
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
......@@ -1205,7 +1212,8 @@ def main(args):
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
prompt_embeds.zero_()
pooled_prompt_embeds.zero_()
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Predict.
model_pred = flux_transformer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment