Unverified Commit d0cf681a authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] add: tests for t2i adapter training. (#4947)

add: tests for t2i adapter training.
parent dfec61f4
...@@ -245,6 +245,13 @@ def parse_args(input_args=None): ...@@ -245,6 +245,13 @@ def parse_args(input_args=None):
default=None, default=None,
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
) )
parser.add_argument(
"--adapter_model_name_or_path",
type=str,
default=None,
help="Path to pretrained adapter model or model identifier from huggingface.co/models."
" If not specified adapter weights are initialized w.r.t the configurations of SDXL.",
)
parser.add_argument( parser.add_argument(
"--revision", "--revision",
type=str, type=str,
...@@ -840,7 +847,11 @@ def main(args): ...@@ -840,7 +847,11 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
) )
logger.info("Initializing t2iadapter weights from unet") if args.adapter_model_name_or_path:
logger.info("Loading existing adapter weights.")
t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path)
else:
logger.info("Initializing t2iadapter weights.")
t2iadapter = T2IAdapter( t2iadapter = T2IAdapter(
in_channels=3, in_channels=3,
channels=(320, 640, 1280, 1280), channels=(320, 640, 1280, 1280),
......
...@@ -1528,6 +1528,25 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -1528,6 +1528,25 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
def test_t2i_adapter_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/t2i_adapter/train_t2i_adapter_sdxl.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
--adapter_model_name_or_path=hf-internal-testing/tiny-adapter
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_args = f""" test_args = f"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment