Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
b6b2bce3
Commit
b6b2bce3
authored
Jan 11, 2025
by
Boris Bonev
Committed by
Boris Bonev
Jan 14, 2025
Browse files
implemented Zernike filter basis
parent
7126fb9a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
129 additions
and
30 deletions
+129
-30
Changelog.md
Changelog.md
+1
-0
examples/train_sfno.py
examples/train_sfno.py
+24
-3
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+21
-24
torch_harmonics/filter_basis.py
torch_harmonics/filter_basis.py
+83
-3
No files found.
Changelog.md
View file @
b6b2bce3
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
*
New filter basis normalization in DISCO convolutions
*
New filter basis normalization in DISCO convolutions
*
Reworked DISCO filter basis datastructure
*
Reworked DISCO filter basis datastructure
*
Support for new filter basis types
*
Support for new filter basis types
*
Adding Zernike polynomial basis on a disk
*
Adding Morlet wavelet basis functions on a spherical disk
*
Adding Morlet wavelet basis functions on a spherical disk
*
Cleaning up the SFNO example and adding new Local Spherical Neural Operator model
*
Cleaning up the SFNO example and adding new Local Spherical Neural Operator model
*
Updated resampling module to extend input signal to the poles if needed
*
Updated resampling module to extend input signal to the poles if needed
...
...
examples/train_sfno.py
View file @
b6b2bce3
...
@@ -430,7 +430,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
...
@@ -430,7 +430,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer
=
"none"
,
normalization_layer
=
"none"
,
)
)
models
[
f
"lsno_sc2_layers4_e32"
]
=
partial
(
models
[
f
"lsno_sc2_layers4_e32
_morlet
"
]
=
partial
(
LSNO
,
LSNO
,
img_size
=
(
nlat
,
nlon
),
img_size
=
(
nlat
,
nlon
),
grid
=
grid
,
grid
=
grid
,
...
@@ -443,6 +443,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
...
@@ -443,6 +443,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
pos_embed
=
False
,
pos_embed
=
False
,
use_mlp
=
True
,
use_mlp
=
True
,
normalization_layer
=
"none"
,
normalization_layer
=
"none"
,
kernel_shape
=
[
4
,
4
],
encoder_kernel_shape
=
[
4
,
4
],
filter_basis_type
=
"morlet"
)
models
[
f
"lsno_sc2_layers4_e32_zernike"
]
=
partial
(
LSNO
,
img_size
=
(
nlat
,
nlon
),
grid
=
grid
,
num_layers
=
4
,
scale_factor
=
2
,
embed_dim
=
32
,
operator_type
=
"driscoll-healy"
,
activation_function
=
"gelu"
,
big_skip
=
False
,
pos_embed
=
False
,
use_mlp
=
True
,
normalization_layer
=
"none"
,
kernel_shape
=
[
4
],
encoder_kernel_shape
=
[
4
],
filter_basis_type
=
"zernike"
)
)
# iterate over models and train each model
# iterate over models and train each model
...
@@ -468,7 +489,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
...
@@ -468,7 +489,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
# run the training
# run the training
if
train
:
if
train
:
run
=
wandb
.
init
(
project
=
"
sfno ablations
spherical swe"
,
group
=
model_name
,
name
=
model_name
+
"_"
+
str
(
time
.
time
()),
config
=
model_handle
.
keywords
)
run
=
wandb
.
init
(
project
=
"
local sno
spherical swe"
,
group
=
model_name
,
name
=
model_name
+
"_"
+
str
(
time
.
time
()),
config
=
model_handle
.
keywords
)
# optimizer:
# optimizer:
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5e-4
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5e-4
)
...
@@ -478,7 +499,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
...
@@ -478,7 +499,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
print
(
f
"Training
{
model_name
}
, single step"
)
print
(
f
"Training
{
model_name
}
, single step"
)
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
,
nepochs
=
2
0
,
loss_fn
=
"l2"
,
enable_amp
=
enable_amp
,
log_grads
=
log_grads
)
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
,
nepochs
=
10
0
,
loss_fn
=
"l2"
,
enable_amp
=
enable_amp
,
log_grads
=
log_grads
)
if
nfuture
>
0
:
if
nfuture
>
0
:
print
(
f
'Training
{
model_name
}
,
{
nfuture
}
step'
)
print
(
f
'Training
{
model_name
}
,
{
nfuture
}
step'
)
...
...
torch_harmonics/examples/models/lsno.py
View file @
b6b2bce3
...
@@ -70,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module):
...
@@ -70,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out
=
grid_out
,
grid_out
=
grid_out
,
groups
=
groups
,
groups
=
groups
,
bias
=
bias
,
bias
=
bias
,
theta_cutoff
=
math
.
sqrt
(
2
.0
)
*
torch
.
pi
/
float
(
out_shape
[
0
]
-
1
),
theta_cutoff
=
1
.0
*
torch
.
pi
/
float
(
out_shape
[
0
]
-
1
),
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -103,7 +103,7 @@ class DiscreteContinuousDecoder(nn.Module):
...
@@ -103,7 +103,7 @@ class DiscreteContinuousDecoder(nn.Module):
# # set up
# # set up
self
.
sht
=
RealSHT
(
*
in_shape
,
grid
=
grid_in
).
float
()
self
.
sht
=
RealSHT
(
*
in_shape
,
grid
=
grid_in
).
float
()
self
.
isht
=
InverseRealSHT
(
*
out_shape
,
lmax
=
self
.
sht
.
lmax
,
mmax
=
self
.
sht
.
mmax
,
grid
=
grid_out
).
float
()
self
.
isht
=
InverseRealSHT
(
*
out_shape
,
lmax
=
self
.
sht
.
lmax
,
mmax
=
self
.
sht
.
mmax
,
grid
=
grid_out
).
float
()
#
self.upscale = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
self
.
upscale
=
ResampleS2
(
*
in_shape
,
*
out_shape
,
grid_in
=
grid_in
,
grid_out
=
grid_out
)
# set up DISCO convolution
# set up DISCO convolution
self
.
conv
=
DiscreteContinuousConvS2
(
self
.
conv
=
DiscreteContinuousConvS2
(
...
@@ -117,28 +117,25 @@ class DiscreteContinuousDecoder(nn.Module):
...
@@ -117,28 +117,25 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out
=
grid_out
,
grid_out
=
grid_out
,
groups
=
groups
,
groups
=
groups
,
bias
=
False
,
bias
=
False
,
theta_cutoff
=
math
.
sqrt
(
2
.0
)
*
torch
.
pi
/
float
(
in_shape
[
0
]
-
1
),
theta_cutoff
=
1
.0
*
torch
.
pi
/
float
(
in_shape
[
0
]
-
1
),
)
)
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
def
upscale_sht
(
self
,
x
:
torch
.
Tensor
):
def
upscale_sht
(
self
,
x
:
torch
.
Tensor
):
return
self
.
isht
(
self
.
sht
(
x
))
return
self
.
isht
(
self
.
sht
(
x
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
dtype
=
x
.
dtype
dtype
=
x
.
dtype
#
x = self.upscale(x)
x
=
self
.
upscale
(
x
)
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
x
=
x
.
float
()
x
=
x
.
float
()
x
=
self
.
upscale_sht
(
x
)
#
x = self.upscale_sht(x)
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
x
=
x
.
to
(
dtype
=
dtype
)
x
=
x
.
to
(
dtype
=
dtype
)
return
x
return
x
class
SphericalNeuralOperatorBlock
(
nn
.
Module
):
class
SphericalNeuralOperatorBlock
(
nn
.
Module
):
"""
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
...
@@ -160,7 +157,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
...
@@ -160,7 +157,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
inner_skip
=
"None"
,
inner_skip
=
"None"
,
outer_skip
=
"linear"
,
outer_skip
=
"linear"
,
use_mlp
=
True
,
use_mlp
=
True
,
disco_kernel_shape
=
[
2
,
4
],
disco_kernel_shape
=
[
3
,
4
],
disco_basis_type
=
"piecewise linear"
,
disco_basis_type
=
"piecewise linear"
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -185,7 +182,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
...
@@ -185,7 +182,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
grid_in
=
forward_transform
.
grid
,
grid_in
=
forward_transform
.
grid
,
grid_out
=
inverse_transform
.
grid
,
grid_out
=
inverse_transform
.
grid
,
bias
=
False
,
bias
=
False
,
theta_cutoff
=
4
*
math
.
sqrt
(
2.0
)
*
torch
.
pi
/
float
(
inverse_transform
.
nlat
-
1
),
theta_cutoff
=
1.0
*
(
disco_kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
inverse_transform
.
nlat
-
1
),
)
)
elif
conv_type
==
"global"
:
elif
conv_type
==
"global"
:
self
.
global_conv
=
SpectralConvS2
(
forward_transform
,
inverse_transform
,
input_dim
,
output_dim
,
gain
=
gain_factor
,
operator_type
=
operator_type
,
bias
=
False
)
self
.
global_conv
=
SpectralConvS2
(
forward_transform
,
inverse_transform
,
input_dim
,
output_dim
,
gain
=
gain_factor
,
operator_type
=
operator_type
,
bias
=
False
)
...
@@ -294,6 +291,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
...
@@ -294,6 +291,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Activation function to use, by default "gelu"
Activation function to use, by default "gelu"
encoder_kernel_shape : int, optional
encoder_kernel_shape : int, optional
size of the encoder kernel
size of the encoder kernel
filter_basis_type: Optional[str]: str, optional
filter basis type
use_mlp : int, optional
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
mlp_ratio : int, optional
...
@@ -350,7 +349,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
...
@@ -350,7 +349,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
activation_function
=
"relu"
,
activation_function
=
"relu"
,
kernel_shape
=
[
3
,
4
],
kernel_shape
=
[
3
,
4
],
encoder_kernel_shape
=
[
3
,
4
],
encoder_kernel_shape
=
[
3
,
4
],
disco
_basis_type
=
"piecewise linear"
,
filter
_basis_type
=
"piecewise linear"
,
use_mlp
=
True
,
use_mlp
=
True
,
mlp_ratio
=
2.0
,
mlp_ratio
=
2.0
,
drop_rate
=
0.0
,
drop_rate
=
0.0
,
...
@@ -423,18 +422,17 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
...
@@ -423,18 +422,17 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self
.
pos_embed
=
None
self
.
pos_embed
=
None
# encoder
# encoder
self
.
encoder
=
DiscreteContinuousConvS2
(
self
.
encoder
=
DiscreteContinuousEncoder
(
self
.
in_chans
,
in_shape
=
self
.
img_size
,
self
.
embed_dim
,
out_shape
=
(
self
.
h
,
self
.
w
),
self
.
img_size
,
(
self
.
h
,
self
.
w
),
self
.
encoder_kernel_shape
,
basis_type
=
disco_basis_type
,
groups
=
1
,
grid_in
=
grid
,
grid_in
=
grid
,
grid_out
=
grid_internal
,
grid_out
=
grid_internal
,
inp_chans
=
self
.
in_chans
,
out_chans
=
self
.
embed_dim
,
kernel_shape
=
self
.
encoder_kernel_shape
,
basis_type
=
filter_basis_type
,
groups
=
1
,
bias
=
False
,
bias
=
False
,
theta_cutoff
=
math
.
sqrt
(
2
)
*
torch
.
pi
/
float
(
self
.
h
-
1
),
)
)
# prepare the SHT
# prepare the SHT
...
@@ -476,7 +474,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
...
@@ -476,7 +474,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
outer_skip
=
outer_skip
,
outer_skip
=
outer_skip
,
use_mlp
=
use_mlp
,
use_mlp
=
use_mlp
,
disco_kernel_shape
=
kernel_shape
,
disco_kernel_shape
=
kernel_shape
,
disco_basis_type
=
disco
_basis_type
,
disco_basis_type
=
filter
_basis_type
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
@@ -490,7 +488,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
...
@@ -490,7 +488,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
inp_chans
=
self
.
embed_dim
,
inp_chans
=
self
.
embed_dim
,
out_chans
=
self
.
out_chans
,
out_chans
=
self
.
out_chans
,
kernel_shape
=
self
.
encoder_kernel_shape
,
kernel_shape
=
self
.
encoder_kernel_shape
,
basis_type
=
disco
_basis_type
,
basis_type
=
filter
_basis_type
,
groups
=
1
,
groups
=
1
,
bias
=
False
,
bias
=
False
,
)
)
...
@@ -503,7 +501,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
...
@@ -503,7 +501,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
# scale = math.sqrt(0.5 / self.in_chans)
# scale = math.sqrt(0.5 / self.in_chans)
# nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale)
# nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale)
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
def
no_weight_decay
(
self
):
return
{
"pos_embed"
,
"cls_token"
}
return
{
"pos_embed"
,
"cls_token"
}
...
...
torch_harmonics/filter_basis.py
View file @
b6b2bce3
...
@@ -44,7 +44,7 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis
...
@@ -44,7 +44,7 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis
elif
basis_type
==
"morlet"
:
elif
basis_type
==
"morlet"
:
return
MorletFilterBasis
(
kernel_shape
=
kernel_shape
)
return
MorletFilterBasis
(
kernel_shape
=
kernel_shape
)
elif
basis_type
==
"zernike"
:
elif
basis_type
==
"zernike"
:
r
aise
NotImplementedError
(
)
r
eturn
ZernikeFilterBasis
(
kernel_shape
=
kernel_shape
)
else
:
else
:
raise
ValueError
(
f
"Unknown basis_type
{
basis_type
}
"
)
raise
ValueError
(
f
"Unknown basis_type
{
basis_type
}
"
)
...
@@ -54,6 +54,16 @@ def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
...
@@ -54,6 +54,16 @@ def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
return
torch
.
minimum
(
torch
.
abs
(
x1
-
x2
),
torch
.
abs
(
2
*
math
.
pi
-
torch
.
abs
(
x1
-
x2
)))
return
torch
.
minimum
(
torch
.
abs
(
x1
-
x2
),
torch
.
abs
(
2
*
math
.
pi
-
torch
.
abs
(
x1
-
x2
)))
def
_log_factorial
(
x
:
torch
.
Tensor
):
"""Helper function to compute the log factorial on a torch tensor"""
return
torch
.
lgamma
(
x
+
1
)
def
_factorial
(
x
:
torch
.
Tensor
):
"""Helper function to compute the factorial on a torch tensor"""
return
torch
.
exp
(
_log_factorial
(
x
))
class
FilterBasis
(
metaclass
=
abc
.
ABCMeta
):
class
FilterBasis
(
metaclass
=
abc
.
ABCMeta
):
"""
"""
Abstract base class for a filter basis
Abstract base class for a filter basis
...
@@ -226,7 +236,7 @@ class MorletFilterBasis(FilterBasis):
...
@@ -226,7 +236,7 @@ class MorletFilterBasis(FilterBasis):
def
kernel_size
(
self
):
def
kernel_size
(
self
):
return
self
.
kernel_shape
[
0
]
*
self
.
kernel_shape
[
1
]
return
self
.
kernel_shape
[
0
]
*
self
.
kernel_shape
[
1
]
def
_
gaussian_window
(
self
,
r
:
torch
.
Tensor
,
width
:
float
=
1.0
):
def
gaussian_window
(
self
,
r
:
torch
.
Tensor
,
width
:
float
=
1.0
):
return
1
/
(
2
*
math
.
pi
*
width
**
2
)
*
torch
.
exp
(
-
0.5
*
r
**
2
/
(
width
**
2
))
return
1
/
(
2
*
math
.
pi
*
width
**
2
)
*
torch
.
exp
(
-
0.5
*
r
**
2
/
(
width
**
2
))
def
compute_support_vals
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
,
width
:
float
=
0.25
):
def
compute_support_vals
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
,
width
:
float
=
0.25
):
...
@@ -254,6 +264,76 @@ class MorletFilterBasis(FilterBasis):
...
@@ -254,6 +264,76 @@ class MorletFilterBasis(FilterBasis):
disk_area
=
1.0
disk_area
=
1.0
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals
=
self
.
_gaussian_window
(
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
/
r_cutoff
,
width
=
width
)
*
harmonic
[
iidx
[:,
0
],
iidx
[:,
1
],
iidx
[:,
2
]]
/
disk_area
vals
=
self
.
gaussian_window
(
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
/
r_cutoff
,
width
=
width
)
*
harmonic
[
iidx
[:,
0
],
iidx
[:,
1
],
iidx
[:,
2
]]
/
disk_area
return
iidx
,
vals
class
ZernikeFilterBasis
(
FilterBasis
):
"""
Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials
"""
def
__init__
(
self
,
kernel_shape
:
Union
[
int
,
Tuple
[
int
],
List
[
int
]],
):
if
isinstance
(
kernel_shape
,
tuple
)
or
isinstance
(
kernel_shape
,
list
):
kernel_shape
=
kernel_shape
[
0
]
if
not
isinstance
(
kernel_shape
,
int
):
raise
ValueError
(
f
"expected kernel_shape to be an integer but got
{
kernel_shape
}
instead."
)
super
().
__init__
(
kernel_shape
=
kernel_shape
)
@
property
def
kernel_size
(
self
):
return
(
self
.
kernel_shape
*
(
self
.
kernel_shape
+
1
))
//
2
def
zernikeradial
(
self
,
r
:
torch
.
Tensor
,
n
:
torch
.
Tensor
,
m
:
torch
.
Tensor
):
out
=
torch
.
zeros_like
(
r
)
bound
=
(
n
-
m
)
//
2
+
1
max_bound
=
bound
.
max
().
item
()
for
k
in
range
(
max_bound
):
inc
=
(
-
1
)
**
k
*
_factorial
(
n
-
k
)
*
r
**
(
n
-
2
*
k
)
/
(
math
.
factorial
(
k
)
*
_factorial
((
n
+
m
)
//
2
-
k
)
*
_factorial
((
n
-
m
)
//
2
-
k
))
out
+=
torch
.
where
(
k
<
bound
,
inc
,
0.0
)
return
out
def
zernikepoly
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
n
:
torch
.
Tensor
,
l
:
torch
.
Tensor
):
m
=
2
*
l
-
n
return
torch
.
where
(
m
<
0
,
self
.
zernikeradial
(
r
,
n
,
-
m
)
*
torch
.
sin
(
m
*
phi
),
self
.
zernikeradial
(
r
,
n
,
m
)
*
torch
.
cos
(
m
*
phi
))
def
compute_support_vals
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
,
width
:
float
=
0.25
):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel
=
torch
.
arange
(
self
.
kernel_size
).
reshape
(
-
1
,
1
,
1
)
# get relevant indices
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
))
# indexing logic for zernike polynomials
# the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed
# precompute shifts in the level of the "pyramid"
nshifts
=
torch
.
arange
(
self
.
kernel_shape
)
nshifts
=
(
nshifts
+
1
)
*
nshifts
//
2
# find the level and position within the pyramid
nkernel
=
torch
.
searchsorted
(
nshifts
,
ikernel
,
right
=
True
)
-
1
lkernel
=
ikernel
-
nshifts
[
nkernel
]
# mkernel = 2 * ikernel - nkernel * (nkernel + 2)
# get corresponding coordinates and n and l indices
r
=
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
/
r_cutoff
phi
=
phi
[
iidx
[:,
1
],
iidx
[:,
2
]]
n
=
nkernel
[
iidx
[:,
0
],
0
,
0
]
l
=
lkernel
[
iidx
[:,
0
],
0
,
0
]
# computes the Zernike polynomials using helper functions
vals
=
self
.
zernikepoly
(
r
,
phi
,
n
,
l
)
return
iidx
,
vals
return
iidx
,
vals
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment