Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8ecdd3ef
Unverified
Commit
8ecdd3ef
authored
Apr 18, 2023
by
Cristian Garcia
Committed by
GitHub
Apr 18, 2023
Browse files
Optimize log_validation in train_controlnet_flax (#3110)
extract pipeline from log_validation
parent
cd8b7507
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
16 deletions
+19
-16
examples/controlnet/train_controlnet_flax.py
examples/controlnet/train_controlnet_flax.py
+19
-16
No files found.
examples/controlnet/train_controlnet_flax.py
View file @
8ecdd3ef
...
...
@@ -76,20 +76,11 @@ def image_grid(imgs, rows, cols):
return
grid
def
log_validation
(
controlnet
,
controlnet_params
,
tokenizer
,
args
,
rng
,
weight_dtype
):
logger
.
info
(
"Running validation...
"
)
def
log_validation
(
pipeline
,
pipeline_params
,
controlnet_params
,
tokenizer
,
args
,
rng
,
weight_dtype
):
logger
.
info
(
"Running validation..."
)
pipeline
,
params
=
FlaxStableDiffusionControlNetPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
tokenizer
=
tokenizer
,
controlnet
=
controlnet
,
safety_checker
=
None
,
dtype
=
weight_dtype
,
revision
=
args
.
revision
,
from_pt
=
args
.
from_pt
,
)
params
=
jax_utils
.
replicate
(
params
)
params
[
"controlnet"
]
=
controlnet_params
pipeline_params
=
pipeline_params
.
copy
()
pipeline_params
[
"controlnet"
]
=
controlnet_params
num_samples
=
jax
.
device_count
()
prng_seed
=
jax
.
random
.
split
(
rng
,
jax
.
device_count
())
...
...
@@ -121,7 +112,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
images
=
pipeline
(
prompt_ids
=
prompt_ids
,
image
=
processed_image
,
params
=
params
,
params
=
pipeline_
params
,
prng_seed
=
prng_seed
,
num_inference_steps
=
50
,
jit
=
True
,
...
...
@@ -176,6 +167,7 @@ tags:
- text-to-image
- diffusers
- controlnet
- jax-diffusers-event
inference: true
---
"""
...
...
@@ -800,6 +792,17 @@ def main():
]:
controlnet_params
[
key
]
=
unet_params
[
key
]
pipeline
,
pipeline_params
=
FlaxStableDiffusionControlNetPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
tokenizer
=
tokenizer
,
controlnet
=
controlnet
,
safety_checker
=
None
,
dtype
=
weight_dtype
,
revision
=
args
.
revision
,
from_pt
=
args
.
from_pt
,
)
pipeline_params
=
jax_utils
.
replicate
(
pipeline_params
)
# Optimization
if
args
.
scale_lr
:
args
.
learning_rate
=
args
.
learning_rate
*
total_train_batch_size
...
...
@@ -1073,7 +1076,7 @@ def main():
and
global_step
%
args
.
validation_steps
==
0
and
jax
.
process_index
()
==
0
):
_
=
log_validation
(
controlnet
,
state
.
params
,
tokenizer
,
args
,
validation_rng
,
weight_dtype
)
_
=
log_validation
(
pipeline
,
pipeline_params
,
state
.
params
,
tokenizer
,
args
,
validation_rng
,
weight_dtype
)
if
global_step
%
args
.
logging_steps
==
0
and
jax
.
process_index
()
==
0
:
if
args
.
report_to
==
"wandb"
:
...
...
@@ -1105,7 +1108,7 @@ def main():
if
args
.
validation_prompt
is
not
None
:
if
args
.
profile_validation
:
jax
.
profiler
.
start_trace
(
args
.
output_dir
)
image_logs
=
log_validation
(
controlnet
,
state
.
params
,
tokenizer
,
args
,
validation_rng
,
weight_dtype
)
image_logs
=
log_validation
(
pipeline
,
pipeline_params
,
state
.
params
,
tokenizer
,
args
,
validation_rng
,
weight_dtype
)
if
args
.
profile_validation
:
jax
.
profiler
.
stop_trace
()
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment