Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
c9f939bf
Unverified
Commit
c9f939bf
authored
May 17, 2023
by
Will Berman
Committed by
GitHub
May 17, 2023
Browse files
Update full dreambooth script to work with IF (#3425)
parent
2858d7e1
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
344 additions
and
57 deletions
+344
-57
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+266
-40
examples/test_examples.py
examples/test_examples.py
+26
-0
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+52
-17
No files found.
examples/dreambooth/train_dreambooth.py
View file @
c9f939bf
This diff is collapsed.
Click to expand it.
examples/test_examples.py
View file @
c9f939bf
...
...
@@ -147,6 +147,32 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"unet"
,
"diffusion_pytorch_model.bin"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
def
test_dreambooth_if
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
test_args
=
f
"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir
{
tmpdir
}
--pre_compute_text_embeddings
--tokenizer_max_length=77
--text_encoder_use_attention_mask
"""
.
split
()
run_command
(
self
.
_launch_args
+
test_args
)
# save_pretrained smoke test
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"unet"
,
"diffusion_pytorch_model.bin"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
def
test_dreambooth_checkpointing
(
self
):
instance_prompt
=
"photo"
pretrained_model_name_or_path
=
"hf-internal-testing/tiny-stable-diffusion-pipe"
...
...
src/diffusers/models/unet_2d_blocks.py
View file @
c9f939bf
...
...
@@ -1507,10 +1507,27 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs
=
cross_attention_kwargs
if
cross_attention_kwargs
is
not
None
else
{}
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
# resnet
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
if
return_dict
is
not
None
:
return
module
(
*
inputs
,
return_dict
=
return_dict
)
else
:
return
module
(
*
inputs
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
resnet
),
hidden_states
,
temb
)
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
attn
,
return_dict
=
False
),
hidden_states
,
encoder_hidden_states
,
cross_attention_kwargs
,
)[
0
]
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
# attn
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
...
@@ -2593,9 +2610,27 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
if
return_dict
is
not
None
:
return
module
(
*
inputs
,
return_dict
=
return_dict
)
else
:
return
module
(
*
inputs
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
resnet
),
hidden_states
,
temb
)
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
attn
,
return_dict
=
False
),
hidden_states
,
encoder_hidden_states
,
cross_attention_kwargs
,
)[
0
]
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
# attn
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment