Unverified Commit fed12376 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[ControlNet SDXL training] fixes in the training script (#4223)

* fix: #4206

* add: sdxl controlnet training smoketest.

* remove unnecessary token inits.

* add: licensing to model card.

* include SDXL licensing in the model card and make public visibility default

* debugging

* debugging

* disable local file download.

* fix: training test.

* fix: ckpt prefix.
parent 95b7de88
...@@ -124,7 +124,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) ...@@ -124,7 +124,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
with torch.autocast("cuda"): with torch.autocast("cuda"):
image = pipeline( image = pipeline(
validation_prompt, validation_image, num_inference_steps=20, generator=generator prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0] ).images[0]
images.append(image) images.append(image)
...@@ -178,7 +178,7 @@ def import_model_class_from_model_name_or_path( ...@@ -178,7 +178,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
): ):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True pretrained_model_name_or_path, subfolder=subfolder, revision=revision
) )
model_class = text_encoder_config.architectures[0] model_class = text_encoder_config.architectures[0]
...@@ -226,6 +226,12 @@ inference: true ...@@ -226,6 +226,12 @@ inference: true
These are controlnet weights trained on {base_model} with new type of conditioning. These are controlnet weights trained on {base_model} with new type of conditioning.
{img_str} {img_str}
"""
model_card += """
## License
[SDXL 0.9 Research License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9/blob/main/LICENSE.md)
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card) f.write(yaml + model_card)
...@@ -798,10 +804,7 @@ def main(args): ...@@ -798,10 +804,7 @@ def main(args):
if args.push_to_hub: if args.push_to_hub:
repo_id = create_repo( repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
exist_ok=True,
token=args.hub_token,
private=True,
).repo_id ).repo_id
# Load the tokenizers # Load the tokenizers
...@@ -839,7 +842,7 @@ def main(args): ...@@ -839,7 +842,7 @@ def main(args):
revision=args.revision, revision=args.revision,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
) )
if args.controlnet_model_name_or_path: if args.controlnet_model_name_or_path:
......
...@@ -1296,6 +1296,25 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -1296,6 +1296,25 @@ class ExamplesTestsAccelerate(unittest.TestCase):
{"checkpoint-8", "checkpoint-10", "checkpoint-12"}, {"checkpoint-8", "checkpoint-10", "checkpoint-12"},
) )
def test_controlnet_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/controlnet/train_controlnet_sdxl.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin")))
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_args = f""" test_args = f"""
......
...@@ -751,7 +751,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -751,7 +751,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
sample = self.conv_in(sample) sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond sample = sample + controlnet_cond
# 3. down # 3. down
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment