Commit 1c2859f2 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in filter basis

parent eeda67aa
...@@ -65,8 +65,10 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -65,8 +65,10 @@ class FilterBasis(metaclass=abc.ABCMeta):
""" """
Initialize the filter basis. Initialize the filter basis.
Parameters: Parameters
kernel_shape: shape of the kernel, can be an integer or tuple of integers -----------
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel, can be an integer or tuple of integers
""" """
self.kernel_shape = kernel_shape self.kernel_shape = kernel_shape
...@@ -76,8 +78,10 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -76,8 +78,10 @@ class FilterBasis(metaclass=abc.ABCMeta):
""" """
Abstract property that should return the size of the kernel. Abstract property that should return the size of the kernel.
Returns: Returns
int: the kernel size -------
kernel_size: int
The size of the kernel
""" """
raise NotImplementedError raise NotImplementedError
...@@ -140,8 +144,10 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -140,8 +144,10 @@ class PiecewiseLinearFilterBasis(FilterBasis):
""" """
Compute the kernel size for piecewise linear basis. Compute the kernel size for piecewise linear basis.
Returns: Returns
int: the kernel size -------
kernel_size: int
The size of the kernel
""" """
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2 return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
...@@ -250,8 +256,10 @@ class MorletFilterBasis(FilterBasis): ...@@ -250,8 +256,10 @@ class MorletFilterBasis(FilterBasis):
""" """
Initialize the Morlet filter basis. Initialize the Morlet filter basis.
Parameters: Parameters
kernel_shape: shape of the kernel, can be an integer or tuple of integers -----------
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel, can be an integer or tuple of integers
""" """
if isinstance(kernel_shape, int): if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape, kernel_shape] kernel_shape = [kernel_shape, kernel_shape]
...@@ -265,8 +273,10 @@ class MorletFilterBasis(FilterBasis): ...@@ -265,8 +273,10 @@ class MorletFilterBasis(FilterBasis):
""" """
Compute the kernel size for Morlet basis. Compute the kernel size for Morlet basis.
Returns: Returns
int: the kernel size -------
kernel_size: int
The size of the kernel
""" """
return self.kernel_shape[0] * self.kernel_shape[1] return self.kernel_shape[0] * self.kernel_shape[1]
...@@ -274,12 +284,17 @@ class MorletFilterBasis(FilterBasis): ...@@ -274,12 +284,17 @@ class MorletFilterBasis(FilterBasis):
""" """
Compute Gaussian window function. Compute Gaussian window function.
Parameters: Parameters
r: radial distance tensor -----------
width: width parameter of the Gaussian r: torch.Tensor
Radial distance tensor
width: float
Width parameter of the Gaussian
Returns: Returns
torch.Tensor: Gaussian window values -------
out: torch.Tensor
Gaussian window values
""" """
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2)) return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
...@@ -287,12 +302,17 @@ class MorletFilterBasis(FilterBasis): ...@@ -287,12 +302,17 @@ class MorletFilterBasis(FilterBasis):
""" """
Compute Hann window function. Compute Hann window function.
Parameters: Parameters
r: radial distance tensor -----------
width: width parameter of the Hann window r: torch.Tensor
Radial distance tensor
width: float
Width parameter of the Hann window
Returns: Returns
torch.Tensor: Hann window values -------
out: torch.Tensor
Hann window values
""" """
return torch.cos(0.5 * torch.pi * r / width) ** 2 return torch.cos(0.5 * torch.pi * r / width) ** 2
...@@ -338,8 +358,10 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -338,8 +358,10 @@ class ZernikeFilterBasis(FilterBasis):
""" """
Initialize the Zernike filter basis. Initialize the Zernike filter basis.
Parameters: Parameters
kernel_shape: shape of the kernel, can be an integer or tuple of integers -----------
kernel_shape: Union[int, Tuple[int]]
Shape of the kernel, can be an integer or tuple of integers
""" """
if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list): if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list):
kernel_shape = kernel_shape[0] kernel_shape = kernel_shape[0]
...@@ -353,8 +375,10 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -353,8 +375,10 @@ class ZernikeFilterBasis(FilterBasis):
""" """
Compute the kernel size for Zernike basis. Compute the kernel size for Zernike basis.
Returns: Returns
int: the kernel size -------
kernel_size: int
The size of the kernel
""" """
return (self.kernel_shape * (self.kernel_shape + 1)) // 2 return (self.kernel_shape * (self.kernel_shape + 1)) // 2
...@@ -362,13 +386,19 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -362,13 +386,19 @@ class ZernikeFilterBasis(FilterBasis):
""" """
Compute radial Zernike polynomials. Compute radial Zernike polynomials.
Parameters: Parameters
r: radial distance tensor -----------
n: principal quantum number r: torch.Tensor
m: azimuthal quantum number Radial distance tensor
n: torch.Tensor
Principal quantum number
m: torch.Tensor
Azimuthal quantum number
Returns: Returns
torch.Tensor: radial Zernike polynomial values -------
out: torch.Tensor
Radial Zernike polynomial values
""" """
out = torch.zeros_like(r) out = torch.zeros_like(r)
bound = (n - m) // 2 + 1 bound = (n - m) // 2 + 1
...@@ -385,14 +415,21 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -385,14 +415,21 @@ class ZernikeFilterBasis(FilterBasis):
""" """
Compute Zernike polynomials. Compute Zernike polynomials.
Parameters: Parameters
r: radial distance tensor -----------
phi: azimuthal angle tensor r: torch.Tensor
n: principal quantum number Radial distance tensor
l: azimuthal quantum number phi: torch.Tensor
Azimuthal angle tensor
n: torch.Tensor
Principal quantum number
l: torch.Tensor
Azimuthal quantum number
Returns: Returns
torch.Tensor: Zernike polynomial values -------
out: torch.Tensor
Zernike polynomial values
""" """
m = 2 * l - n m = 2 * l - n
return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi)) return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment