Unverified Commit c9f939bf authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Update full dreambooth script to work with IF (#3425)

parent 2858d7e1
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import gc
import hashlib import hashlib
import itertools import itertools
import logging import logging
...@@ -30,7 +31,7 @@ import transformers ...@@ -30,7 +31,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, model_info, upload_folder
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -93,31 +94,61 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}. ...@@ -93,31 +94,61 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}.
f.write(yaml + model_card) f.write(yaml + model_card)
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): def log_validation(
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch, prompt_embeds, negative_prompt_embeds
):
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline_args = {}
if text_encoder is not None:
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
if vae is not None:
pipeline_args["vae"] = vae
# create pipeline (note: unet and vae are loaded again in float32) # create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
vae=vae,
revision=args.revision, revision=args.revision,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
**pipeline_args,
) )
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
if args.pre_compute_text_embeddings:
pipeline_args = {
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
}
else:
pipeline_args = {"prompt": args.validation_prompt}
# run inference # run inference
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = [] images = []
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
with torch.autocast("cuda"): with torch.autocast("cuda"):
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
images.append(image) images.append(image)
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
...@@ -155,6 +186,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st ...@@ -155,6 +186,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else: else:
raise ValueError(f"{model_class} is not supported.") raise ValueError(f"{model_class} is not supported.")
...@@ -459,6 +494,27 @@ def parse_args(input_args=None): ...@@ -459,6 +494,27 @@ def parse_args(input_args=None):
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
), ),
) )
parser.add_argument(
"--pre_compute_text_embeddings",
action="store_true",
help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
)
parser.add_argument(
"--tokenizer_max_length",
type=int,
default=None,
required=False,
help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
)
parser.add_argument(
"--text_encoder_use_attention_mask",
action="store_true",
required=False,
help="Whether to use attention mask for the text encoder",
)
parser.add_argument(
"--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -481,6 +537,9 @@ def parse_args(input_args=None): ...@@ -481,6 +537,9 @@ def parse_args(input_args=None):
if args.class_prompt is not None: if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.") warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
if args.train_text_encoder and args.pre_compute_text_embeddings:
raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
return args return args
...@@ -500,10 +559,16 @@ class DreamBoothDataset(Dataset): ...@@ -500,10 +559,16 @@ class DreamBoothDataset(Dataset):
class_num=None, class_num=None,
size=512, size=512,
center_crop=False, center_crop=False,
encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
): ):
self.size = size self.size = size
self.center_crop = center_crop self.center_crop = center_crop
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root) self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists(): if not self.instance_data_root.exists():
...@@ -545,40 +610,52 @@ class DreamBoothDataset(Dataset): ...@@ -545,40 +610,52 @@ class DreamBoothDataset(Dataset):
if not instance_image.mode == "RGB": if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB") instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image) example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt, if self.encoder_hidden_states is not None:
truncation=True, example["instance_prompt_ids"] = self.encoder_hidden_states
padding="max_length", else:
max_length=self.tokenizer.model_max_length, text_inputs = tokenize_prompt(
return_tensors="pt", self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
).input_ids )
example["instance_prompt_ids"] = text_inputs.input_ids
example["instance_attention_mask"] = text_inputs.attention_mask
if self.class_data_root: if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB": if not class_image.mode == "RGB":
class_image = class_image.convert("RGB") class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image) example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt, if self.instance_prompt_encoder_hidden_states is not None:
truncation=True, example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
padding="max_length", else:
max_length=self.tokenizer.model_max_length, class_text_inputs = tokenize_prompt(
return_tensors="pt", self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
).input_ids )
example["class_prompt_ids"] = class_text_inputs.input_ids
example["class_attention_mask"] = class_text_inputs.attention_mask
return example return example
def collate_fn(examples, with_prior_preservation=False): def collate_fn(examples, with_prior_preservation=False):
has_attention_mask = "instance_attention_mask" in examples[0]
input_ids = [example["instance_prompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples] pixel_values = [example["instance_images"] for example in examples]
if has_attention_mask:
attention_mask = [example["instance_attention_mask"] for example in examples]
# Concat class and instance examples for prior preservation. # Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes. # We do this to avoid doing two forward passes.
if with_prior_preservation: if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples] input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples] pixel_values += [example["class_images"] for example in examples]
if has_attention_mask:
attention_mask += [example["class_attention_mask"] for example in examples]
pixel_values = torch.stack(pixel_values) pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
...@@ -588,6 +665,10 @@ def collate_fn(examples, with_prior_preservation=False): ...@@ -588,6 +665,10 @@ def collate_fn(examples, with_prior_preservation=False):
"input_ids": input_ids, "input_ids": input_ids,
"pixel_values": pixel_values, "pixel_values": pixel_values,
} }
if has_attention_mask:
batch["attention_mask"] = attention_mask
return batch return batch
...@@ -608,6 +689,50 @@ class PromptDataset(Dataset): ...@@ -608,6 +689,50 @@ class PromptDataset(Dataset):
return example return example
def model_has_vae(args):
config_file_name = os.path.join("vae", AutoencoderKL.config_name)
if os.path.isdir(args.pretrained_model_name_or_path):
config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
return os.path.isfile(config_file_name)
else:
files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
return any(file.rfilename == config_file_name for file in files_in_repo)
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
if tokenizer_max_length is not None:
max_length = tokenizer_max_length
else:
max_length = tokenizer.model_max_length
text_inputs = tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
return text_inputs
def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
text_input_ids = input_ids.to(text_encoder.device)
if text_encoder_use_attention_mask:
attention_mask = attention_mask.to(text_encoder.device)
else:
attention_mask = None
prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
return prompt_embeds
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -727,7 +852,14 @@ def main(args): ...@@ -727,7 +852,14 @@ def main(args):
text_encoder = text_encoder_cls.from_pretrained( text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
if model_has_vae(args):
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
else:
vae = None
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
) )
...@@ -761,7 +893,9 @@ def main(args): ...@@ -761,7 +893,9 @@ def main(args):
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
vae.requires_grad_(False) if vae is not None:
vae.requires_grad_(False)
if not args.train_text_encoder: if not args.train_text_encoder:
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
...@@ -835,6 +969,44 @@ def main(args): ...@@ -835,6 +969,44 @@ def main(args):
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
if args.pre_compute_text_embeddings:
def compute_text_embeddings(prompt):
with torch.no_grad():
text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
prompt_embeds = encode_prompt(
text_encoder,
text_inputs.input_ids,
text_inputs.attention_mask,
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
return prompt_embeds
pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
if args.validation_prompt is not None:
validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
else:
validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
else:
pre_computed_instance_prompt_encoder_hidden_states = None
text_encoder = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
else:
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation: # Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,
...@@ -845,6 +1017,9 @@ def main(args): ...@@ -845,6 +1017,9 @@ def main(args):
tokenizer=tokenizer, tokenizer=tokenizer,
size=args.resolution, size=args.resolution,
center_crop=args.center_crop, center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length,
) )
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
...@@ -890,8 +1065,10 @@ def main(args): ...@@ -890,8 +1065,10 @@ def main(args):
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
# Move vae and text_encoder to device and cast to weight_dtype # Move vae and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype) if vae is not None:
if not args.train_text_encoder: vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder and text_encoder is not None:
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
...@@ -961,37 +1138,55 @@ def main(args): ...@@ -961,37 +1138,55 @@ def main(args):
continue continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents if vae is not None:
# Convert images to latent space
model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
else:
model_input = pixel_values
# Sample noise that we'll add to the model input
if args.offset_noise: if args.offset_noise:
noise = torch.randn_like(latents) + 0.1 * torch.randn( noise = torch.randn_like(model_input) + 0.1 * torch.randn(
latents.shape[0], latents.shape[1], 1, 1, device=latents.device model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
) )
else: else:
noise = torch.randn_like(latents) noise = torch.randn_like(model_input)
bsz = latents.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Get the text embedding for conditioning # Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0] if args.pre_compute_text_embeddings:
encoder_hidden_states = batch["input_ids"]
else:
encoder_hidden_states = encode_prompt(
text_encoder,
batch["input_ids"],
batch["attention_mask"],
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
# Predict the noise residual # Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample
if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
target = noise target = noise
elif noise_scheduler.config.prediction_type == "v_prediction": elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps) target = noise_scheduler.get_velocity(model_input, noise, timesteps)
else: else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
...@@ -1037,7 +1232,16 @@ def main(args): ...@@ -1037,7 +1232,16 @@ def main(args):
if args.validation_prompt is not None and global_step % args.validation_steps == 0: if args.validation_prompt is not None and global_step % args.validation_steps == 0:
images = log_validation( images = log_validation(
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch text_encoder,
tokenizer,
unet,
vae,
args,
accelerator,
weight_dtype,
epoch,
validation_prompt_encoder_hidden_states,
validation_prompt_negative_prompt_embeds,
) )
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
...@@ -1050,12 +1254,34 @@ def main(args): ...@@ -1050,12 +1254,34 @@ def main(args):
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
pipeline_args = {}
if text_encoder is not None:
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
if args.skip_save_text_encoder:
pipeline_args["text_encoder"] = None
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision, revision=args.revision,
**pipeline_args,
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -147,6 +147,32 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -147,6 +147,32 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_if(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--pre_compute_text_embeddings
--tokenizer_max_length=77
--text_encoder_use_attention_mask
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_checkpointing(self): def test_dreambooth_checkpointing(self):
instance_prompt = "photo" instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
......
...@@ -1507,16 +1507,33 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1507,16 +1507,33 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# resnet if self.training and self.gradient_checkpointing:
hidden_states = resnet(hidden_states, temb)
# attn def create_custom_forward(module, return_dict=None):
hidden_states = attn( def custom_forward(*inputs):
hidden_states, if return_dict is not None:
encoder_hidden_states=encoder_hidden_states, return module(*inputs, return_dict=return_dict)
attention_mask=attention_mask, else:
**cross_attention_kwargs, return module(*inputs)
)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
...@@ -2593,15 +2610,33 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2593,15 +2610,33 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple = res_hidden_states_tuple[:-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb) if self.training and self.gradient_checkpointing:
# attn def create_custom_forward(module, return_dict=None):
hidden_states = attn( def custom_forward(*inputs):
hidden_states, if return_dict is not None:
encoder_hidden_states=encoder_hidden_states, return module(*inputs, return_dict=return_dict)
attention_mask=attention_mask, else:
**cross_attention_kwargs, return module(*inputs)
)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
cross_attention_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment