"tests/compute/test_specialization.py" did not exist on "2c489fadec58fd6be7f555897b222cdef31d98b5"
Commit 7339f0b0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
def extract(v, t, x_shape):
"""
Extract some coefficients at specified timesteps, then reshape to
[batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
device = t.device
out = torch.gather(v, index=t, dim=0).float().to(device)
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
def forward(self, x_0):
"""
Algorithm 1.
"""
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
noise = torch.randn_like(x_0)
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
class GaussianDiffusionSampler(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
self.register_buffer('coeff1', torch.sqrt(1. / alphas))
self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
extract(self.coeff1, t, x_t.shape) * x_t -
extract(self.coeff2, t, x_t.shape) * eps
)
def p_mean_variance(self, x_t, t):
# below: only log_variance is used in the KL computations
var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
var = extract(var, t, x_t.shape)
eps = self.model(x_t, t)
xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
return xt_prev_mean, var
def forward(self, x_T):
"""
Algorithm 2.
"""
x_t = x_T
print('Start Sampling')
for time_step in tqdm(reversed(range(self.T))):
t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
mean, var= self.p_mean_variance(x_t=x_t, t=t)
# no noise when t == 0
if time_step > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
x_t = mean + torch.sqrt(var) * noise
assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
x_0 = x_t
return torch.clip(x_0, -1, 1)
This diff is collapsed.
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from Diffusion.kan_utils.fastkanconv import FastKANConvLayer, SplineConv2D
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb)
pos = torch.arange(T).float()
emb = pos[:, None] * emb[None, :]
assert list(emb.shape) == [T, d_model // 2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, nn.Linear):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
def forward(self, t):
emb = self.timembedding(t)
return emb
class DownSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
# self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
self.main = FastKANConvLayer(in_ch, in_ch, 3, stride=2, padding=1)
# self.initialize()
def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)
def forward(self, x, temb):
x = self.main(x)
return x
class UpSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
# self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
self.main = FastKANConvLayer(in_ch, in_ch, 3, stride=1, padding=1)
# self.initialize()
def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)
def forward(self, x, temb):
_, _, H, W = x.shape
x = F.interpolate(
x, scale_factor=2, mode='nearest')
x = self.main(x)
return x
class AttnBlock(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.group_norm = nn.GroupNorm(32, in_ch)
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.initialize()
def initialize(self):
for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.proj.weight, gain=1e-5)
def forward(self, x):
B, C, H, W = x.shape
h = self.group_norm(x)
q = self.proj_q(h)
k = self.proj_k(h)
v = self.proj_v(h)
q = q.permute(0, 2, 3, 1).view(B, H * W, C)
k = k.view(B, C, H * W)
w = torch.bmm(q, k) * (int(C) ** (-0.5))
assert list(w.shape) == [B, H * W, H * W]
w = F.softmax(w, dim=-1)
v = v.permute(0, 2, 3, 1).view(B, H * W, C)
h = torch.bmm(w, v)
assert list(h.shape) == [B, H * W, C]
h = h.view(B, H, W, C).permute(0, 3, 1, 2)
h = self.proj(h)
return x + h
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(32, in_ch),
# Swish(),
# nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
FastKANConvLayer(in_ch, out_ch, 3, stride=1, padding=1),
)
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.GroupNorm(32, out_ch),
# Swish(),
nn.Dropout(dropout),
# nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
FastKANConvLayer(out_ch, out_ch, 3, stride=1, padding=1),
)
if in_ch != out_ch:
# self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
self.shortcut = FastKANConvLayer(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = AttnBlock(out_ch)
else:
self.attn = nn.Identity()
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)) and not isinstance(module, (SplineConv2D)):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
# init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
def forward(self, x, temb):
h = self.block1(x)
h += self.temb_proj(temb)[:, :, None, None]
h = self.block2(h)
h = h + self.shortcut(x)
h = self.attn(h)
return h
# return x
class UNet_ConvKan(nn.Module):
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
super().__init__()
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
tdim = ch * 4
self.time_embedding = TimeEmbedding(T, ch, tdim)
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
# Swish(),
# nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
FastKANConvLayer(now_ch, 3, 3, stride=1, padding=1)
)
# self.initialize()
def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)
def forward(self, x, t):
# Timestep embedding
temb = self.time_embedding(t)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb)
hs.append(h)
# Middle
for layer in self.middleblocks:
h = layer(h, temb)
# Upsampling
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb)
h = self.tail(h)
assert len(hs) == 0
return h
if __name__ == '__main__':
batch_size = 8
model = UNet_ConvKan(
T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],
num_res_blocks=2, dropout=0.1)
x = torch.randn(batch_size, 3, 32, 32)
t = torch.randint(1000, (batch_size, ))
y = model(x, t)
print(y.shape)
This diff is collapsed.
This diff is collapsed.
import os
from typing import Dict
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms, transforms
# from torchvision.datasets import CIFAR10
from torchvision.utils import save_image
from Diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
from Diffusion.UNet import UNet, UNe t_Baseline
from Diffusion.Model_ConvKan import UNet_ConvKan
from Diffusion.Model_UMLP import UMLP
from Diffusion.Model_UKAN_Hybrid import UKan_Hybrid
from Scheduler import GradualWarmupScheduler
from skimage import io
import os
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import Dataset
import sys
model_dict = {
'UNet': UNet,
'UNet_ConvKan': UNet_ConvKan, # dose not work
'UMLP': UMLP,
'UKan_Hybrid': UKan_Hybrid,
'UNet_Baseline': UNet_Baseline,
}
class UnlabeledDataset(Dataset):
def __init__(self, folder, transform=None, repeat_n=1):
self.folder = folder
self.transform = transform
# self.image_files = os.listdir(folder) * repeat_n
self.image_files = os.listdir(folder)
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_file = self.image_files[idx]
image_path = os.path.join(self.folder, image_file)
image = io.imread(image_path)
if self.transform:
image = self.transform(image)
return image, torch.Tensor([0])
def train(modelConfig: Dict):
device = torch.device(modelConfig["device"])
log_print = True
if log_print:
file = open(modelConfig["save_weight_dir"]+'log.txt', "w")
sys.stdout = file
transform = Compose([
ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
if modelConfig["dataset"] == 'cvc':
dataset = UnlabeledDataset('data/cvc/images_64/', transform=transform, repeat_n=modelConfig["dataset_repeat"])
elif modelConfig["dataset"] == 'glas':
dataset = UnlabeledDataset('data/glas/images_64/', transform=transform, repeat_n=modelConfig["dataset_repeat"])
elif modelConfig["dataset"] == 'glas_resize':
dataset = UnlabeledDataset('data/glas/images_64_resize/', transform=transform, repeat_n=modelConfig["dataset_repeat"])
elif modelConfig["dataset"] == 'busi':
dataset = UnlabeledDataset('data/busi/images_64/', transform=transform, repeat_n=modelConfig["dataset_repeat"])
else:
raise ValueError('dataset not found')
print('modelConfig: ')
for key, value in modelConfig.items():
print(key, ' : ', value)
dataloader = DataLoader(
dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
print('Using {}'.format(modelConfig["model"]))
# model setup
net_model =model_dict[modelConfig["model"]](T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
if modelConfig["training_load_weight"] is not None:
net_model.load_state_dict(torch.load(os.path.join(
modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
optimizer = torch.optim.AdamW(
net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
warmUpScheduler = GradualWarmupScheduler(
optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
trainer = GaussianDiffusionTrainer(
net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
# start training
for e in range(1,modelConfig["epoch"]+1):
with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
for images, labels in tqdmDataLoader:
# train
optimizer.zero_grad()
x_0 = images.to(device)
loss = trainer(x_0).sum() / 1000.
loss.backward()
torch.nn.utils.clip_grad_norm_(
net_model.parameters(), modelConfig["grad_clip"])
optimizer.step()
tqdmDataLoader.set_postfix(ordered_dict={
"epoch": e,
"loss: ": loss.item(),
"img shape: ": x_0.shape,
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
})
# print version
if log_print:
print("epoch: ", e, "loss: ", loss.item(), "img shape: ", x_0.shape, "LR: ", optimizer.state_dict()['param_groups'][0]["lr"])
warmUpScheduler.step()
if e % 50 ==0:
torch.save(net_model.state_dict(), os.path.join(
modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt"))
modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(e)
eval_tmp(modelConfig, e)
torch.save(net_model.state_dict(), os.path.join(
modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt"))
if log_print:
file.close()
sys.stdout = sys.__stdout__
def eval_tmp(modelConfig: Dict, nme: int):
# load model and evaluate
with torch.no_grad():
device = torch.device(modelConfig["device"])
model = model_dict[modelConfig["model"]](T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=0.)
ckpt = torch.load(os.path.join(
modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)
model.load_state_dict(ckpt)
print("model load weight done.")
model.eval()
sampler = GaussianDiffusionSampler(
model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
# Sampled from standard normal distribution
noisyImage = torch.randn(
size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
# saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
# save_image(saveNoisy, os.path.join(
# modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
sampledImgs = sampler(noisyImage)
sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1]
save_root = modelConfig["sampled_dir"].replace('Gens','Tmp')
os.makedirs(save_root, exist_ok=True)
save_image(sampledImgs, os.path.join(
save_root, modelConfig["sampledImgName"].replace('.png','_{}.png').format(nme)), nrow=modelConfig["nrow"])
if nme < 0.95 * modelConfig["epoch"]:
os.remove(os.path.join(
modelConfig["save_weight_dir"], modelConfig["test_load_weight"]))
def eval(modelConfig: Dict):
# load model and evaluate
with torch.no_grad():
device = torch.device(modelConfig["device"])
model = model_dict[modelConfig["model"]](T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
ckpt = torch.load(os.path.join(
modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)
model.load_state_dict(ckpt)
print("model load weight done.")
model.eval()
sampler = GaussianDiffusionSampler(
model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
# Sampled from standard normal distribution
noisyImage = torch.randn(
size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
# saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
# save_image(saveNoisy, os.path.join(
# modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
sampledImgs = sampler(noisyImage)
sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1]
for i, image in enumerate(sampledImgs):
save_image(image, os.path.join(modelConfig["sampled_dir"], modelConfig["sampledImgName"].replace('.png','_{}.png').format(i)), nrow=modelConfig["nrow"])
\ No newline at end of file
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb)
pos = torch.arange(T).float()
emb = pos[:, None] * emb[None, :]
assert list(emb.shape) == [T, d_model // 2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, nn.Linear):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
def forward(self, t):
emb = self.timembedding(t)
return emb
class DownSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)
def forward(self, x, temb):
x = self.main(x)
return x
class UpSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)
def forward(self, x, temb):
_, _, H, W = x.shape
x = F.interpolate(
x, scale_factor=2, mode='nearest')
x = self.main(x)
return x
class AttnBlock(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.group_norm = nn.GroupNorm(32, in_ch)
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.initialize()
def initialize(self):
for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.proj.weight, gain=1e-5)
def forward(self, x):
B, C, H, W = x.shape
h = self.group_norm(x)
q = self.proj_q(h)
k = self.proj_k(h)
v = self.proj_v(h)
q = q.permute(0, 2, 3, 1).view(B, H * W, C)
k = k.view(B, C, H * W)
w = torch.bmm(q, k) * (int(C) ** (-0.5))
assert list(w.shape) == [B, H * W, H * W]
w = F.softmax(w, dim=-1)
v = v.permute(0, 2, 3, 1).view(B, H * W, C)
h = torch.bmm(w, v)
assert list(h.shape) == [B, H * W, C]
h = h.view(B, H, W, C).permute(0, 3, 1, 2)
h = self.proj(h)
return x + h
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(32, in_ch),
Swish(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
)
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.GroupNorm(32, out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = AttnBlock(out_ch)
else:
self.attn = nn.Identity()
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
def forward(self, x, temb):
h = self.block1(x)
h += self.temb_proj(temb)[:, :, None, None]
h = self.block2(h)
h = h + self.shortcut(x)
h = self.attn(h)
return h
class UNet(nn.Module):
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
super().__init__()
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
tdim = ch * 4
self.time_embedding = TimeEmbedding(T, ch, tdim)
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)
def forward(self, x, t):
# Timestep embedding
temb = self.time_embedding(t)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb)
hs.append(h)
# Middle
# torch.Size([8, 512, 4, 4])
for layer in self.middleblocks:
h = layer(h, temb)
# torch.Size([8, 512, 4, 4])
# Upsampling
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb)
h = self.tail(h)
assert len(hs) == 0
return h
class UNet_Baseline(nn.Module):
# Remove the middle blocks
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
super().__init__()
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
tdim = ch * 4
self.time_embedding = TimeEmbedding(T, ch, tdim)
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)
def forward(self, x, t):
# Timestep embedding
temb = self.time_embedding(t)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb)
hs.append(h)
# Upsampling
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb)
h = self.tail(h)
assert len(hs) == 0
return h
if __name__ == '__main__':
batch_size = 8
model = UNet(
T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[1],
num_res_blocks=2, dropout=0.1)
\ No newline at end of file
from .Diffusion import *
from .UNet import *
from .Train import *
from .kan import *
from .fastkanconv import *
# from .kan_convolutional import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Union
class PolynomialFunction(nn.Module):
def __init__(self,
degree: int = 3):
super().__init__()
self.degree = degree
def forward(self, x):
return torch.stack([x ** i for i in range(self.degree)], dim=-1)
class BSplineFunction(nn.Module):
def __init__(self, grid_min: float = -2.,
grid_max: float = 2., degree: int = 3, num_basis: int = 8):
super(BSplineFunction, self).__init__()
self.degree = degree
self.num_basis = num_basis
self.knots = torch.linspace(grid_min, grid_max, num_basis + degree + 1) # Uniform knots
def basis_function(self, i, k, t):
if k == 0:
return ((self.knots[i] <= t) & (t < self.knots[i + 1])).float()
else:
left_num = (t - self.knots[i]) * self.basis_function(i, k - 1, t)
left_den = self.knots[i + k] - self.knots[i]
left = left_num / left_den if left_den != 0 else 0
right_num = (self.knots[i + k + 1] - t) * self.basis_function(i + 1, k - 1, t)
right_den = self.knots[i + k + 1] - self.knots[i + 1]
right = right_num / right_den if right_den != 0 else 0
return left + right
def forward(self, x):
x = x.squeeze() # Assuming x is of shape (B, 1)
basis_functions = torch.stack([self.basis_function(i, self.degree, x) for i in range(self.num_basis)], dim=-1)
return basis_functions
class ChebyshevFunction(nn.Module):
def __init__(self, degree: int = 4):
super(ChebyshevFunction, self).__init__()
self.degree = degree
def forward(self, x):
chebyshev_polynomials = [torch.ones_like(x), x]
for n in range(2, self.degree):
chebyshev_polynomials.append(2 * x * chebyshev_polynomials[-1] - chebyshev_polynomials[-2])
return torch.stack(chebyshev_polynomials, dim=-1)
class FourierBasisFunction(nn.Module):
def __init__(self,
num_frequencies: int = 4,
period: float = 1.0):
super(FourierBasisFunction, self).__init__()
assert num_frequencies % 2 == 0, "num_frequencies must be even"
self.num_frequencies = num_frequencies
self.period = nn.Parameter(torch.Tensor([period]), requires_grad=False)
def forward(self, x):
frequencies = torch.arange(1, self.num_frequencies // 2 + 1, device=x.device)
sin_components = torch.sin(2 * torch.pi * frequencies * x[..., None] / self.period)
cos_components = torch.cos(2 * torch.pi * frequencies * x[..., None] / self.period)
basis_functions = torch.cat([sin_components, cos_components], dim=-1)
return basis_functions
class RadialBasisFunction(nn.Module):
def __init__(
self,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 4,
denominator: float = None,
):
super().__init__()
grid = torch.linspace(grid_min, grid_max, num_grids)
self.grid = torch.nn.Parameter(grid, requires_grad=False)
self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)
def forward(self, x):
return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
class SplineConv2D(nn.Conv2d):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]] = 3,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
init_scale: float = 0.1,
padding_mode: str = "zeros",
**kw
) -> None:
self.init_scale = init_scale
super().__init__(in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
**kw
)
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)
if self.bias is not None:
nn.init.zeros_(self.bias)
class FastKANConvLayer(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]] = 3,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 4,
use_base_update: bool = True,
base_activation = F.silu,
spline_weight_init_scale: float = 0.1,
padding_mode: str = "zeros",
kan_type: str = "BSpline",
# kan_type: str = "RBF",
) -> None:
super().__init__()
if kan_type == "RBF":
self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
elif kan_type == "Fourier":
self.rbf = FourierBasisFunction(num_grids)
elif kan_type == "Poly":
self.rbf = PolynomialFunction(num_grids)
elif kan_type == "Chebyshev":
self.rbf = ChebyshevFunction(num_grids)
elif kan_type == "BSpline":
self.rbf = BSplineFunction(grid_min, grid_max, 3, num_grids)
self.spline_conv = SplineConv2D(in_channels * num_grids,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
spline_weight_init_scale,
padding_mode)
self.use_base_update = use_base_update
if use_base_update:
self.base_activation = base_activation
self.base_conv = nn.Conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode)
def forward(self, x):
batch_size, channels, height, width = x.shape
x_rbf = self.rbf(x.view(batch_size, channels, -1)).view(batch_size, channels, height, width, -1)
x_rbf = x_rbf.permute(0, 4, 1, 2, 3).contiguous().view(batch_size, -1, height, width)
# Apply spline convolution
ret = self.spline_conv(x_rbf)
if self.use_base_update:
base = self.base_conv(self.base_activation(x))
ret = ret + base
return ret
import torch
import torch.nn.functional as F
import math
class KANLinear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
assert x.dim() == 2 and x.size(1) == self.in_features
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
return base_output + spline_output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)
splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
class KAN(torch.nn.Module):
def __init__(
self,
layers_hidden,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KAN, self).__init__()
self.grid_size = grid_size
self.spline_order = spline_order
self.layers = torch.nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANLinear(
in_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
)
def forward(self, x: torch.Tensor, update_grid=False):
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return sum(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
)
import argparse
import torch.nn as nn
class qkv_transform(nn.Conv1d):
"""Conv1d for qkv_transform"""
def str2bool(v):
if v.lower() in ['true', 1]:
return True
elif v.lower() in ['false', 0]:
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
from Diffusion.Train import train, eval
import os
import argparse
import torch
import numpy as np
def main(model_config = None):
if model_config is not None:
modelConfig = model_config
if modelConfig["state"] == "train":
train(modelConfig)
modelConfig['batch_size'] = 64
modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(modelConfig['epoch'])
for i in range(32):
modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
eval(modelConfig)
else:
for i in range(32):
modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
eval(modelConfig)
def seed_all(args):
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--state', type=str, default='train') # train or eval
parser.add_argument('--dataset', type=str, default='cvc') # busi, glas, cvc
parser.add_argument('--epoch', type=int, default=1000) # 1000 for cvc/glas, 5000 for busi
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--T', type=int, default=1000)
parser.add_argument('--channel', type=int, default=64) # 64 or 128
parser.add_argument('--test_load_weight', type=str, default='ckpt_1000_.pt')
parser.add_argument('--num_res_blocks', type=int, default=2)
parser.add_argument('--dropout', type=float, default=0.15)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--img_size', type=float, default=64)
parser.add_argument('--dataset_repeat', type=int, default=1) # did not use
parser.add_argument('--seed', type=int, default=0) # did not use
parser.add_argument('--model', type=str, default='UKAN_Hybrid')
parser.add_argument('--exp_nme', type=str, default='UKAN_Hybrid')
parser.add_argument('--save_root', type=str, default='./Output/')
args = parser.parse_args()
save_root = args.save_root
if args.seed != 0:
seed_all(args)
modelConfig = {
"dataset": args.dataset,
"state": args.state, # or eval
"epoch": args.epoch,
"batch_size": args.batch_size,
"T": args.T,
"channel": args.channel,
"channel_mult": [1, 2, 3, 4],
"attn": [2],
"num_res_blocks": args.num_res_blocks,
"dropout": args.dropout,
"lr": args.lr,
"multiplier": 2.,
"beta_1": 1e-4,
"beta_T": 0.02,
"img_size": 64,
"grad_clip": 1.,
"device": "cuda", ### MAKE SURE YOU HAVE A GPU !!!
"training_load_weight": None,
"save_weight_dir": os.path.join(save_root, args.exp_nme, "Weights"),
"sampled_dir": os.path.join(save_root, args.exp_nme, "Gens"),
"test_load_weight": args.test_load_weight,
"sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
"sampledImgName": "SampledNoGuidenceImgs.png",
"nrow": 8,
"model":args.model,
"version": 1,
"dataset_repeat": args.dataset_repeat,
"seed": args.seed,
"save_root": args.save_root,
}
os.makedirs(modelConfig["save_weight_dir"], exist_ok=True)
os.makedirs(modelConfig["sampled_dir"], exist_ok=True)
# backup
import shutil
shutil.copy("Diffusion/Model_UKAN_Hybrid.py", os.path.join(save_root, args.exp_nme))
shutil.copy("Diffusion/Train.py", os.path.join(save_root, args.exp_nme))
main(modelConfig)
from Diffusion.Train import train, eval, eval_tmp
import os
import argparse
import torch
def main(model_config = None):
if model_config is not None:
modelConfig = model_config
if modelConfig["state"] == "train":
train(modelConfig)
modelConfig['batch_size'] = 64
modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(modelConfig['epoch'])
for i in range(32):
modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
eval(modelConfig)
else:
for i in range(1):
modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
eval_tmp(modelConfig,1000) # for grid visualization
# eval(modelConfig) # for metric evaluation
def seed_all(args):
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import numpy as np
np.random.seed(args.seed)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--state', type=str, default='eval')
parser.add_argument('--dataset', type=str, default='cvc') # busi, glas, cvc
parser.add_argument('--epoch', type=int, default=1000)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--T', type=int, default=1000)
parser.add_argument('--channel', type=int, default=64)
parser.add_argument('--test_load_weight', type=str, default='ckpt_1000_.pt')
parser.add_argument('--num_res_blocks', type=int, default=2)
parser.add_argument('--dropout', type=float, default=0.15)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--img_size', type=float, default=64) # 64 or 128
parser.add_argument('--dataset_repeat', type=int, default=1) # didnot use
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--model', type=str, default='UKan_Hybrid')
parser.add_argument('--exp_nme', type=str, default='./')
parser.add_argument('--save_root', type=str, default='released_models/ukan_cvc')
# parser.add_argument('--save_root', type=str, default='released_models/ukan_glas')
# parser.add_argument('--save_root', type=str, default='released_models/ukan_busi')
args = parser.parse_args()
save_root = args.save_root
if args.seed != 0:
seed_all(args)
modelConfig = {
"dataset": args.dataset,
"state": args.state, # or eval
"epoch": args.epoch,
"batch_size": args.batch_size,
"T": args.T,
"channel": args.channel,
"channel_mult": [1, 2, 3, 4],
"attn": [2],
"num_res_blocks": args.num_res_blocks,
"dropout": args.dropout,
"lr": args.lr,
"multiplier": 2.,
"beta_1": 1e-4,
"beta_T": 0.02,
"img_size": 64,
"grad_clip": 1.,
"device": "cuda", ### MAKE SURE YOU HAVE A GPU !!!
"training_load_weight": None,
"save_weight_dir": os.path.join(save_root, args.exp_nme, "Weights"),
"sampled_dir": os.path.join(save_root, args.exp_nme, "FinalCheck"),
"test_load_weight": args.test_load_weight,
"sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
"sampledImgName": "SampledNoGuidenceImgs.png",
"nrow": 8,
"model":args.model,
"version": 1,
"dataset_repeat": args.dataset_repeat,
"seed": args.seed,
"save_root": args.save_root,
}
os.makedirs(modelConfig["save_weight_dir"], exist_ok=True)
os.makedirs(modelConfig["sampled_dir"], exist_ok=True)
main(modelConfig)
# Diffusion UKAN (arxiv)
> [**U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**](https://arxiv.org/abs/2406.02918)<br>
> [Chenxin Li](https://xggnet.github.io/)\*, [Xinyu Liu](https://xinyuliu-jeffrey.github.io/)\*, [Wuyang Li](https://wymancv.github.io/wuyang.github.io/)\*, [Cheng Wang](https://scholar.google.com/citations?user=AM7gvyUAAAAJ&hl=en)\*, [Hengyu Liu](), [Yixuan Yuan](https://www.ee.cuhk.edu.hk/~yxyuan/people/people.htm)<sup>✉</sup><br>The Chinese Univerisity of Hong Kong
Contact: wuyangli@cuhk.edu.hk
## 💡 Environment
You can change the torch and Cuda versions to satisfy your device.
```bash
conda create --name UKAN python=3.10
conda activate UKAN
conda install cudatoolkit=11.3
pip install -r requirement.txt
```
## 🖼️ Gallery of Diffusion UKAN
![image](./assets/gen.png)
## 📚 Prepare datasets
Download the pre-processed dataset from [Onedrive](https://gocuhk-my.sharepoint.com/:u:/g/personal/wuyangli_cuhk_edu_hk/ESqX-V_eLSBEuaJXAzf64JMB16xF9kz3661pJSwQ-hOspg?e=XdABCH) and unzip it into the project folder. The data is pre-processed by the scripts in [tools](./tools).
```
Diffusion_UKAN
| data
| └─ cvc
| └─ images_64
| └─ busi
| └─ images_64
| └─ glas
| └─ images_64
```
## 📦 Prepare pre-trained models
Download released_models from [Onedrive](https://gocuhk-my.sharepoint.com/:u:/g/personal/wuyangli_cuhk_edu_hk/EUVSH8QFUmpJlxyoEj8Pr2IB8PzGbVJg53rc6GcqxGgLDg?e=a4glNt) and unzip it in the project folder.
```
Diffusion_UKAN
| released_models
| └─ ukan_cvc
| └─ FinalCheck   # generated toy images (see next section)
| └─ Gens         # the generated images used for evaluation in our paper
| └─ Tmp          # saved generated images during model training with a 50-epoch interval
| └─ Weights      # The final checkpoint
| └─ FID.txt      # raw evaluation data
| └─ IS.txt       # raw evaluation data  
| └─ ukan_busi
| └─ ukan_glas
```
## 🧸 Toy example
Images will be generated in `released_models/ukan_cvc/FinalCheck` by running this:
```python
python Main_Test.py
```
## 🔥 Training
<!-- You may need to modify the dirs slightly. -->
Please refer to the [training_scripts](./training_scripts) folder. Besides, you can play with different network variations by modifying `MODEL` according to the following dictionary,
```python
model_dict = {
'UNet': UNet,
'UNet_ConvKan': UNet_ConvKan,
'UMLP': UMLP,
'UKan_Hybrid': UKan_Hybrid,
'UNet_Baseline': UNet_Baseline,
}
```
## 🤞 Acknowledgement
Thanks for
We mainly appreciate these excellent projects
- [Simple DDPM](https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-)
- [Kolmogorov-Arnold Network](https://github.com/mintisan/awesome-kan)
- [Efficient Kolmogorov-Arnold Network](https://github.com/Blealtan/efficient-kan.git)
from torch.optim.lr_scheduler import _LRScheduler
class GradualWarmupScheduler(_LRScheduler):
def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None):
self.multiplier = multiplier
self.total_epoch = warm_epoch
self.after_scheduler = after_scheduler
self.finished = False
self.last_epoch = None
self.base_lrs = None
super().__init__(optimizer)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
def step(self, epoch=None, metrics=None):
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
else:
return super(GradualWarmupScheduler, self).step(epoch)
\ No newline at end of file
download data.zip and unzip here
\ No newline at end of file
Copyright 2017 Shane T. Barratt
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
# Inception Score Pytorch
Pytorch was lacking code to calculate the Inception Score for GANs. This repository fills this gap.
However, we do not recommend using the Inception Score to evaluate generative models, see [our note](https://arxiv.org/abs/1801.01973) for why.
## Getting Started
Clone the repository and navigate to it:
```
$ git clone git@github.com:sbarratt/inception-score-pytorch.git
$ cd inception-score-pytorch
```
To generate random 64x64 images and calculate the inception score, do the following:
```
$ python inception_score.py
```
The only function is `inception_score`. It takes a list of numpy images normalized to the range [0,1] and a set of arguments and then calculates the inception score. Please assure your images are 3x299x299 and if not (e.g. your GAN was trained on CIFAR), pass `resize=True` to the function to have it automatically resize using bilinear interpolation before passing the images to the inception network.
```python
def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
"""Computes the inception score of the generated images imgs
imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
cuda -- whether or not to run on GPU
batch_size -- batch size for feeding into Inception v3
splits -- number of splits
"""
```
### Prerequisites
You will need [torch](http://pytorch.org/), [torchvision](https://github.com/pytorch/vision), [numpy/scipy](https://scipy.org/).
## License
This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details
## Acknowledgments
* Inception Score from [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment