Unverified Commit 44f6bc81 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Don't copy when unwrapping model (#2166)

* Don't copy when unwrapping model.

Otherwise an exception is raised when using fp16.

* Remove unused import
parent 164b6e05
import argparse import argparse
import copy
import inspect import inspect
import logging import logging
import math import math
...@@ -530,7 +529,7 @@ def main(args): ...@@ -530,7 +529,7 @@ def main(args):
# Generate sample images for visual inspection # Generate sample images for visual inspection
if accelerator.is_main_process: if accelerator.is_main_process:
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
unet = copy.deepcopy(accelerator.unwrap_model(model)) unet = accelerator.unwrap_model(model)
if args.use_ema: if args.use_ema:
ema_model.copy_to(unet.parameters()) ema_model.copy_to(unet.parameters())
pipeline = DDPMPipeline( pipeline = DDPMPipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment