Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d0cf681a
Unverified
Commit
d0cf681a
authored
Sep 08, 2023
by
Sayak Paul
Committed by
GitHub
Sep 08, 2023
Browse files
[Tests] add: tests for t2i adapter training. (#4947)
add: tests for t2i adapter training.
parent
dfec61f4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
8 deletions
+38
-8
examples/t2i_adapter/train_t2i_adapter_sdxl.py
examples/t2i_adapter/train_t2i_adapter_sdxl.py
+19
-8
examples/test_examples.py
examples/test_examples.py
+19
-0
No files found.
examples/t2i_adapter/train_t2i_adapter_sdxl.py
View file @
d0cf681a
...
...
@@ -245,6 +245,13 @@ def parse_args(input_args=None):
default
=
None
,
help
=
"Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038."
,
)
parser
.
add_argument
(
"--adapter_model_name_or_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to pretrained adapter model or model identifier from huggingface.co/models."
" If not specified adapter weights are initialized w.r.t the configurations of SDXL."
,
)
parser
.
add_argument
(
"--revision"
,
type
=
str
,
...
...
@@ -840,7 +847,11 @@ def main(args):
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
)
logger
.
info
(
"Initializing t2iadapter weights from unet"
)
if
args
.
adapter_model_name_or_path
:
logger
.
info
(
"Loading existing adapter weights."
)
t2iadapter
=
T2IAdapter
.
from_pretrained
(
args
.
adapter_model_name_or_path
)
else
:
logger
.
info
(
"Initializing t2iadapter weights."
)
t2iadapter
=
T2IAdapter
(
in_channels
=
3
,
channels
=
(
320
,
640
,
1280
,
1280
),
...
...
examples/test_examples.py
View file @
d0cf681a
...
...
@@ -1528,6 +1528,25 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"diffusion_pytorch_model.safetensors"
)))
def
test_t2i_adapter_sdxl
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
test_args
=
f
"""
examples/t2i_adapter/train_t2i_adapter_sdxl.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
--adapter_model_name_or_path=hf-internal-testing/tiny-adapter
--dataset_name=hf-internal-testing/fill10
--output_dir=
{
tmpdir
}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--checkpointing_steps=2
"""
.
split
()
run_command
(
self
.
_launch_args
+
test_args
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"diffusion_pytorch_model.safetensors"
)))
def
test_custom_diffusion_checkpointing_checkpoints_total_limit
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
test_args
=
f
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment