"...dcu-process-montor.git" did not exist on "b632c3a72b4209edd592405a42b2d344804340a1"
Commit 4e08e0ca authored by Patrick von Platen's avatar Patrick von Platen
Browse files

merge

parents af6c1439 07ff0abf
import argparse
import os
import torch
import torch.nn.functional as F
import bitsandbytes as bnb
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPMScheduler, Glide, GlideUNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.utils import logging
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
logger = logging.get_logger(__name__)
def main(args):
accelerator = Accelerator(mixed_precision=args.mixed_precision)
pipeline = Glide.from_pretrained("fusing/glide-base")
model = pipeline.text_unet
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
augmentations = Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.5], [0.5]),
]
)
dataset = load_dataset(args.dataset, split="train")
text_encoder = pipeline.text_encoder.eval()
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
text_inputs = text_inputs.input_ids.to(accelerator.device)
with torch.no_grad():
text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs).last_hidden_state
return {"images": images, "text_embeddings": text_embeddings}
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)
model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
# Train!
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() if is_distributed else 1
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
for epoch in range(args.num_epochs):
model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
batch_size, n_channels, height, width = clean_images.shape
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
timesteps = torch.randint(
0, noise_scheduler.timesteps, (batch_size,), device=clean_images.device
).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
# predict the noise residual
loss = F.mse_loss(model_output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
optimizer.step()
else:
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
# predict the noise residual
loss = F.mse_loss(model_output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
accelerator.wait_for_everyone()
# Generate a sample image for visual inspection
if accelerator.is_main_process:
model.eval()
with torch.no_grad():
pipeline.unet = accelerator.unwrap_model(model)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline("a clip art of a corgi", generator=generator, num_upscale_inference_steps=50)
# process image to PIL
image_processed = image.squeeze(0)
image_processed = ((image_processed + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(image_processed)
# save image
test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch:04d}.png")
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
parser.add_argument("--output_dir", type=str, default="glide-text2image")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
main(args)
......@@ -4,19 +4,19 @@ import os
import torch
import torch.nn.functional as F
import bitsandbytes as bnb
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetLDMModel
from diffusers import DDPMScheduler, LatentDiffusion, UNetLDMModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.modeling_utils import unwrap_model
from diffusers.optimization import get_scheduler
from diffusers.utils import logging
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Lambda,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
......@@ -30,6 +30,8 @@ logger = logging.get_logger(__name__)
def main(args):
accelerator = Accelerator(mixed_precision=args.mixed_precision)
pipeline = LatentDiffusion.from_pretrained("fusing/latent-diffusion-text2im-large")
pipeline.unet = None # this model will be trained from scratch now
model = UNetLDMModel(
attention_resolutions=[4, 2, 1],
channel_mult=[1, 2, 4, 4],
......@@ -37,7 +39,7 @@ def main(args):
conv_resample=True,
dims=2,
dropout=0,
image_size=32,
image_size=8,
in_channels=4,
model_channels=320,
num_heads=8,
......@@ -51,7 +53,7 @@ def main(args):
legacy=False,
)
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
augmentations = Compose(
[
......@@ -59,14 +61,22 @@ def main(args):
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
Normalize([0.5], [0.5]),
]
)
dataset = load_dataset(args.dataset, split="train")
text_encoder = pipeline.bert.eval()
vqvae = pipeline.vqvae.eval()
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
with torch.no_grad():
text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs.input_ids.cpu()).last_hidden_state
images = 1 / 0.18215 * torch.stack(images, dim=0)
latents = accelerator.unwrap_model(vqvae).encode(images.cpu()).mode()
return {"images": images, "text_embeddings": text_embeddings, "latents": latents}
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
......@@ -78,9 +88,11 @@ def main(args):
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoder.cpu()
vqvae = vqvae.cpu()
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
......@@ -98,29 +110,31 @@ def main(args):
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
global_step = 0
for epoch in range(args.num_epochs):
model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
clean_latents = batch["latents"]
noise_samples = torch.randn(clean_latents.shape).to(clean_latents.device)
bsz = clean_latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_latents.device).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# add noise onto the clean latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
noisy_latents = noise_scheduler.training_step(clean_latents, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
optimizer.step()
else:
output = model(noisy_images, timesteps)
output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
......@@ -131,24 +145,25 @@ def main(args):
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
global_step += 1
optimizer.step()
if is_distributed:
torch.distributed.barrier()
accelerator.wait_for_everyone()
# Generate a sample image for visual inspection
if args.local_rank in [-1, 0]:
if accelerator.is_main_process:
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler)
pipeline.unet = accelerator.unwrap_model(model)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
image = pipeline(
["a clip art of a corgi"], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50
)
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed * 255.0
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
......@@ -162,20 +177,19 @@ def main(args):
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
if is_distributed:
torch.distributed.barrier()
accelerator.wait_for_everyone()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--output_dir", type=str, default="ddpm-model")
parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
parser.add_argument("--output_dir", type=str, default="ldm-text2image")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--resolution", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--push_to_hub", action="store_true")
......
......@@ -7,7 +7,7 @@ import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
from diffusers import DDPMPipeline, DDPMScheduler, UNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
......@@ -71,7 +71,7 @@ def main(args):
model, optimizer, train_dataloader, lr_scheduler
)
ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4)
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
......@@ -133,7 +133,7 @@ def main(args):
# Generate a sample image for visual inspection
if accelerator.is_main_process:
with torch.no_grad():
pipeline = DDPM(
pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler
)
......@@ -172,6 +172,9 @@ if __name__ == "__main__":
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3/4)
parser.add_argument("--ema_max_decay", type=float, default=0.999)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
......
......@@ -64,7 +64,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
......@@ -73,7 +73,7 @@ class Upsample(nn.Module):
self.use_conv_transpose = use_conv_transpose
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1)
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
......@@ -207,6 +207,7 @@ class GradTTSUpsample(torch.nn.Module):
return self.conv(x)
# TODO (patil-suraj): needs test
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
......
......@@ -31,6 +31,7 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
def nonlinearity(x):
......@@ -42,20 +43,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
......@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
......
......@@ -8,6 +8,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
def convert_module_to_f16(l):
......@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
......@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
......@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
......
......@@ -3,6 +3,7 @@ import torch
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
class Mish(torch.nn.Module):
......@@ -10,15 +11,6 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Upsample(torch.nn.Module):
def __init__(self, dim):
super(Upsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Downsample(torch.nn.Module):
def __init__(self, dim):
super(Downsample, self).__init__()
......@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in),
Upsample(dim_in, use_conv_transpose=True),
]
)
)
......
......@@ -10,6 +10,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
# try:
......@@ -403,35 +404,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
......@@ -506,8 +478,8 @@ class ResBlock(TimestepBlock):
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
......@@ -974,7 +946,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
......
......@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample
def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
if timesteps.dim() != 1:
raise ValueError("`timesteps` must be a 1D tensor")
device = original_samples.device
batch_size = original_samples.shape[0]
timesteps = timesteps.reshape(batch_size, 1, 1, 1)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
......
......@@ -14,6 +14,8 @@
import numpy as np
import torch
from typing import Union
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
......@@ -50,3 +52,29 @@ class SchedulerMixin:
return torch.log(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(
self,
values: Union[np.ndarray, torch.Tensor],
broadcast_array: Union[np.ndarray, torch.Tensor]
):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args:
timesteps: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
Returns:
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
tensor_format = getattr(self, "tensor_format", "pt")
values = values.flatten()
while len(values.shape) < len(broadcast_array.shape):
values = values[..., None]
if tensor_format == "pt":
values = values.to(broadcast_array.device)
return values
......@@ -22,6 +22,7 @@ import numpy as np
import torch
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Upsample
from diffusers.testing_utils import floats_tensor, slow, torch_device
......@@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
1e-3,
)
class UpsampleBlockTests(unittest.TestCase):
def test_upsample_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.7145, 1.3773, 0.3492, 0.8448, 1.0839, -0.3341, 0.5956, 0.1250, -0.4841])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_conv_out_dim(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True, out_channels=64)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 64, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.2703, 0.1656, -0.2538, -0.0553, -0.2984, 0.1044, 0.1155, 0.2579, 0.7755])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_transpose(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
......@@ -21,7 +21,7 @@ import unittest
import numpy as np
import torch
from diffusers import (
from diffusers import ( # GradTTSPipeline,
BDDMPipeline,
DDIMPipeline,
DDIMScheduler,
......@@ -30,7 +30,6 @@ from diffusers import (
GlidePipeline,
GlideSuperResUNetModel,
GlideTextToImageUNetModel,
GradTTSPipeline,
GradTTSScheduler,
LatentDiffusionPipeline,
NCSNpp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment