Unverified Commit 8b84f851 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples] fix mixed_precision arg (#1359)

* use accelerator to check mixed_precision

* default `mixed_precision` to `None`

* pass mixed_precision to accelerate launch
parent e50c25d8
...@@ -141,7 +141,7 @@ export INSTANCE_DIR="path-to-instance-images" ...@@ -141,7 +141,7 @@ export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images" export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth.py \ accelerate launch --mixed_precision="fp16" train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \ --instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \ --class_data_dir=$CLASS_DIR \
...@@ -157,8 +157,7 @@ accelerate launch train_dreambooth.py \ ...@@ -157,8 +157,7 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 \ --max_train_steps=800
--mixed_precision=fp16
``` ```
### Fine-tune text encoder with the UNet. ### Fine-tune text encoder with the UNet.
......
...@@ -187,12 +187,12 @@ def parse_args(input_args=None): ...@@ -187,12 +187,12 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
type=str, type=str,
default="no", default=None,
choices=["no", "fp16", "bf16"], choices=["no", "fp16", "bf16"],
help=( help=(
"Whether to use mixed precision. Choose" "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
"and an Nvidia Ampere GPU." " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
), ),
) )
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
...@@ -538,9 +538,9 @@ def main(args): ...@@ -538,9 +538,9 @@ def main(args):
) )
weight_dtype = torch.float32 weight_dtype = torch.float32
if args.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
elif args.mixed_precision == "bf16": elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu. # Move text_encode and vae to gpu.
......
...@@ -46,7 +46,7 @@ With `gradient_checkpointing` and `mixed_precision` it should be possible to fin ...@@ -46,7 +46,7 @@ With `gradient_checkpointing` and `mixed_precision` it should be possible to fin
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions" export dataset_name="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image.py \ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \ --dataset_name=$dataset_name \
--use_ema \ --use_ema \
...@@ -54,7 +54,6 @@ accelerate launch train_text_to_image.py \ ...@@ -54,7 +54,6 @@ accelerate launch train_text_to_image.py \
--train_batch_size=1 \ --train_batch_size=1 \
--gradient_accumulation_steps=4 \ --gradient_accumulation_steps=4 \
--gradient_checkpointing \ --gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \ --max_train_steps=15000 \
--learning_rate=1e-05 \ --learning_rate=1e-05 \
--max_grad_norm=1 \ --max_grad_norm=1 \
...@@ -70,7 +69,7 @@ If you wish to use custom loading logic, you should modify the script, we have l ...@@ -70,7 +69,7 @@ If you wish to use custom loading logic, you should modify the script, we have l
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset" export TRAIN_DIR="path_to_your_dataset"
accelerate launch train_text_to_image.py \ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \ --train_data_dir=$TRAIN_DIR \
--use_ema \ --use_ema \
...@@ -78,7 +77,6 @@ accelerate launch train_text_to_image.py \ ...@@ -78,7 +77,6 @@ accelerate launch train_text_to_image.py \
--train_batch_size=1 \ --train_batch_size=1 \
--gradient_accumulation_steps=4 \ --gradient_accumulation_steps=4 \
--gradient_checkpointing \ --gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \ --max_train_steps=15000 \
--learning_rate=1e-05 \ --learning_rate=1e-05 \
--max_grad_norm=1 \ --max_grad_norm=1 \
......
...@@ -186,12 +186,12 @@ def parse_args(): ...@@ -186,12 +186,12 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
type=str, type=str,
default="no", default=None,
choices=["no", "fp16", "bf16"], choices=["no", "fp16", "bf16"],
help=( help=(
"Whether to use mixed precision. Choose" "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
"and an Nvidia Ampere GPU." " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
), ),
) )
parser.add_argument( parser.add_argument(
...@@ -496,9 +496,9 @@ def main(): ...@@ -496,9 +496,9 @@ def main():
) )
weight_dtype = torch.float32 weight_dtype = torch.float32
if args.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
elif args.mixed_precision == "bf16": elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu. # Move text_encode and vae to gpu.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment