Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c9f939bf
Unverified
Commit
c9f939bf
authored
May 17, 2023
by
Will Berman
Committed by
GitHub
May 17, 2023
Browse files
Update full dreambooth script to work with IF (#3425)
parent
2858d7e1
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
344 additions
and
57 deletions
+344
-57
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+266
-40
examples/test_examples.py
examples/test_examples.py
+26
-0
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+52
-17
No files found.
examples/dreambooth/train_dreambooth.py
View file @
c9f939bf
This diff is collapsed.
Click to expand it.
examples/test_examples.py
View file @
c9f939bf
...
@@ -147,6 +147,32 @@ class ExamplesTestsAccelerate(unittest.TestCase):
...
@@ -147,6 +147,32 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"unet"
,
"diffusion_pytorch_model.bin"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"unet"
,
"diffusion_pytorch_model.bin"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
def
test_dreambooth_if
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
test_args
=
f
"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir
{
tmpdir
}
--pre_compute_text_embeddings
--tokenizer_max_length=77
--text_encoder_use_attention_mask
"""
.
split
()
run_command
(
self
.
_launch_args
+
test_args
)
# save_pretrained smoke test
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"unet"
,
"diffusion_pytorch_model.bin"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
def
test_dreambooth_checkpointing
(
self
):
def
test_dreambooth_checkpointing
(
self
):
instance_prompt
=
"photo"
instance_prompt
=
"photo"
pretrained_model_name_or_path
=
"hf-internal-testing/tiny-stable-diffusion-pipe"
pretrained_model_name_or_path
=
"hf-internal-testing/tiny-stable-diffusion-pipe"
...
...
src/diffusers/models/unet_2d_blocks.py
View file @
c9f939bf
...
@@ -1507,16 +1507,33 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
...
@@ -1507,16 +1507,33 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs
=
cross_attention_kwargs
if
cross_attention_kwargs
is
not
None
else
{}
cross_attention_kwargs
=
cross_attention_kwargs
if
cross_attention_kwargs
is
not
None
else
{}
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
# resnet
if
self
.
training
and
self
.
gradient_checkpointing
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
# attn
def
create_custom_forward
(
module
,
return_dict
=
None
):
hidden_states
=
attn
(
def
custom_forward
(
*
inputs
):
hidden_states
,
if
return_dict
is
not
None
:
encoder_hidden_states
=
encoder_hidden_states
,
return
module
(
*
inputs
,
return_dict
=
return_dict
)
attention_mask
=
attention_mask
,
else
:
**
cross_attention_kwargs
,
return
module
(
*
inputs
)
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
resnet
),
hidden_states
,
temb
)
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
attn
,
return_dict
=
False
),
hidden_states
,
encoder_hidden_states
,
cross_attention_kwargs
,
)[
0
]
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
,
**
cross_attention_kwargs
,
)
output_states
=
output_states
+
(
hidden_states
,)
output_states
=
output_states
+
(
hidden_states
,)
...
@@ -2593,15 +2610,33 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
...
@@ -2593,15 +2610,33 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
if
self
.
training
and
self
.
gradient_checkpointing
:
# attn
def
create_custom_forward
(
module
,
return_dict
=
None
):
hidden_states
=
attn
(
def
custom_forward
(
*
inputs
):
hidden_states
,
if
return_dict
is
not
None
:
encoder_hidden_states
=
encoder_hidden_states
,
return
module
(
*
inputs
,
return_dict
=
return_dict
)
attention_mask
=
attention_mask
,
else
:
**
cross_attention_kwargs
,
return
module
(
*
inputs
)
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
resnet
),
hidden_states
,
temb
)
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
attn
,
return_dict
=
False
),
hidden_states
,
encoder_hidden_states
,
cross_attention_kwargs
,
)[
0
]
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
,
**
cross_attention_kwargs
,
)
if
self
.
upsamplers
is
not
None
:
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
for
upsampler
in
self
.
upsamplers
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment