import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Union class PolynomialFunction(nn.Module): def __init__(self, degree: int = 3): super().__init__() self.degree = degree def forward(self, x): return torch.stack([x ** i for i in range(self.degree)], dim=-1) class BSplineFunction(nn.Module): def __init__(self, grid_min: float = -2., grid_max: float = 2., degree: int = 3, num_basis: int = 8): super(BSplineFunction, self).__init__() self.degree = degree self.num_basis = num_basis self.knots = torch.linspace(grid_min, grid_max, num_basis + degree + 1) # Uniform knots def basis_function(self, i, k, t): if k == 0: return ((self.knots[i] <= t) & (t < self.knots[i + 1])).float() else: left_num = (t - self.knots[i]) * self.basis_function(i, k - 1, t) left_den = self.knots[i + k] - self.knots[i] left = left_num / left_den if left_den != 0 else 0 right_num = (self.knots[i + k + 1] - t) * self.basis_function(i + 1, k - 1, t) right_den = self.knots[i + k + 1] - self.knots[i + 1] right = right_num / right_den if right_den != 0 else 0 return left + right def forward(self, x): x = x.squeeze() # Assuming x is of shape (B, 1) basis_functions = torch.stack([self.basis_function(i, self.degree, x) for i in range(self.num_basis)], dim=-1) return basis_functions class ChebyshevFunction(nn.Module): def __init__(self, degree: int = 4): super(ChebyshevFunction, self).__init__() self.degree = degree def forward(self, x): chebyshev_polynomials = [torch.ones_like(x), x] for n in range(2, self.degree): chebyshev_polynomials.append(2 * x * chebyshev_polynomials[-1] - chebyshev_polynomials[-2]) return torch.stack(chebyshev_polynomials, dim=-1) class FourierBasisFunction(nn.Module): def __init__(self, num_frequencies: int = 4, period: float = 1.0): super(FourierBasisFunction, self).__init__() assert num_frequencies % 2 == 0, "num_frequencies must be even" self.num_frequencies = num_frequencies self.period = nn.Parameter(torch.Tensor([period]), requires_grad=False) def forward(self, x): frequencies = torch.arange(1, self.num_frequencies // 2 + 1, device=x.device) sin_components = torch.sin(2 * torch.pi * frequencies * x[..., None] / self.period) cos_components = torch.cos(2 * torch.pi * frequencies * x[..., None] / self.period) basis_functions = torch.cat([sin_components, cos_components], dim=-1) return basis_functions class RadialBasisFunction(nn.Module): def __init__( self, grid_min: float = -2., grid_max: float = 2., num_grids: int = 4, denominator: float = None, ): super().__init__() grid = torch.linspace(grid_min, grid_max, num_grids) self.grid = torch.nn.Parameter(grid, requires_grad=False) self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1) def forward(self, x): return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2) class SplineConv2D(nn.Conv2d): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = 3, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, init_scale: float = 0.1, padding_mode: str = "zeros", **kw ) -> None: self.init_scale = init_scale super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, **kw ) def reset_parameters(self) -> None: nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale) if self.bias is not None: nn.init.zeros_(self.bias) class FastKANConvLayer(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = 3, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, grid_min: float = -2., grid_max: float = 2., num_grids: int = 4, use_base_update: bool = True, base_activation = F.silu, spline_weight_init_scale: float = 0.1, padding_mode: str = "zeros", kan_type: str = "BSpline", # kan_type: str = "RBF", ) -> None: super().__init__() if kan_type == "RBF": self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids) elif kan_type == "Fourier": self.rbf = FourierBasisFunction(num_grids) elif kan_type == "Poly": self.rbf = PolynomialFunction(num_grids) elif kan_type == "Chebyshev": self.rbf = ChebyshevFunction(num_grids) elif kan_type == "BSpline": self.rbf = BSplineFunction(grid_min, grid_max, 3, num_grids) self.spline_conv = SplineConv2D(in_channels * num_grids, out_channels, kernel_size, stride, padding, dilation, groups, bias, spline_weight_init_scale, padding_mode) self.use_base_update = use_base_update if use_base_update: self.base_activation = base_activation self.base_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) def forward(self, x): batch_size, channels, height, width = x.shape x_rbf = self.rbf(x.view(batch_size, channels, -1)).view(batch_size, channels, height, width, -1) x_rbf = x_rbf.permute(0, 4, 1, 2, 3).contiguous().view(batch_size, -1, height, width) # Apply spline convolution ret = self.spline_conv(x_rbf) if self.use_base_update: base = self.base_conv(self.base_activation(x)) ret = ret + base return ret