Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1e07b6b3
Unverified
Commit
1e07b6b3
authored
Oct 28, 2022
by
Duong A. Nguyen
Committed by
GitHub
Oct 28, 2022
Browse files
[Flax SD finetune] Fix dtype (#1038)
fix jnp dtype
parent
fb38bb16
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
examples/text_to_image/train_text_to_image_flax.py
examples/text_to_image/train_text_to_image_flax.py
+3
-3
No files found.
examples/text_to_image/train_text_to_image_flax.py
View file @
1e07b6b3
...
...
@@ -371,11 +371,11 @@ def main():
train_dataset
,
shuffle
=
True
,
collate_fn
=
collate_fn
,
batch_size
=
total_train_batch_size
,
drop_last
=
True
)
weight_dtype
=
torch
.
float32
weight_dtype
=
jnp
.
float32
if
args
.
mixed_precision
==
"fp16"
:
weight_dtype
=
torch
.
float16
weight_dtype
=
jnp
.
float16
elif
args
.
mixed_precision
==
"bf16"
:
weight_dtype
=
torch
.
bfloat16
weight_dtype
=
jnp
.
bfloat16
# Load models and create wrapper for stable diffusion
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment