Commit 4369d182 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

added choice of filter basis as option

parent 652c4ab2
......@@ -206,6 +206,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
in_channels: int,
out_channels: int,
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1,
bias: Optional[bool] = True,
):
......@@ -214,7 +215,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
self.kernel_shape = kernel_shape
# get the filter basis functions
self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type="piecewise linear")
self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
# groups
self.groups = groups
......@@ -256,13 +257,14 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
......@@ -349,13 +351,14 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
......
......@@ -196,13 +196,14 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
......@@ -326,13 +327,14 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment