Commit 4369d182 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

added choice of filter basis as option

parent 652c4ab2
...@@ -206,6 +206,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta): ...@@ -206,6 +206,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1, groups: Optional[int] = 1,
bias: Optional[bool] = True, bias: Optional[bool] = True,
): ):
...@@ -214,7 +215,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta): ...@@ -214,7 +215,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
self.kernel_shape = kernel_shape self.kernel_shape = kernel_shape
# get the filter basis functions # get the filter basis functions
self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type="piecewise linear") self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
# groups # groups
self.groups = groups self.groups = groups
...@@ -256,13 +257,14 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -256,13 +257,14 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True, bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None, theta_cutoff: Optional[float] = None,
): ):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias) super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
self.nlat_in, self.nlon_in = in_shape self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape self.nlat_out, self.nlon_out = out_shape
...@@ -349,13 +351,14 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -349,13 +351,14 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True, bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None, theta_cutoff: Optional[float] = None,
): ):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias) super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
self.nlat_in, self.nlon_in = in_shape self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape self.nlat_out, self.nlon_out = out_shape
......
...@@ -196,13 +196,14 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -196,13 +196,14 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True, bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None, theta_cutoff: Optional[float] = None,
): ):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias) super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
self.nlat_in, self.nlon_in = in_shape self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape self.nlat_out, self.nlon_out = out_shape
...@@ -326,13 +327,14 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -326,13 +327,14 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True, bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None, theta_cutoff: Optional[float] = None,
): ):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias) super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
self.nlat_in, self.nlon_in = in_shape self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape self.nlat_out, self.nlon_out = out_shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment