import math import torch from torch import nn from torch.nn import init from torch.nn import functional as F class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class TimeEmbedding(nn.Module): def __init__(self, T, d_model, dim): assert d_model % 2 == 0 super().__init__() emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) emb = torch.exp(-emb) pos = torch.arange(T).float() emb = pos[:, None] * emb[None, :] assert list(emb.shape) == [T, d_model // 2] emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) assert list(emb.shape) == [T, d_model // 2, 2] emb = emb.view(T, d_model) self.timembedding = nn.Sequential( nn.Embedding.from_pretrained(emb), nn.Linear(d_model, dim), Swish(), nn.Linear(dim, dim), ) self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, nn.Linear): init.xavier_uniform_(module.weight) init.zeros_(module.bias) def forward(self, t): emb = self.timembedding(t) return emb class DownSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): x = self.main(x) return x class UpSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): _, _, H, W = x.shape x = F.interpolate( x, scale_factor=2, mode='nearest') x = self.main(x) return x class AttnBlock(nn.Module): def __init__(self, in_ch): super().__init__() self.group_norm = nn.GroupNorm(32, in_ch) self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.initialize() def initialize(self): for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.proj.weight, gain=1e-5) def forward(self, x): B, C, H, W = x.shape h = self.group_norm(x) q = self.proj_q(h) k = self.proj_k(h) v = self.proj_v(h) q = q.permute(0, 2, 3, 1).view(B, H * W, C) k = k.view(B, C, H * W) w = torch.bmm(q, k) * (int(C) ** (-0.5)) assert list(w.shape) == [B, H * W, H * W] w = F.softmax(w, dim=-1) v = v.permute(0, 2, 3, 1).view(B, H * W, C) h = torch.bmm(w, v) assert list(h.shape) == [B, H * W, C] h = h.view(B, H, W, C).permute(0, 3, 1, 2) h = self.proj(h) return x + h class ResBlock(nn.Module): def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): super().__init__() self.block1 = nn.Sequential( nn.GroupNorm(32, in_ch), Swish(), nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(tdim, out_ch), ) self.block2 = nn.Sequential( nn.GroupNorm(32, out_ch), Swish(), nn.Dropout(dropout), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), ) if in_ch != out_ch: self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) else: self.shortcut = nn.Identity() if attn: self.attn = AttnBlock(out_ch) else: self.attn = nn.Identity() self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) def forward(self, x, temb): h = self.block1(x) h += self.temb_proj(temb)[:, :, None, None] h = self.block2(h) h = h + self.shortcut(x) h = self.attn(h) return h # return x class KANLinear(torch.nn.Module): def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KANLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) self.spline_weight = torch.nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = torch.nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order : -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ Compute the B-spline bases for the given input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). Returns: torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = ( self.grid ) # (in_features, grid_size + 2 * spline_order + 1) x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ Compute the coefficients of the curve that interpolates the given points. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). Returns: torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): assert x.dim() == 2 and x.size(1) == self.in_features base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) return base_output + spline_output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # sort each channel individually to collect data distribution x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive grid = torch.concatenate( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ Compute the regularization loss. This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation. The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization. """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / regularization_loss_activation regularization_loss_entropy = -torch.sum(p * p.log()) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy ) class Ukan(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks1 = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), # ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.middleblocks2 = nn.ModuleList([ # ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=torch.nn.SiLU grid_eps=0.02 grid_range=[-1, 1] kan_c=512 self.fc1 = KANLinear( kan_c, kan_c *2, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) # print(now_ch) # self.dwconv = DWConv(kan_c *2) self.act = nn.GELU() self.fc2 = KANLinear( kan_c *2, kan_c, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle # for layer in self.middleblocks1: # h = layer(h, temb) B, C, H, W = h.shape # transform B, C, H, W into B*H*W, C h = h.permute(0, 2, 3, 1).reshape(B*H*W, C) h =self.fc2( self.fc1(h)) # transform B*H*W, C into B, C, H, W h = h.reshape(B, H, W, C).permute(0, 3, 1, 2) # for layer in self.middleblocks2: # h = layer(h, temb) ### Stage 3 # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h class DW_bn_relu(nn.Module): def __init__(self, dim=768): super(DW_bn_relu, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) self.bn = nn.BatchNorm2d(dim) self.relu = nn.ReLU() def forward(self, x, H, W): B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = self.bn(x) x = self.relu(x) x = x.flatten(2).transpose(1, 2) return x class kan(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.dim = in_features grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=torch.nn.SiLU grid_eps=0.02 grid_range=[-1, 1] # self.fc1 = nn.Linear(in_features, hidden_features) self.fc1 = KANLinear( in_features, hidden_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) # self.fc2 = nn.Linear(hidden_features, out_features) self.fc2 = KANLinear( hidden_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) self.fc3 = KANLinear( hidden_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) # ############################################## self.version = 4 # version 4 hard code ���ܶ����� # ############################################## if self.version == 1: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() self.dwconv_4 = DWConv(hidden_features) self.act_4 = act_layer() elif self.version == 2: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() elif self.version == 3: self.dwconv_1 = DW_bn_relu(hidden_features) self.dwconv_2 = DW_bn_relu(hidden_features) self.dwconv_3 = DW_bn_relu(hidden_features) elif self.version == 4: self.dwconv_1 = DW_bn_relu(hidden_features) self.dwconv_2 = DW_bn_relu(hidden_features) self.dwconv_3 = DW_bn_relu(hidden_features) elif self.version == 5: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() elif self.version == 6: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() elif self.version == 7: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() elif self.version == 8: self.dwconv_1 = DW_bn_relu(hidden_features) self.dwconv_2 = DW_bn_relu(hidden_features) self.dwconv_3 = DW_bn_relu(hidden_features) self.drop = nn.Dropout(drop) self.shift_size = shift_size self.pad = shift_size // 2 self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W): # pdb.set_trace() B, N, C = x.shape if self.version == 1: x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_4(x, H, W) x = self.act_4(x) elif self.version == 2: x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() elif self.version == 3: x = self.dwconv_1(x, H, W) x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() elif self.version == 4: x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_1(x, H, W) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) elif self.version == 5: x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) elif self.version == 6: x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) elif self.version == 7: x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.drop(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.drop(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) x = self.drop(x) elif self.version == 8: x = self.dwconv_1(x, H, W) x = self.drop(x) x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.drop(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.drop(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() return x class Ukan_v3(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks1 = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), # ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.middleblocks2 = nn.ModuleList([ # ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=torch.nn.SiLU grid_eps=0.02 grid_range=[-1, 1] kan_c=512 self.kan1 = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=4) self.kan2 = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=4) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle # for layer in self.middleblocks1: # h = layer(h, temb) B, C, H, W = h.shape # transform B, C, H, W into B*H*W, C h = h.reshape(B,C, H*W).permute(0, 2, 1) h = self.kan1(h, H, W) h = self.kan2(h, H, W) h = h.permute(0, 2, 1).reshape(B, C, H, W) # h =self.fc2( self.fc1(h)) # transform B*H*W, C into B, C, H, W # h = h.reshape(B, H, W, C).permute(0, 3, 1, 2) # B, N, C = x.shape # for layer in self.middleblocks2: # h = layer(h, temb) ### Stage 3 # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h class Ukan_v2(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout,version=4): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks1 = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), # ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.middleblocks2 = nn.ModuleList([ # ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=torch.nn.SiLU grid_eps=0.02 grid_range=[-1, 1] kan_c=512 self.kan = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=version) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle # for layer in self.middleblocks1: # h = layer(h, temb) B, C, H, W = h.shape # transform B, C, H, W into B*H*W, C h = h.reshape(B,C, H*W).permute(0, 2, 1) h = self.kan(h, H, W) h = h.permute(0, 2, 1).reshape(B, C, H, W) # h =self.fc2( self.fc1(h)) # transform B*H*W, C into B, C, H, W # h = h.reshape(B, H, W, C).permute(0, 3, 1, 2) # B, N, C = x.shape # for layer in self.middleblocks2: # h = layer(h, temb) ### Stage 3 # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h class UNet(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle for layer in self.middleblocks: h = layer(h, temb) # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h class UNet_MLP(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks1 = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), # ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.middleblocks2 = nn.ModuleList([ # ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) kan_c=512 self.fc1 = nn.Linear( kan_c, kan_c *2, ) self.act = nn.GELU() self.fc2 = nn.Linear( kan_c *2, kan_c, ) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle # for layer in self.middleblocks1: # h = layer(h, temb) B, C, H, W = h.shape # transform B, C, H, W into B*H*W, C h = h.permute(0, 2, 3, 1).reshape(B*H*W, C) h =self.fc2(self.act(self.fc1(h))) # transform B*H*W, C into B, C, H, W h = h.reshape(B, H, W, C).permute(0, 3, 1, 2) # for layer in self.middleblocks2: # h = layer(h, temb) ### Stage 3 # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h if __name__ == '__main__': batch_size = 8 model = UNet( T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1], num_res_blocks=2, dropout=0.1) x = torch.randn(batch_size, 3, 32, 32) t = torch.randint(1000, (batch_size, )) y = model(x, t) print(y.shape)