"docs/source/vscode:/vscode.git/clone" did not exist on "311bd88a040d42619f5eccd5a3f53852ec2be3e1"
Commit abcb2597 authored by patil-suraj's avatar patil-suraj
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents 183056f2 c991ffd4
...@@ -9,14 +9,14 @@ from accelerate import Accelerator ...@@ -9,14 +9,14 @@ from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.modeling_utils import unwrap_model
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import logging from diffusers.utils import logging
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
Compose, Compose,
InterpolationMode, InterpolationMode,
Lambda, Normalize,
RandomHorizontalFlip, RandomHorizontalFlip,
Resize, Resize,
ToTensor, ToTensor,
...@@ -48,7 +48,7 @@ def main(args): ...@@ -48,7 +48,7 @@ def main(args):
CenterCrop(args.resolution), CenterCrop(args.resolution),
RandomHorizontalFlip(), RandomHorizontalFlip(),
ToTensor(), ToTensor(),
Lambda(lambda x: x * 2 - 1), Normalize([0.5], [0.5]),
] ]
) )
dataset = load_dataset(args.dataset, split="train") dataset = load_dataset(args.dataset, split="train")
...@@ -71,6 +71,8 @@ def main(args): ...@@ -71,6 +71,8 @@ def main(args):
model, optimizer, train_dataloader, lr_scheduler model, optimizer, train_dataloader, lr_scheduler
) )
ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4)
if args.push_to_hub: if args.push_to_hub:
repo = init_git_repo(args, at_init=True) repo = init_git_repo(args, at_init=True)
...@@ -87,6 +89,7 @@ def main(args): ...@@ -87,6 +89,7 @@ def main(args):
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Total optimization steps = {max_steps}")
global_step = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
model.train() model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar: with tqdm(total=len(train_dataloader), unit="ba") as pbar:
...@@ -117,19 +120,22 @@ def main(args): ...@@ -117,19 +120,22 @@ def main(args):
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
ema_model.step(model, global_step)
optimizer.zero_grad() optimizer.zero_grad()
pbar.update(1) pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) pbar.set_postfix(
loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
)
global_step += 1
optimizer.step() accelerator.wait_for_everyone()
if is_distributed:
torch.distributed.barrier()
# Generate a sample image for visual inspection # Generate a sample image for visual inspection
if args.local_rank in [-1, 0]: if accelerator.is_main_process:
model.eval()
with torch.no_grad(): with torch.no_grad():
pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler) pipeline = DDPM(
unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler
)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
...@@ -151,8 +157,7 @@ def main(args): ...@@ -151,8 +157,7 @@ def main(args):
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else: else:
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if is_distributed: accelerator.wait_for_everyone()
torch.distributed.barrier()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -82,61 +82,62 @@ def Normalize(in_channels): ...@@ -82,61 +82,62 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module): #class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32): # def __init__(self, dim, heads=4, dim_head=32):
super().__init__() # super().__init__()
self.heads = heads # self.heads = heads
hidden_dim = dim_head * heads # hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) # self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1) # self.to_out = nn.Conv2d(hidden_dim, dim, 1)
#
def forward(self, x): # def forward(self, x):
b, c, h, w = x.shape # b, c, h, w = x.shape
qkv = self.to_qkv(x) # qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k = k.softmax(dim=-1) # import ipdb; ipdb.set_trace()
context = torch.einsum("bhdn,bhen->bhde", k, v) # k = k.softmax(dim=-1)
out = torch.einsum("bhde,bhdn->bhen", context, q) # context = torch.einsum("bhdn,bhen->bhde", k, v)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) # out = torch.einsum("bhde,bhdn->bhen", context, q)
return self.to_out(out) # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# return self.to_out(out)
#
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels): #class SpatialSelfAttention(nn.Module):
super().__init__() # def __init__(self, in_channels):
self.in_channels = in_channels # super().__init__()
# self.in_channels = in_channels
self.norm = Normalize(in_channels) #
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.norm = Normalize(in_channels)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x): #
h_ = x # def forward(self, x):
h_ = self.norm(h_) # h_ = x
q = self.q(h_) # h_ = self.norm(h_)
k = self.k(h_) # q = self.q(h_)
v = self.v(h_) # k = self.k(h_)
# v = self.v(h_)
#
# compute attention # compute attention
b, c, h, w = q.shape # b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c") # q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)") # k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k) # w_ = torch.einsum("bij,bjk->bik", q, k)
#
w_ = w_ * (int(c) ** (-0.5)) # w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2) # w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values # attend to values
v = rearrange(v, "b c h w -> b c (h w)") # v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i") # w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_) # h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) # h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_) # h_ = self.proj_out(h_)
#
return x + h_ # return x + h_
#
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
......
import copy
import torch
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(
self,
model,
update_after_step=0,
inv_gamma=1.0,
power=2 / 3,
min_value=0.0,
max_value=0.9999,
device=None,
):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = copy.deepcopy(model)
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
if device is not None:
self.averaged_model = self.averaged_model.to(device=device)
self.decay = 0.0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma) ** -self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model, optimization_step):
ema_state_dict = {}
ema_params = self.averaged_model.state_dict()
self.decay = self.get_decay(optimization_step)
for key, param in new_model.named_parameters():
if isinstance(param, dict):
continue
try:
ema_param = ema_params[key]
except KeyError:
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
ema_params[key] = ema_param
if not param.requires_grad:
ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
ema_param = ema_params[key]
else:
ema_param.mul_(self.decay)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
ema_state_dict[key] = ema_param
for key, param in new_model.named_buffers():
ema_state_dict[key] = param
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
...@@ -510,6 +510,28 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -510,6 +510,28 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_output_pretrained_spatial_transformer(self):
model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
context = torch.ones((1, 16, 64), dtype=torch.float32)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step, context=context)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetGradTTSModel model_class = UNetGradTTSModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment