Unverified Commit c9f939bf authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Update full dreambooth script to work with IF (#3425)

parent 2858d7e1
This diff is collapsed.
...@@ -147,6 +147,32 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -147,6 +147,32 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_if(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--pre_compute_text_embeddings
--tokenizer_max_length=77
--text_encoder_use_attention_mask
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_checkpointing(self): def test_dreambooth_checkpointing(self):
instance_prompt = "photo" instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
......
...@@ -1507,16 +1507,33 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1507,16 +1507,33 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# resnet if self.training and self.gradient_checkpointing:
hidden_states = resnet(hidden_states, temb)
# attn def create_custom_forward(module, return_dict=None):
hidden_states = attn( def custom_forward(*inputs):
hidden_states, if return_dict is not None:
encoder_hidden_states=encoder_hidden_states, return module(*inputs, return_dict=return_dict)
attention_mask=attention_mask, else:
**cross_attention_kwargs, return module(*inputs)
)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
...@@ -2593,15 +2610,33 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2593,15 +2610,33 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple = res_hidden_states_tuple[:-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb) if self.training and self.gradient_checkpointing:
# attn def create_custom_forward(module, return_dict=None):
hidden_states = attn( def custom_forward(*inputs):
hidden_states, if return_dict is not None:
encoder_hidden_states=encoder_hidden_states, return module(*inputs, return_dict=return_dict)
attention_mask=attention_mask, else:
**cross_attention_kwargs, return module(*inputs)
)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment