Unverified Commit cbcd0512 authored by Denis's avatar Denis Committed by GitHub
Browse files

Training to predict x0 in training example (#1031)



* changed training example to add option to train model that predicts x0 (instead of eps), changed DDPM pipeline accordingly

* Revert "changed training example to add option to train model that predicts x0 (instead of eps), changed DDPM pipeline accordingly"

This reverts commit c5efb525648885f2e7df71f4483a9f248515ad61.

* changed training example to add option to train model that predicts x0 (instead of eps), changed DDPM pipeline accordingly

* fixed code style
Co-authored-by: default avatarlukovnikov <lukovnikov@users.noreply.github.com>
parent 0b61cea3
...@@ -29,6 +29,24 @@ from tqdm.auto import tqdm ...@@ -29,6 +29,24 @@ from tqdm.auto import tqdm
logger = get_logger(__name__) logger = get_logger(__name__)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
if not isinstance(arr, torch.Tensor):
arr = torch.from_numpy(arr)
res = arr[timesteps].float().to(timesteps.device)
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument( parser.add_argument(
...@@ -171,6 +189,16 @@ def parse_args(): ...@@ -171,6 +189,16 @@ def parse_args():
), ),
) )
parser.add_argument(
"--predict_mode",
type=str,
default="eps",
help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction",
)
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank: if env_local_rank != -1 and env_local_rank != args.local_rank:
...@@ -224,7 +252,7 @@ def main(args): ...@@ -224,7 +252,7 @@ def main(args):
"UpBlock2D", "UpBlock2D",
), ),
) )
noise_scheduler = DDPMScheduler(num_train_timesteps=1000) noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
model.parameters(), model.parameters(),
lr=args.learning_rate, lr=args.learning_rate,
...@@ -257,6 +285,8 @@ def main(args): ...@@ -257,6 +285,8 @@ def main(args):
images = [augmentations(image.convert("RGB")) for image in examples["image"]] images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images} return {"input": images}
logger.info(f"Dataset size: {len(dataset)}")
dataset.set_transform(transforms) dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
...@@ -319,8 +349,20 @@ def main(args): ...@@ -319,8 +349,20 @@ def main(args):
with accelerator.accumulate(model): with accelerator.accumulate(model):
# Predict the noise residual # Predict the noise residual
noise_pred = model(noisy_images, timesteps).sample model_output = model(noisy_images, timesteps).sample
loss = F.mse_loss(noise_pred, noise)
if args.predict_mode == "eps":
loss = F.mse_loss(model_output, noise) # this could have different weights!
elif args.predict_mode == "x0":
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
snr_weights = alpha_t / (1 - alpha_t)
loss = snr_weights * F.mse_loss(
model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper
loss = loss.mean()
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
...@@ -355,7 +397,12 @@ def main(args): ...@@ -355,7 +397,12 @@ def main(args):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images images = pipeline(
generator=generator,
batch_size=args.eval_batch_size,
output_type="numpy",
predict_epsilon=args.predict_mode == "eps",
).images
# denormalize the images and save to tensorboard # denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8") images_processed = (images * 255).round().astype("uint8")
......
...@@ -45,6 +45,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -45,6 +45,7 @@ class DDPMPipeline(DiffusionPipeline):
num_inference_steps: int = 1000, num_inference_steps: int = 1000,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
predict_epsilon: bool = True,
**kwargs, **kwargs,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
r""" r"""
...@@ -84,7 +85,9 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -84,7 +85,9 @@ class DDPMPipeline(DiffusionPipeline):
model_output = self.unet(image, t).sample model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> x_t-1 # 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample image = self.scheduler.step(
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
).prev_sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment