"llm/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "2ff45d571de4463fcebf779373ae7337cf969ebf"
Unverified Commit 5a7d35e2 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Fix InstructPix2Pix training in multi-GPU mode (#2978)

* fix: norm group test for UNet3D.

* fix: unet rejig.

* fix: unwrapping when running validation inputs.

* unwrapping the unet too.

* fix: device.

* better unwrapping.

* unwrapping before ema.

* unwrapping.
parent 0c72006e
...@@ -451,19 +451,18 @@ def main(): ...@@ -451,19 +451,18 @@ def main():
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
# initialized to zero. # initialized to zero.
if accelerator.is_main_process: logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") in_channels = 8
in_channels = 8 out_channels = unet.conv_in.out_channels
out_channels = unet.conv_in.out_channels unet.register_to_config(in_channels=in_channels)
unet.register_to_config(in_channels=in_channels)
with torch.no_grad():
with torch.no_grad(): new_conv_in = nn.Conv2d(
new_conv_in = nn.Conv2d( in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding )
) new_conv_in.weight.zero_()
new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) unet.conv_in = new_conv_in
unet.conv_in = new_conv_in
# Freeze vae and text_encoder # Freeze vae and text_encoder
vae.requires_grad_(False) vae.requires_grad_(False)
...@@ -892,9 +891,12 @@ def main(): ...@@ -892,9 +891,12 @@ def main():
# Store the UNet parameters temporarily and load the EMA parameters to perform inference. # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters()) ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters()) ema_unet.copy_to(unet.parameters())
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=unet, unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
revision=args.revision, revision=args.revision,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
...@@ -904,7 +906,9 @@ def main(): ...@@ -904,7 +906,9 @@ def main():
# run inference # run inference
original_image = download_image(args.val_image_url) original_image = download_image(args.val_image_url)
edited_images = [] edited_images = []
with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): with torch.autocast(
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
):
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
edited_images.append( edited_images.append(
pipeline( pipeline(
...@@ -959,7 +963,7 @@ def main(): ...@@ -959,7 +963,7 @@ def main():
if args.validation_prompt is not None: if args.validation_prompt is not None:
edited_images = [] edited_images = []
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device)): with torch.autocast(str(accelerator.device).replace(":0", "")):
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
edited_images.append( edited_images.append(
pipeline( pipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment