Unverified Commit 1e07b6b3 authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

[Flax SD finetune] Fix dtype (#1038)

fix jnp dtype
parent fb38bb16
...@@ -371,11 +371,11 @@ def main(): ...@@ -371,11 +371,11 @@ def main():
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True
) )
weight_dtype = torch.float32 weight_dtype = jnp.float32
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = jnp.float16
elif args.mixed_precision == "bf16": elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = jnp.bfloat16
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment