Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d3ce6f4b
Unverified
Commit
d3ce6f4b
authored
Mar 07, 2023
by
Pedro Cuenca
Committed by
GitHub
Mar 07, 2023
Browse files
Support revision in Flax text-to-image training (#2567)
Support revision in Flax text-to-image training.
parent
ff91f154
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
4 deletions
+13
-4
examples/text_to_image/train_text_to_image_flax.py
examples/text_to_image/train_text_to_image_flax.py
+13
-4
No files found.
examples/text_to_image/train_text_to_image_flax.py
View file @
d3ce6f4b
...
...
@@ -48,6 +48,13 @@ def parse_args():
required
=
True
,
help
=
"Path to pretrained model or model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--revision"
,
type
=
str
,
default
=
None
,
required
=
False
,
help
=
"Revision of pretrained model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--dataset_name"
,
type
=
str
,
...
...
@@ -386,15 +393,17 @@ def main():
weight_dtype
=
jnp
.
bfloat16
# Load models and create wrapper for stable diffusion
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer"
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
revision
=
args
.
revision
,
subfolder
=
"tokenizer"
)
text_encoder
=
FlaxCLIPTextModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
dtype
=
weight_dtype
args
.
pretrained_model_name_or_path
,
revision
=
args
.
revision
,
subfolder
=
"text_encoder"
,
dtype
=
weight_dtype
)
vae
,
vae_params
=
FlaxAutoencoderKL
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"vae"
,
dtype
=
weight_dtype
args
.
pretrained_model_name_or_path
,
revision
=
args
.
revision
,
subfolder
=
"vae"
,
dtype
=
weight_dtype
)
unet
,
unet_params
=
FlaxUNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
dtype
=
weight_dtype
args
.
pretrained_model_name_or_path
,
revision
=
args
.
revision
,
subfolder
=
"unet"
,
dtype
=
weight_dtype
)
# Optimization
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment