Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
e7ceb9c8
Commit
e7ceb9c8
authored
Jun 07, 2023
by
Boris Bonev
Browse files
Adding SFNO examples
parent
24490256
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2533 additions
and
0 deletions
+2533
-0
examples/sfno_examples/models/activations.py
examples/sfno_examples/models/activations.py
+96
-0
examples/sfno_examples/models/contractions.py
examples/sfno_examples/models/contractions.py
+141
-0
examples/sfno_examples/models/factorizations.py
examples/sfno_examples/models/factorizations.py
+199
-0
examples/sfno_examples/models/layers.py
examples/sfno_examples/models/layers.py
+552
-0
examples/sfno_examples/models/sfno.py
examples/sfno_examples/models/sfno.py
+517
-0
examples/sfno_examples/train_sfno.ipynb
examples/sfno_examples/train_sfno.ipynb
+488
-0
examples/sfno_examples/train_sfno.py
examples/sfno_examples/train_sfno.py
+419
-0
examples/sfno_examples/utils/pde_dataset.py
examples/sfno_examples/utils/pde_dataset.py
+121
-0
No files found.
examples/sfno_examples/models/activations.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
torch
import
torch.nn
as
nn
# complex activation functions
class
ComplexCardioid
(
nn
.
Module
):
"""
Complex Cardioid activation function
"""
def
__init__
(
self
):
super
(
ComplexCardioid
,
self
).
__init__
()
def
forward
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
0.5
*
(
1.
+
torch
.
cos
(
z
.
angle
()))
*
z
return
out
class
ComplexReLU
(
nn
.
Module
):
"""
Complex-valued variants of the ReLU activation function
"""
def
__init__
(
self
,
negative_slope
=
0.
,
mode
=
"real"
,
bias_shape
=
None
,
scale
=
1.
):
super
(
ComplexReLU
,
self
).
__init__
()
# store parameters
self
.
mode
=
mode
if
self
.
mode
in
[
"modulus"
,
"halfplane"
]:
if
bias_shape
is
not
None
:
self
.
bias
=
nn
.
Parameter
(
scale
*
torch
.
ones
(
bias_shape
,
dtype
=
torch
.
float32
))
else
:
self
.
bias
=
nn
.
Parameter
(
scale
*
torch
.
ones
((
1
),
dtype
=
torch
.
float32
))
else
:
self
.
bias
=
0
self
.
negative_slope
=
negative_slope
self
.
act
=
nn
.
LeakyReLU
(
negative_slope
=
negative_slope
)
def
forward
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
mode
==
"cartesian"
:
zr
=
torch
.
view_as_real
(
z
)
za
=
self
.
act
(
zr
)
out
=
torch
.
view_as_complex
(
za
)
elif
self
.
mode
==
"modulus"
:
zabs
=
torch
.
sqrt
(
torch
.
square
(
z
.
real
)
+
torch
.
square
(
z
.
imag
))
out
=
torch
.
where
(
zabs
+
self
.
bias
>
0
,
(
zabs
+
self
.
bias
)
*
z
/
zabs
,
0.0
)
elif
self
.
mode
==
"cardioid"
:
out
=
0.5
*
(
1.
+
torch
.
cos
(
z
.
angle
()))
*
z
# elif self.mode == "halfplane":
# # bias is an angle parameter in this case
# modified_angle = torch.angle(z) - self.bias
# condition = torch.logical_and( (0. <= modified_angle), (modified_angle < torch.pi/2.) )
# out = torch.where(condition, z, self.negative_slope * z)
elif
self
.
mode
==
"real"
:
zr
=
torch
.
view_as_real
(
z
)
outr
=
zr
.
clone
()
outr
[...,
0
]
=
self
.
act
(
zr
[...,
0
])
out
=
torch
.
view_as_complex
(
outr
)
else
:
raise
NotImplementedError
return
out
\ No newline at end of file
examples/sfno_examples/models/contractions.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
torch
"""
Contains complex contractions wrapped into jit for harmonic layers
"""
@
torch
.
jit
.
script
def
compl_contract2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bixys,kixr->srbkx"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
@
torch
.
jit
.
script
def
compl_contract2d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
res
=
torch
.
einsum
(
"bixy,kix->bkx"
,
ac
,
bc
)
return
torch
.
view_as_real
(
res
)
@
torch
.
jit
.
script
def
compl_contract_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bins,kinr->srbkn"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
@
torch
.
jit
.
script
def
compl_contract_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
res
=
torch
.
einsum
(
"bin,kin->bkn"
,
ac
,
bc
)
return
torch
.
view_as_real
(
res
)
# Helper routines for spherical MLPs
@
torch
.
jit
.
script
def
compl_mul1d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bixs,ior->srbox"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
@
torch
.
jit
.
script
def
compl_mul1d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
resc
=
torch
.
einsum
(
"bix,io->box"
,
ac
,
bc
)
res
=
torch
.
view_as_real
(
resc
)
return
res
@
torch
.
jit
.
script
def
compl_muladd1d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
res
=
compl_mul1d_fwd
(
a
,
b
)
+
c
return
res
@
torch
.
jit
.
script
def
compl_muladd1d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmpcc
=
torch
.
view_as_complex
(
compl_mul1d_fwd_c
(
a
,
b
))
cc
=
torch
.
view_as_complex
(
c
)
return
torch
.
view_as_real
(
tmpcc
+
cc
)
# Helper routines for FFT MLPs
@
torch
.
jit
.
script
def
compl_mul2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bixys,ior->srboxy"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
@
torch
.
jit
.
script
def
compl_mul2d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
resc
=
torch
.
einsum
(
"bixy,io->boxy"
,
ac
,
bc
)
res
=
torch
.
view_as_real
(
resc
)
return
res
@
torch
.
jit
.
script
def
compl_muladd2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
res
=
compl_mul2d_fwd
(
a
,
b
)
+
c
return
res
@
torch
.
jit
.
script
def
compl_muladd2d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmpcc
=
torch
.
view_as_complex
(
compl_mul2d_fwd_c
(
a
,
b
))
cc
=
torch
.
view_as_complex
(
c
)
return
torch
.
view_as_real
(
tmpcc
+
cc
)
@
torch
.
jit
.
script
def
real_mul2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
torch
.
einsum
(
"bixy,io->boxy"
,
a
,
b
)
return
out
@
torch
.
jit
.
script
def
real_muladd2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
compl_mul2d_fwd_c
(
a
,
b
)
+
c
# for all the experimental layers
# @torch.jit.script
# def compl_exp_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
# resc = torch.einsum("bixy,xio->boxy", ac, bc)
# res = torch.view_as_real(resc)
# return res
# @torch.jit.script
# def compl_exp_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# tmpcc = torch.view_as_complex(compl_exp_mul2d_fwd(a, b))
# cc = torch.view_as_complex(c)
# return torch.view_as_real(tmpcc + cc)
examples/sfno_examples/models/factorizations.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
torch
import
tensorly
as
tl
tl
.
set_backend
(
'pytorch'
)
from
tltorch.factorized_tensors.core
import
FactorizedTensor
einsum_symbols
=
'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
def
_contract_dense
(
x
,
weight
,
separable
=
False
,
operator_type
=
'diagonal'
):
order
=
tl
.
ndim
(
x
)
# batch-size, in_channels, x, y...
x_syms
=
list
(
einsum_symbols
[:
order
])
# in_channels, out_channels, x, y...
weight_syms
=
list
(
x_syms
[
1
:])
# no batch-size
# batch-size, out_channels, x, y...
if
separable
:
out_syms
=
[
x_syms
[
0
]]
+
list
(
weight_syms
)
else
:
weight_syms
.
insert
(
1
,
einsum_symbols
[
order
])
# outputs
out_syms
=
list
(
weight_syms
)
out_syms
[
0
]
=
x_syms
[
0
]
if
operator_type
==
'diagonal'
:
pass
elif
operator_type
==
'block-diagonal'
:
weight_syms
.
insert
(
-
1
,
einsum_symbols
[
order
+
1
])
out_syms
[
-
1
]
=
weight_syms
[
-
2
]
elif
operator_type
==
'vector'
:
weight_syms
.
pop
()
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
eq
=
''
.
join
(
x_syms
)
+
','
+
''
.
join
(
weight_syms
)
+
'->'
+
''
.
join
(
out_syms
)
if
not
torch
.
is_tensor
(
weight
):
weight
=
weight
.
to_tensor
()
return
tl
.
einsum
(
eq
,
x
,
weight
)
def
_contract_cp
(
x
,
cp_weight
,
separable
=
False
,
operator_type
=
'diagonal'
):
order
=
tl
.
ndim
(
x
)
x_syms
=
str
(
einsum_symbols
[:
order
])
rank_sym
=
einsum_symbols
[
order
]
out_sym
=
einsum_symbols
[
order
+
1
]
out_syms
=
list
(
x_syms
)
if
separable
:
factor_syms
=
[
einsum_symbols
[
1
]
+
rank_sym
]
#in only
else
:
out_syms
[
1
]
=
out_sym
factor_syms
=
[
einsum_symbols
[
1
]
+
rank_sym
,
out_sym
+
rank_sym
]
#in, out
factor_syms
+=
[
xs
+
rank_sym
for
xs
in
x_syms
[
2
:]]
#x, y, ...
if
operator_type
==
'diagonal'
:
pass
elif
operator_type
==
'block-diagonal'
:
out_syms
[
-
1
]
=
einsum_symbols
[
order
+
2
]
factor_syms
+=
[
out_syms
[
-
1
]
+
rank_sym
]
elif
operator_type
==
'vector'
:
factor_syms
.
pop
()
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
eq
=
x_syms
+
','
+
rank_sym
+
','
+
','
.
join
(
factor_syms
)
+
'->'
+
''
.
join
(
out_syms
)
return
tl
.
einsum
(
eq
,
x
,
cp_weight
.
weights
,
*
cp_weight
.
factors
)
def
_contract_tucker
(
x
,
tucker_weight
,
separable
=
False
,
operator_type
=
'diagonal'
):
order
=
tl
.
ndim
(
x
)
x_syms
=
str
(
einsum_symbols
[:
order
])
out_sym
=
einsum_symbols
[
order
]
out_syms
=
list
(
x_syms
)
if
separable
:
core_syms
=
einsum_symbols
[
order
+
1
:
2
*
order
]
# factor_syms = [einsum_symbols[1]+core_syms[0]] #in only
factor_syms
=
[
xs
+
rs
for
(
xs
,
rs
)
in
zip
(
x_syms
[
1
:],
core_syms
)]
#x, y, ...
else
:
core_syms
=
einsum_symbols
[
order
+
1
:
2
*
order
+
1
]
out_syms
[
1
]
=
out_sym
factor_syms
=
[
einsum_symbols
[
1
]
+
core_syms
[
0
],
out_sym
+
core_syms
[
1
]]
#out, in
factor_syms
+=
[
xs
+
rs
for
(
xs
,
rs
)
in
zip
(
x_syms
[
2
:],
core_syms
[
2
:])]
#x, y, ...
if
operator_type
==
'diagonal'
:
pass
elif
operator_type
==
'block-diagonal'
:
raise
NotImplementedError
(
f
"Operator type
{
operator_type
}
not implemented for Tucker"
)
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
eq
=
x_syms
+
','
+
core_syms
+
','
+
','
.
join
(
factor_syms
)
+
'->'
+
''
.
join
(
out_syms
)
return
tl
.
einsum
(
eq
,
x
,
tucker_weight
.
core
,
*
tucker_weight
.
factors
)
def
_contract_tt
(
x
,
tt_weight
,
separable
=
False
,
operator_type
=
'diagonal'
):
order
=
tl
.
ndim
(
x
)
x_syms
=
list
(
einsum_symbols
[:
order
])
weight_syms
=
list
(
x_syms
[
1
:])
# no batch-size
if
not
separable
:
weight_syms
.
insert
(
1
,
einsum_symbols
[
order
])
# outputs
out_syms
=
list
(
weight_syms
)
out_syms
[
0
]
=
x_syms
[
0
]
else
:
out_syms
=
list
(
x_syms
)
if
operator_type
==
'diagonal'
:
pass
elif
operator_type
==
'block-diagonal'
:
weight_syms
.
insert
(
-
1
,
einsum_symbols
[
order
+
1
])
out_syms
[
-
1
]
=
weight_syms
[
-
2
]
elif
operator_type
==
'vector'
:
weight_syms
.
pop
()
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
rank_syms
=
list
(
einsum_symbols
[
order
+
2
:])
tt_syms
=
[]
for
i
,
s
in
enumerate
(
weight_syms
):
tt_syms
.
append
([
rank_syms
[
i
],
s
,
rank_syms
[
i
+
1
]])
eq
=
''
.
join
(
x_syms
)
+
','
+
','
.
join
(
''
.
join
(
f
)
for
f
in
tt_syms
)
+
'->'
+
''
.
join
(
out_syms
)
return
tl
.
einsum
(
eq
,
x
,
*
tt_weight
.
factors
)
def
get_contract_fun
(
weight
,
implementation
=
'reconstructed'
,
separable
=
False
):
"""Generic ND implementation of Fourier Spectral Conv contraction
Parameters
----------
weight : tensorly-torch's FactorizedTensor
implementation : {'reconstructed', 'factorized'}, default is 'reconstructed'
whether to reconstruct the weight and do a forward pass (reconstructed)
or contract directly the factors of the factorized weight with the input (factorized)
Returns
-------
function : (x, weight) -> x * weight in Fourier space
"""
if
implementation
==
'reconstructed'
:
return
_contract_dense
elif
implementation
==
'factorized'
:
if
torch
.
is_tensor
(
weight
):
return
_contract_dense
elif
isinstance
(
weight
,
FactorizedTensor
):
if
weight
.
name
.
lower
()
==
'complexdense'
:
return
_contract_dense
elif
weight
.
name
.
lower
()
==
'complextucker'
:
return
_contract_tucker
elif
weight
.
name
.
lower
()
==
'complextt'
:
return
_contract_tt
elif
weight
.
name
.
lower
()
==
'complexcp'
:
return
_contract_cp
else
:
raise
ValueError
(
f
'Got unexpected factorized weight type
{
weight
.
name
}
'
)
else
:
raise
ValueError
(
f
'Got unexpected weight type of class
{
weight
.
__class__
.
__name__
}
'
)
else
:
raise
ValueError
(
f
'Got
{
implementation
=
}
, expected "reconstructed" or "factorized"'
)
examples/sfno_examples/models/layers.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from
functools
import
partial
from
collections
import
OrderedDict
from
copy
import
Error
,
deepcopy
from
re
import
S
from
numpy.lib.arraypad
import
pad
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.fft
from
torch.nn.modules.container
import
Sequential
from
torch.utils.checkpoint
import
checkpoint
,
checkpoint_sequential
from
torch.cuda
import
amp
from
typing
import
Optional
import
math
from
torch_harmonics
import
*
from
models.contractions
import
*
from
models.activations
import
*
from
models.factorizations
import
get_contract_fun
# # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl
# from tensorly.plugins import use_opt_einsum
# tl.set_backend('pytorch')
# use_opt_einsum('optimal')
from
tltorch.factorized_tensors.core
import
FactorizedTensor
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.
+
math
.
erf
(
x
/
math
.
sqrt
(
2.
)))
/
2.
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
warnings
.
warn
(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
,
stacklevel
=
2
)
with
torch
.
no_grad
():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l
=
norm_cdf
((
a
-
mean
)
/
std
)
u
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor
.
uniform_
(
2
*
l
-
1
,
2
*
u
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
return
tensor
def
trunc_normal_
(
tensor
,
mean
=
0.
,
std
=
1.
,
a
=-
2.
,
b
=
2.
):
r
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
)
@
torch
.
jit
.
script
def
drop_path
(
x
:
torch
.
Tensor
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
)
->
torch
.
Tensor
:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
1.
-
drop_prob
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
# work with diff dim tensors, not just 2d ConvNets
random_tensor
=
keep_prob
+
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
output_bias
=
True
,
drop_rate
=
0.
,
checkpointing
=
False
):
super
(
MLP
,
self
).
__init__
()
self
.
checkpointing
=
checkpointing
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
fc1
=
nn
.
Conv2d
(
in_features
,
hidden_features
,
1
,
bias
=
True
)
# ln1 = norm_layer(num_features=hidden_features)
act
=
act_layer
()
fc2
=
nn
.
Conv2d
(
hidden_features
,
out_features
,
1
,
bias
=
output_bias
)
if
drop_rate
>
0.
:
drop
=
nn
.
Dropout
(
drop_rate
)
self
.
fwd
=
nn
.
Sequential
(
fc1
,
act
,
drop
,
fc2
,
drop
)
else
:
self
.
fwd
=
nn
.
Sequential
(
fc1
,
act
,
fc2
)
@
torch
.
jit
.
ignore
def
checkpoint_forward
(
self
,
x
):
return
checkpoint
(
self
.
fwd
,
x
)
def
forward
(
self
,
x
):
if
self
.
checkpointing
:
return
self
.
checkpoint_forward
(
x
)
else
:
return
self
.
fwd
(
x
)
class
RealFFT2
(
nn
.
Module
):
"""
Helper routine to wrap FFT similarly to the SHT
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
):
super
(
RealFFT2
,
self
).
__init__
()
self
.
nlat
=
nlat
self
.
nlon
=
nlon
self
.
lmax
=
lmax
or
self
.
nlat
self
.
mmax
=
mmax
or
self
.
nlon
//
2
+
1
def
forward
(
self
,
x
):
y
=
torch
.
fft
.
rfft2
(
x
,
dim
=
(
-
2
,
-
1
),
norm
=
"ortho"
)
y
=
torch
.
cat
((
y
[...,
:
math
.
ceil
(
self
.
lmax
/
2
),
:
self
.
mmax
],
y
[...,
-
math
.
floor
(
self
.
lmax
/
2
):,
:
self
.
mmax
]),
dim
=-
2
)
return
y
class
InverseRealFFT2
(
nn
.
Module
):
"""
Helper routine to wrap FFT similarly to the SHT
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
):
super
(
InverseRealFFT2
,
self
).
__init__
()
self
.
nlat
=
nlat
self
.
nlon
=
nlon
self
.
lmax
=
lmax
or
self
.
nlat
self
.
mmax
=
mmax
or
self
.
nlon
//
2
+
1
def
forward
(
self
,
x
):
return
torch
.
fft
.
irfft2
(
x
,
dim
=
(
-
2
,
-
1
),
s
=
(
self
.
nlat
,
self
.
nlon
),
norm
=
"ortho"
)
class
SpectralConvS2
(
nn
.
Module
):
"""
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
"""
def
__init__
(
self
,
forward_transform
,
inverse_transform
,
in_channels
,
out_channels
,
scale
=
'auto'
,
operator_type
=
'diagonal'
,
rank
=
0.2
,
factorization
=
None
,
separable
=
False
,
implementation
=
'factorized'
,
decomposition_kwargs
=
dict
(),
bias
=
False
):
super
(
SpectralConvS2
,
self
).
__init__
()
if
scale
==
'auto'
:
scale
=
(
1
/
(
in_channels
*
out_channels
))
self
.
forward_transform
=
forward_transform
self
.
inverse_transform
=
inverse_transform
self
.
modes_lat
=
self
.
inverse_transform
.
lmax
self
.
modes_lon
=
self
.
inverse_transform
.
mmax
self
.
scale_residual
=
(
self
.
forward_transform
.
nlat
!=
self
.
inverse_transform
.
nlat
)
\
or
(
self
.
forward_transform
.
nlon
!=
self
.
inverse_transform
.
nlon
)
# Make sure we are using a Complex Factorized Tensor
if
factorization
is
None
:
factorization
=
'Dense'
# No factorization
if
not
factorization
.
lower
().
startswith
(
'complex'
):
factorization
=
f
'Complex
{
factorization
}
'
# remember factorization details
self
.
operator_type
=
operator_type
self
.
rank
=
rank
self
.
factorization
=
factorization
self
.
separable
=
separable
assert
self
.
inverse_transform
.
lmax
==
self
.
modes_lat
assert
self
.
inverse_transform
.
mmax
==
self
.
modes_lon
weight_shape
=
[
in_channels
]
if
not
self
.
separable
:
weight_shape
+=
[
out_channels
]
if
self
.
operator_type
==
'diagonal'
:
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
]
elif
self
.
operator_type
==
'block-diagonal'
:
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
,
self
.
modes_lon
]
elif
self
.
operator_type
==
'vector'
:
weight_shape
+=
[
self
.
modes_lat
]
else
:
raise
NotImplementedError
(
f
"Unkonw operator type f
{
self
.
operator_type
}
"
)
# form weight tensors
self
.
weight
=
FactorizedTensor
.
new
(
weight_shape
,
rank
=
self
.
rank
,
factorization
=
factorization
,
fixed_rank_modes
=
False
,
**
decomposition_kwargs
)
# initialization of weights
self
.
weight
.
normal_
(
0
,
scale
)
self
.
_contract
=
get_contract_fun
(
self
.
weight
,
implementation
=
implementation
,
separable
=
separable
)
if
bias
:
self
.
bias
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
1
,
out_channels
,
1
,
1
))
def
forward
(
self
,
x
):
dtype
=
x
.
dtype
x
=
x
.
float
()
residual
=
x
B
,
C
,
H
,
W
=
x
.
shape
with
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
forward_transform
(
x
)
if
self
.
scale_residual
:
residual
=
self
.
inverse_transform
(
x
)
x
=
self
.
_contract
(
x
,
self
.
weight
,
separable
=
self
.
separable
,
operator_type
=
self
.
operator_type
)
with
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
inverse_transform
(
x
)
if
hasattr
(
self
,
'bias'
):
x
=
x
+
self
.
bias
x
=
x
.
type
(
dtype
)
return
x
,
residual
class
SpectralAttention2d
(
nn
.
Module
):
"""
geometrical Spectral Attention layer
"""
def
__init__
(
self
,
forward_transform
,
inverse_transform
,
embed_dim
,
sparsity_threshold
=
0.0
,
hidden_size_factor
=
2
,
use_complex_kernels
=
False
,
complex_activation
=
'real'
,
bias
=
False
,
spectral_layers
=
1
,
drop_rate
=
0.
):
super
(
SpectralAttention2d
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
sparsity_threshold
=
sparsity_threshold
self
.
hidden_size
=
int
(
hidden_size_factor
*
self
.
embed_dim
)
self
.
scale
=
1
/
embed_dim
**
2
self
.
mul_add_handle
=
compl_muladd2d_fwd_c
if
use_complex_kernels
else
compl_muladd2d_fwd
self
.
mul_handle
=
compl_mul2d_fwd_c
if
use_complex_kernels
else
compl_mul2d_fwd
self
.
spectral_layers
=
spectral_layers
self
.
modes_lat
=
forward_transform
.
lmax
self
.
modes_lon
=
forward_transform
.
mmax
# only storing the forward handle to be able to call it
self
.
forward_transform
=
forward_transform
self
.
inverse_transform
=
inverse_transform
self
.
scale_residual
=
(
self
.
forward_transform
.
nlat
!=
self
.
inverse_transform
.
nlat
)
\
or
(
self
.
forward_transform
.
nlon
!=
self
.
inverse_transform
.
nlon
)
assert
inverse_transform
.
lmax
==
self
.
modes_lat
assert
inverse_transform
.
mmax
==
self
.
modes_lon
# weights
w
=
[
self
.
scale
*
torch
.
randn
(
self
.
embed_dim
,
self
.
hidden_size
,
2
)]
for
l
in
range
(
1
,
self
.
spectral_layers
):
w
.
append
(
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
,
self
.
hidden_size
,
2
))
self
.
w
=
nn
.
ParameterList
(
w
)
if
bias
:
self
.
b
=
nn
.
ParameterList
([
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
,
1
,
2
)
for
_
in
range
(
self
.
spectral_layers
)])
self
.
wout
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
,
self
.
embed_dim
,
2
))
self
.
drop
=
nn
.
Dropout
(
drop_rate
)
if
drop_rate
>
0.
else
nn
.
Identity
()
self
.
activations
=
nn
.
ModuleList
([])
for
l
in
range
(
0
,
self
.
spectral_layers
):
self
.
activations
.
append
(
ComplexReLU
(
mode
=
complex_activation
,
bias_shape
=
(
self
.
hidden_size
,
1
,
1
),
scale
=
self
.
scale
))
def
forward_mlp
(
self
,
x
):
x
=
torch
.
view_as_real
(
x
)
xr
=
x
for
l
in
range
(
self
.
spectral_layers
):
if
hasattr
(
self
,
'b'
):
xr
=
self
.
mul_add_handle
(
xr
,
self
.
w
[
l
],
self
.
b
[
l
])
else
:
xr
=
self
.
mul_handle
(
xr
,
self
.
w
[
l
])
xr
=
torch
.
view_as_complex
(
xr
)
xr
=
self
.
activations
[
l
](
xr
)
xr
=
self
.
drop
(
xr
)
xr
=
torch
.
view_as_real
(
xr
)
x
=
self
.
mul_handle
(
xr
,
self
.
wout
)
x
=
torch
.
view_as_complex
(
x
)
return
x
def
forward
(
self
,
x
):
dtype
=
x
.
dtype
x
=
x
.
float
()
residual
=
x
with
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
forward_transform
(
x
)
if
self
.
scale_residual
:
residual
=
self
.
inverse_transform
(
x
)
x
=
self
.
forward_mlp
(
x
)
with
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
inverse_transform
(
x
)
x
=
x
.
type
(
dtype
)
return
x
,
residual
class
SpectralAttentionS2
(
nn
.
Module
):
"""
Spherical non-linear FNO layer
"""
def
__init__
(
self
,
forward_transform
,
inverse_transform
,
embed_dim
,
operator_type
=
'diagonal'
,
sparsity_threshold
=
0.0
,
hidden_size_factor
=
2
,
complex_activation
=
'real'
,
scale
=
'auto'
,
bias
=
False
,
spectral_layers
=
1
,
drop_rate
=
0.
):
super
(
SpectralAttentionS2
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
sparsity_threshold
=
sparsity_threshold
self
.
operator_type
=
operator_type
self
.
spectral_layers
=
spectral_layers
if
scale
==
'auto'
:
self
.
scale
=
(
1
/
(
embed_dim
*
embed_dim
))
self
.
modes_lat
=
forward_transform
.
lmax
self
.
modes_lon
=
forward_transform
.
mmax
# only storing the forward handle to be able to call it
self
.
forward_transform
=
forward_transform
self
.
inverse_transform
=
inverse_transform
self
.
scale_residual
=
(
self
.
forward_transform
.
nlat
!=
self
.
inverse_transform
.
nlat
)
\
or
(
self
.
forward_transform
.
nlon
!=
self
.
inverse_transform
.
nlon
)
assert
inverse_transform
.
lmax
==
self
.
modes_lat
assert
inverse_transform
.
mmax
==
self
.
modes_lon
hidden_size
=
int
(
hidden_size_factor
*
self
.
embed_dim
)
if
operator_type
==
'diagonal'
:
self
.
mul_add_handle
=
compl_muladd2d_fwd
self
.
mul_handle
=
compl_mul2d_fwd
# weights
w
=
[
self
.
scale
*
torch
.
randn
(
self
.
embed_dim
,
hidden_size
,
2
)]
for
l
in
range
(
1
,
self
.
spectral_layers
):
w
.
append
(
self
.
scale
*
torch
.
randn
(
hidden_size
,
hidden_size
,
2
))
self
.
w
=
nn
.
ParameterList
(
w
)
self
.
wout
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
hidden_size
,
self
.
embed_dim
,
2
))
if
bias
:
self
.
b
=
nn
.
ParameterList
([
self
.
scale
*
torch
.
randn
(
hidden_size
,
1
,
1
,
2
)
for
_
in
range
(
self
.
spectral_layers
)])
self
.
activations
=
nn
.
ModuleList
([])
for
l
in
range
(
0
,
self
.
spectral_layers
):
self
.
activations
.
append
(
ComplexReLU
(
mode
=
complex_activation
,
bias_shape
=
(
hidden_size
,
1
,
1
),
scale
=
self
.
scale
))
elif
operator_type
==
'vector'
:
self
.
mul_add_handle
=
compl_exp_muladd2d_fwd
self
.
mul_handle
=
compl_exp_mul2d_fwd
# weights
w
=
[
self
.
scale
*
torch
.
randn
(
self
.
modes_lat
,
self
.
embed_dim
,
hidden_size
,
2
)]
for
l
in
range
(
1
,
self
.
spectral_layers
):
w
.
append
(
self
.
scale
*
torch
.
randn
(
self
.
modes_lat
,
hidden_size
,
hidden_size
,
2
))
self
.
w
=
nn
.
ParameterList
(
w
)
if
bias
:
self
.
b
=
nn
.
ParameterList
([
self
.
scale
*
torch
.
randn
(
hidden_size
,
1
,
1
,
2
)
for
_
in
range
(
self
.
spectral_layers
)])
self
.
wout
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
modes_lat
,
hidden_size
,
self
.
embed_dim
,
2
))
self
.
activations
=
nn
.
ModuleList
([])
for
l
in
range
(
0
,
self
.
spectral_layers
):
self
.
activations
.
append
(
ComplexReLU
(
mode
=
complex_activation
,
bias_shape
=
(
hidden_size
,
1
,
1
),
scale
=
self
.
scale
))
else
:
raise
ValueError
(
'Unknown operator type'
)
self
.
drop
=
nn
.
Dropout
(
drop_rate
)
if
drop_rate
>
0.
else
nn
.
Identity
()
def
forward_mlp
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
xr
=
torch
.
view_as_real
(
x
)
for
l
in
range
(
self
.
spectral_layers
):
if
hasattr
(
self
,
'b'
):
xr
=
self
.
mul_add_handle
(
xr
,
self
.
w
[
l
],
self
.
b
[
l
])
else
:
xr
=
self
.
mul_handle
(
xr
,
self
.
w
[
l
])
xr
=
torch
.
view_as_complex
(
xr
)
xr
=
self
.
activations
[
l
](
xr
)
xr
=
self
.
drop
(
xr
)
xr
=
torch
.
view_as_real
(
xr
)
# final MLP
x
=
self
.
mul_handle
(
xr
,
self
.
wout
)
x
=
torch
.
view_as_complex
(
x
)
return
x
def
forward
(
self
,
x
):
dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
residual
=
x
# FWD transform
with
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
forward_transform
(
x
)
if
self
.
scale_residual
:
residual
=
self
.
inverse_transform
(
x
)
# MLP
x
=
self
.
forward_mlp
(
x
)
# BWD transform
with
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
inverse_transform
(
x
)
# cast back to initial precision
x
=
x
.
to
(
dtype
)
return
x
,
residual
\ No newline at end of file
examples/sfno_examples/models/sfno.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
torch
import
torch.nn
as
nn
from
apex.normalization
import
FusedLayerNorm
from
torch_harmonics
import
*
from
models.layers
import
*
class
SpectralFilterLayer
(
nn
.
Module
):
"""
Fourier layer. Contains the convolution part of the FNO/SFNO
"""
def
__init__
(
self
,
forward_transform
,
inverse_transform
,
embed_dim
,
filter_type
=
'non-linear'
,
operator_type
=
'diagonal'
,
sparsity_threshold
=
0.0
,
use_complex_kernels
=
True
,
hidden_size_factor
=
2
,
factorization
=
None
,
separable
=
False
,
rank
=
1e-2
,
complex_activation
=
'real'
,
spectral_layers
=
1
,
drop_rate
=
0
):
super
(
SpectralFilterLayer
,
self
).
__init__
()
if
filter_type
==
'non-linear'
and
isinstance
(
forward_transform
,
RealSHT
):
self
.
filter
=
SpectralAttentionS2
(
forward_transform
,
inverse_transform
,
embed_dim
,
operator_type
=
operator_type
,
sparsity_threshold
=
sparsity_threshold
,
hidden_size_factor
=
hidden_size_factor
,
complex_activation
=
complex_activation
,
spectral_layers
=
spectral_layers
,
drop_rate
=
drop_rate
,
bias
=
False
)
elif
filter_type
==
'non-linear'
and
isinstance
(
forward_transform
,
RealFFT2
):
self
.
filter
=
SpectralAttention2d
(
forward_transform
,
inverse_transform
,
embed_dim
,
sparsity_threshold
=
sparsity_threshold
,
use_complex_kernels
=
use_complex_kernels
,
hidden_size_factor
=
hidden_size_factor
,
complex_activation
=
complex_activation
,
spectral_layers
=
spectral_layers
,
drop_rate
=
drop_rate
,
bias
=
False
)
elif
filter_type
==
'linear'
:
self
.
filter
=
SpectralConvS2
(
forward_transform
,
inverse_transform
,
embed_dim
,
embed_dim
,
operator_type
=
operator_type
,
rank
=
rank
,
factorization
=
factorization
,
separable
=
separable
,
bias
=
True
)
else
:
raise
(
NotImplementedError
)
def
forward
(
self
,
x
):
return
self
.
filter
(
x
)
class
SphericalFourierNeuralOperatorBlock
(
nn
.
Module
):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
"""
def
__init__
(
self
,
forward_transform
,
inverse_transform
,
embed_dim
,
filter_type
=
'non-linear'
,
operator_type
=
'diagonal'
,
mlp_ratio
=
2.
,
drop_rate
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
(
nn
.
LayerNorm
,
nn
.
LayerNorm
),
sparsity_threshold
=
0.0
,
use_complex_kernels
=
True
,
factorization
=
None
,
separable
=
False
,
rank
=
128
,
inner_skip
=
'linear'
,
outer_skip
=
None
,
# None, nn.linear or nn.Identity
concat_skip
=
False
,
use_mlp
=
True
,
complex_activation
=
'real'
,
spectral_layers
=
3
):
super
(
SphericalFourierNeuralOperatorBlock
,
self
).
__init__
()
# norm layer
self
.
norm0
=
norm_layer
[
0
]()
#((h,w))
# convolution layer
self
.
filter
=
SpectralFilterLayer
(
forward_transform
,
inverse_transform
,
embed_dim
,
filter_type
,
operator_type
=
operator_type
,
sparsity_threshold
=
sparsity_threshold
,
use_complex_kernels
=
use_complex_kernels
,
hidden_size_factor
=
mlp_ratio
,
factorization
=
factorization
,
separable
=
separable
,
rank
=
rank
,
complex_activation
=
complex_activation
,
spectral_layers
=
spectral_layers
,
drop_rate
=
drop_rate
)
if
inner_skip
==
'linear'
:
self
.
inner_skip
=
nn
.
Conv2d
(
embed_dim
,
embed_dim
,
1
,
1
)
elif
inner_skip
==
'identity'
:
self
.
inner_skip
=
nn
.
Identity
()
self
.
concat_skip
=
concat_skip
if
concat_skip
and
inner_skip
is
not
None
:
self
.
inner_skip_conv
=
nn
.
Conv2d
(
2
*
embed_dim
,
embed_dim
,
1
,
bias
=
False
)
if
filter_type
==
'linear'
or
filter_type
==
'local'
:
self
.
act_layer
=
act_layer
()
# dropout
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
# norm layer
self
.
norm1
=
norm_layer
[
1
]()
#((h,w))
if
use_mlp
==
True
:
mlp_hidden_dim
=
int
(
embed_dim
*
mlp_ratio
)
self
.
mlp
=
MLP
(
in_features
=
embed_dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop_rate
=
drop_rate
,
checkpointing
=
False
)
if
outer_skip
==
'linear'
:
self
.
outer_skip
=
nn
.
Conv2d
(
embed_dim
,
embed_dim
,
1
,
1
)
elif
outer_skip
==
'identity'
:
self
.
outer_skip
=
nn
.
Identity
()
if
concat_skip
and
outer_skip
is
not
None
:
self
.
outer_skip_conv
=
nn
.
Conv2d
(
2
*
embed_dim
,
embed_dim
,
1
,
bias
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
norm0
(
x
)
x
,
residual
=
self
.
filter
(
x
)
if
hasattr
(
self
,
'inner_skip'
):
if
self
.
concat_skip
:
x
=
torch
.
cat
((
x
,
self
.
inner_skip
(
residual
)),
dim
=
1
)
x
=
self
.
inner_skip_conv
(
x
)
else
:
x
=
x
+
self
.
inner_skip
(
residual
)
if
hasattr
(
self
,
'act_layer'
):
x
=
self
.
act_layer
(
x
)
x
=
self
.
norm1
(
x
)
if
hasattr
(
self
,
'mlp'
):
x
=
self
.
mlp
(
x
)
x
=
self
.
drop_path
(
x
)
if
hasattr
(
self
,
'outer_skip'
):
if
self
.
concat_skip
:
x
=
torch
.
cat
((
x
,
self
.
outer_skip
(
residual
)),
dim
=
1
)
x
=
self
.
outer_skip_conv
(
x
)
else
:
x
=
x
+
self
.
outer_skip
(
residual
)
return
x
class
SphericalFourierNeuralOperatorNet
(
nn
.
Module
):
"""
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
both linear and non-linear variants.
Parameters
----------
filter_type : str, optional
Type of filter to use ('linear', 'non-linear'), by default "linear"
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional
Type of operator to use ('vector', 'diagonal'), by default "vector"
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
scale_factor : int, optional
Scale factor to use, by default 3
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
embed_dim : int, optional
Dimension of the embeddings, by default 256
num_layers : int, optional
Number of layers in the network, by default 4
activation_function : str, optional
Activation function to use, by default "gelu"
encoder_layers : int, optional
Number of layers in the encoder, by default 1
use_mlp : int, optional
Whether to use MLP, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
sparsity_threshold : float, optional
Threshold for sparsity, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
use_complex_kernels : bool, optional
Whether to use complex kernels, by default True
big_skip : bool, optional
Whether to add a single large skip connection, by default True
rank : float, optional
Rank of the approximation, by default 1.0
factorization : Any, optional
Type of factorization to use, by default None
separable : bool, optional
Whether to use separable convolutions, by default False
rank : (int, Tuple[int]), optional
If a factorization is used, which rank to use. Argument is passed to tensorly
complex_activation : str, optional
Type of complex activation function to use, by default "real"
spectral_layers : int, optional
Number of spectral layers, by default 3
pos_embed : bool, optional
Whether to use positional embedding, by default True
Example:
--------
>>> model = SphericalFourierNeuralOperatorNet(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=2,
... encoder_layers=1,
... num_blocks=4,
... spectral_layers=2,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def
__init__
(
self
,
filter_type
=
'linear'
,
spectral_transform
=
'sht'
,
operator_type
=
'vector'
,
img_size
=
(
128
,
256
),
scale_factor
=
3
,
in_chans
=
3
,
out_chans
=
3
,
embed_dim
=
256
,
num_layers
=
4
,
activation_function
=
'gelu'
,
encoder_layers
=
1
,
use_mlp
=
True
,
mlp_ratio
=
2.
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
sparsity_threshold
=
0.0
,
normalization_layer
=
'instance_norm'
,
hard_thresholding_fraction
=
1.0
,
use_complex_kernels
=
True
,
big_skip
=
True
,
factorization
=
None
,
separable
=
False
,
rank
=
128
,
complex_activation
=
'real'
,
spectral_layers
=
2
,
pos_embed
=
True
):
super
(
SphericalFourierNeuralOperatorNet
,
self
).
__init__
()
self
.
filter_type
=
filter_type
self
.
spectral_transform
=
spectral_transform
self
.
operator_type
=
operator_type
self
.
img_size
=
img_size
self
.
scale_factor
=
scale_factor
self
.
in_chans
=
in_chans
self
.
out_chans
=
out_chans
self
.
embed_dim
=
self
.
num_features
=
embed_dim
self
.
pos_embed_dim
=
self
.
embed_dim
self
.
num_layers
=
num_layers
self
.
hard_thresholding_fraction
=
hard_thresholding_fraction
self
.
normalization_layer
=
normalization_layer
self
.
use_mlp
=
use_mlp
self
.
encoder_layers
=
encoder_layers
self
.
big_skip
=
big_skip
self
.
factorization
=
factorization
self
.
separable
=
separable
,
self
.
rank
=
rank
self
.
complex_activation
=
complex_activation
self
.
spectral_layers
=
spectral_layers
# activation function
if
activation_function
==
'relu'
:
self
.
activation_function
=
nn
.
ReLU
elif
activation_function
==
'gelu'
:
self
.
activation_function
=
nn
.
GELU
else
:
raise
ValueError
(
f
"Unknown activation function
{
activation_function
}
"
)
# compute downsampled image size
self
.
h
=
self
.
img_size
[
0
]
//
scale_factor
self
.
w
=
self
.
img_size
[
1
]
//
scale_factor
# dropout
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
if
drop_rate
>
0.
else
nn
.
Identity
()
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)]
# pick norm layer
if
self
.
normalization_layer
==
"layer_norm"
:
norm_layer0
=
partial
(
nn
.
LayerNorm
,
normalized_shape
=
(
self
.
img_size
[
0
],
self
.
img_size
[
1
]),
eps
=
1e-6
)
norm_layer1
=
partial
(
nn
.
LayerNorm
,
normalized_shape
=
(
self
.
h
,
self
.
w
),
eps
=
1e-6
)
elif
self
.
normalization_layer
==
"instance_norm"
:
norm_layer0
=
partial
(
nn
.
InstanceNorm2d
,
num_features
=
self
.
embed_dim
,
eps
=
1e-6
,
affine
=
True
,
track_running_stats
=
False
)
norm_layer1
=
norm_layer0
elif
self
.
normalization_layer
==
"none"
:
norm_layer0
=
nn
.
Identity
norm_layer1
=
norm_layer0
else
:
raise
NotImplementedError
(
f
"Error, normalization
{
self
.
normalization_layer
}
not implemented."
)
if
pos_embed
:
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
embed_dim
,
self
.
img_size
[
0
],
self
.
img_size
[
1
]))
else
:
self
.
pos_embed
=
None
# encoder
encoder_hidden_dim
=
self
.
embed_dim
current_dim
=
self
.
in_chans
encoder_modules
=
[]
for
i
in
range
(
self
.
encoder_layers
):
encoder_modules
.
append
(
nn
.
Conv2d
(
current_dim
,
encoder_hidden_dim
,
1
,
bias
=
True
))
encoder_modules
.
append
(
self
.
activation_function
())
current_dim
=
encoder_hidden_dim
encoder_modules
.
append
(
nn
.
Conv2d
(
current_dim
,
self
.
embed_dim
,
1
,
bias
=
False
))
self
.
encoder
=
nn
.
Sequential
(
*
encoder_modules
)
# prepare the spectral transform
if
self
.
spectral_transform
==
'sht'
:
modes_lat
=
int
(
self
.
h
*
self
.
hard_thresholding_fraction
)
modes_lon
=
int
((
self
.
w
//
2
+
1
)
*
self
.
hard_thresholding_fraction
)
self
.
trans_down
=
RealSHT
(
*
self
.
img_size
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
'equiangular'
).
float
()
self
.
itrans_up
=
InverseRealSHT
(
*
self
.
img_size
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
'equiangular'
).
float
()
self
.
trans
=
RealSHT
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
'legendre-gauss'
).
float
()
self
.
itrans
=
InverseRealSHT
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
'legendre-gauss'
).
float
()
elif
self
.
spectral_transform
==
'fft'
:
modes_lat
=
int
(
self
.
h
*
self
.
hard_thresholding_fraction
)
modes_lon
=
int
((
self
.
w
//
2
+
1
)
*
self
.
hard_thresholding_fraction
)
self
.
trans_down
=
RealFFT2
(
*
self
.
img_size
,
lmax
=
modes_lat
,
mmax
=
modes_lon
).
float
()
self
.
itrans_up
=
InverseRealFFT2
(
*
self
.
img_size
,
lmax
=
modes_lat
,
mmax
=
modes_lon
).
float
()
self
.
trans
=
RealFFT2
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
).
float
()
self
.
itrans
=
InverseRealFFT2
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
).
float
()
else
:
raise
(
ValueError
(
'Unknown spectral transform'
))
self
.
blocks
=
nn
.
ModuleList
([])
for
i
in
range
(
self
.
num_layers
):
first_layer
=
i
==
0
last_layer
=
i
==
self
.
num_layers
-
1
forward_transform
=
self
.
trans_down
if
first_layer
else
self
.
trans
inverse_transform
=
self
.
itrans_up
if
last_layer
else
self
.
itrans
inner_skip
=
'linear'
outer_skip
=
'identity'
if
first_layer
:
norm_layer
=
(
norm_layer0
,
norm_layer1
)
elif
last_layer
:
norm_layer
=
(
norm_layer1
,
norm_layer0
)
else
:
norm_layer
=
(
norm_layer1
,
norm_layer1
)
block
=
SphericalFourierNeuralOperatorBlock
(
forward_transform
,
inverse_transform
,
self
.
embed_dim
,
filter_type
=
filter_type
,
operator_type
=
self
.
operator_type
,
mlp_ratio
=
mlp_ratio
,
drop_rate
=
drop_rate
,
drop_path
=
dpr
[
i
],
act_layer
=
self
.
activation_function
,
norm_layer
=
norm_layer
,
sparsity_threshold
=
sparsity_threshold
,
use_complex_kernels
=
use_complex_kernels
,
inner_skip
=
inner_skip
,
outer_skip
=
outer_skip
,
use_mlp
=
use_mlp
,
factorization
=
self
.
factorization
,
separable
=
self
.
separable
,
rank
=
self
.
rank
,
complex_activation
=
self
.
complex_activation
,
spectral_layers
=
self
.
spectral_layers
)
self
.
blocks
.
append
(
block
)
# decoder
decoder_hidden_dim
=
self
.
embed_dim
current_dim
=
self
.
embed_dim
+
self
.
big_skip
*
self
.
in_chans
decoder_modules
=
[]
for
i
in
range
(
self
.
encoder_layers
):
decoder_modules
.
append
(
nn
.
Conv2d
(
current_dim
,
decoder_hidden_dim
,
1
,
bias
=
True
))
decoder_modules
.
append
(
self
.
activation_function
())
current_dim
=
decoder_hidden_dim
decoder_modules
.
append
(
nn
.
Conv2d
(
current_dim
,
self
.
out_chans
,
1
,
bias
=
False
))
self
.
decoder
=
nn
.
Sequential
(
*
decoder_modules
)
# trunc_normal_(self.pos_embed, std=.02)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
)
or
isinstance
(
m
,
nn
.
Conv2d
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
#nn.init.normal_(m.weight, std=0.02)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
)
or
isinstance
(
m
,
FusedLayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'pos_embed'
,
'cls_token'
}
def
forward_features
(
self
,
x
):
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks
:
x
=
blk
(
x
)
return
x
def
forward
(
self
,
x
):
if
self
.
big_skip
:
residual
=
x
x
=
self
.
encoder
(
x
)
if
self
.
pos_embed
is
not
None
:
x
=
x
+
self
.
pos_embed
x
=
self
.
forward_features
(
x
)
if
self
.
big_skip
:
x
=
torch
.
cat
((
x
,
residual
),
dim
=
1
)
x
=
self
.
decoder
(
x
)
return
x
examples/sfno_examples/train_sfno.ipynb
0 → 100644
View file @
e7ceb9c8
This source diff could not be displayed because it is too large. You can
view the blob
instead.
examples/sfno_examples/train_sfno.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
os
import
time
from
tqdm
import
tqdm
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
torch.cuda
import
amp
import
numpy
as
np
import
pandas
as
pd
import
matplotlib.pyplot
as
plt
# wandb logging
import
wandb
wandb
.
login
()
import
sys
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../"
))
from
pde_sphere
import
SphereSolver
def
l2loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
,
squared
=
False
):
loss
=
solver
.
integrate_grid
((
prd
-
tar
)
**
2
,
dimensionless
=
True
).
sum
(
dim
=-
1
)
if
relative
:
loss
=
loss
/
solver
.
integrate_grid
(
tar
**
2
,
dimensionless
=
True
).
sum
(
dim
=-
1
)
if
not
squared
:
loss
=
torch
.
sqrt
(
loss
)
loss
=
loss
.
mean
()
return
loss
def
spectral_l2loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
,
squared
=
False
):
# compute coefficients
coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
prd
-
tar
))
coeffs
=
coeffs
[...,
0
]
**
2
+
coeffs
[...,
1
]
**
2
norm2
=
coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
coeffs
[...,
:,
1
:],
dim
=-
1
)
loss
=
torch
.
sum
(
norm2
,
dim
=
(
-
1
,
-
2
))
if
relative
:
tar_coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
tar
))
tar_coeffs
=
tar_coeffs
[...,
0
]
**
2
+
tar_coeffs
[...,
1
]
**
2
tar_norm2
=
tar_coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
tar_coeffs
[...,
:,
1
:],
dim
=-
1
)
tar_norm2
=
torch
.
sum
(
tar_norm2
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
tar_norm2
if
not
squared
:
loss
=
torch
.
sqrt
(
loss
)
loss
=
loss
.
mean
()
return
loss
def
spectral_loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
,
squared
=
False
):
# gradient weighting factors
lmax
=
solver
.
sht
.
lmax
ls
=
torch
.
arange
(
lmax
).
float
()
spectral_weights
=
(
ls
*
(
ls
+
1
)).
reshape
(
1
,
1
,
-
1
,
1
).
to
(
prd
.
device
)
# compute coefficients
coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
prd
-
tar
))
coeffs
=
coeffs
[...,
0
]
**
2
+
coeffs
[...,
1
]
**
2
coeffs
=
spectral_weights
*
coeffs
norm2
=
coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
coeffs
[...,
:,
1
:],
dim
=-
1
)
loss
=
torch
.
sum
(
norm2
,
dim
=
(
-
1
,
-
2
))
if
relative
:
tar_coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
tar
))
tar_coeffs
=
tar_coeffs
[...,
0
]
**
2
+
tar_coeffs
[...,
1
]
**
2
tar_coeffs
=
spectral_weights
*
tar_coeffs
tar_norm2
=
tar_coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
tar_coeffs
[...,
:,
1
:],
dim
=-
1
)
tar_norm2
=
torch
.
sum
(
tar_norm2
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
tar_norm2
if
not
squared
:
loss
=
torch
.
sqrt
(
loss
)
loss
=
loss
.
mean
()
return
loss
def
h1loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
,
squared
=
False
):
# gradient weighting factors
lmax
=
solver
.
sht
.
lmax
ls
=
torch
.
arange
(
lmax
).
float
()
spectral_weights
=
(
ls
*
(
ls
+
1
)).
reshape
(
1
,
1
,
-
1
,
1
).
to
(
prd
.
device
)
# compute coefficients
coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
prd
-
tar
))
coeffs
=
coeffs
[...,
0
]
**
2
+
coeffs
[...,
1
]
**
2
h1_coeffs
=
spectral_weights
*
coeffs
h1_norm2
=
h1_coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
h1_coeffs
[...,
:,
1
:],
dim
=-
1
)
l2_norm2
=
coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
coeffs
[...,
:,
1
:],
dim
=-
1
)
h1_loss
=
torch
.
sum
(
h1_norm2
,
dim
=
(
-
1
,
-
2
))
l2_loss
=
torch
.
sum
(
l2_norm2
,
dim
=
(
-
1
,
-
2
))
# strictly speaking this is not exactly h1 loss
if
not
squared
:
loss
=
torch
.
sqrt
(
h1_loss
)
+
torch
.
sqrt
(
l2_loss
)
else
:
loss
=
h1_loss
+
l2_loss
if
relative
:
raise
NotImplementedError
(
"Relative H1 loss not implemented"
)
loss
=
loss
.
mean
()
return
loss
def
fluct_l2loss_sphere
(
solver
,
prd
,
tar
,
inp
,
relative
=
False
,
polar_opt
=
0
):
# compute the weighting factor first
fluct
=
solver
.
integrate_grid
((
tar
-
inp
)
**
2
,
dimensionless
=
True
,
polar_opt
=
polar_opt
)
weight
=
fluct
/
torch
.
sum
(
fluct
,
dim
=-
1
,
keepdim
=
True
)
# weight = weight.reshape(*weight.shape, 1, 1)
loss
=
weight
*
solver
.
integrate_grid
((
prd
-
tar
)
**
2
,
dimensionless
=
True
,
polar_opt
=
polar_opt
)
if
relative
:
loss
=
loss
/
(
weight
*
solver
.
integrate_grid
(
tar
**
2
,
dimensionless
=
True
,
polar_opt
=
polar_opt
))
loss
=
torch
.
mean
(
loss
)
return
loss
def
main
(
train
=
True
,
load_checkpoint
=
False
,
enable_amp
=
False
):
# set seed
torch
.
manual_seed
(
333
)
torch
.
cuda
.
manual_seed
(
333
)
# set device
device
=
torch
.
device
(
'cuda:1'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
device
.
index
)
# dataset
from
utils.pde_dataset
import
PdeDataset
# 1 hour prediction steps
dt
=
1
*
3600
dt_solver
=
150
nsteps
=
dt
//
dt_solver
dataset
=
PdeDataset
(
dt
=
dt
,
nsteps
=
nsteps
,
dims
=
(
256
,
512
),
device
=
device
,
normalize
=
True
)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
,
persistent_workers
=
False
)
solver
=
dataset
.
solver
.
to
(
device
)
nlat
=
dataset
.
nlat
nlon
=
dataset
.
nlon
# training function
def
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
=
None
,
nepochs
=
20
,
nfuture
=
0
,
num_examples
=
256
,
num_valid
=
8
,
loss_fn
=
'l2'
):
train_start
=
time
.
time
()
for
epoch
in
range
(
nepochs
):
# time each epoch
epoch_start
=
time
.
time
()
dataloader
.
dataset
.
set_initial_condition
(
'random'
)
dataloader
.
dataset
.
set_num_examples
(
num_examples
)
# do the training
acc_loss
=
0
model
.
train
()
for
inp
,
tar
in
dataloader
:
with
amp
.
autocast
(
enabled
=
enable_amp
):
prd
=
model
(
inp
)
for
_
in
range
(
nfuture
):
prd
=
model
(
prd
)
if
loss_fn
==
'l2'
:
loss
=
l2loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
)
elif
loss_fn
==
'h1'
:
loss
=
h1loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
)
elif
loss_fn
==
'spectral'
:
loss
=
spectral_loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
)
elif
loss_fn
==
'fluct'
:
loss
=
fluct_l2loss_sphere
(
solver
,
prd
,
tar
,
inp
,
relative
=
True
)
else
:
raise
NotImplementedError
(
f
'Unknown loss function
{
loss_fn
}
'
)
acc_loss
+=
loss
.
item
()
*
inp
.
size
(
0
)
optimizer
.
zero_grad
(
set_to_none
=
True
)
# gscaler.scale(loss).backward()
gscaler
.
scale
(
loss
).
backward
()
gscaler
.
step
(
optimizer
)
gscaler
.
update
()
acc_loss
=
acc_loss
/
len
(
dataloader
.
dataset
)
dataloader
.
dataset
.
set_initial_condition
(
'random'
)
dataloader
.
dataset
.
set_num_examples
(
num_valid
)
# perform validation
valid_loss
=
0
model
.
eval
()
with
torch
.
no_grad
():
for
inp
,
tar
in
dataloader
:
prd
=
model
(
inp
)
for
_
in
range
(
nfuture
):
prd
=
model
(
prd
)
loss
=
l2loss_sphere
(
solver
,
prd
,
tar
,
relative
=
True
)
valid_loss
+=
loss
.
item
()
*
inp
.
size
(
0
)
valid_loss
=
valid_loss
/
len
(
dataloader
.
dataset
)
if
scheduler
is
not
None
:
scheduler
.
step
(
valid_loss
)
epoch_time
=
time
.
time
()
-
epoch_start
print
(
f
'--------------------------------------------------------------------------------'
)
print
(
f
'Epoch
{
epoch
}
summary:'
)
print
(
f
'time taken:
{
epoch_time
}
'
)
print
(
f
'accumulated training loss:
{
acc_loss
}
'
)
print
(
f
'relative validation loss:
{
valid_loss
}
'
)
if
wandb
.
run
is
not
None
:
current_lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
wandb
.
log
({
"loss"
:
acc_loss
,
"validation loss"
:
valid_loss
,
"learning rate"
:
current_lr
})
train_time
=
time
.
time
()
-
train_start
print
(
f
'--------------------------------------------------------------------------------'
)
print
(
f
'done. Training took
{
train_time
}
.'
)
return
valid_loss
# rolls out the FNO and compares to the classical solver
def
autoregressive_inference
(
model
,
dataset
,
path_root
,
nsteps
,
autoreg_steps
=
10
,
nskip
=
1
,
plot_channel
=
0
,
nics
=
20
):
model
.
eval
()
losses
=
np
.
zeros
(
nics
)
fno_times
=
np
.
zeros
(
nics
)
nwp_times
=
np
.
zeros
(
nics
)
for
iic
in
range
(
nics
):
ic
=
dataset
.
solver
.
random_initial_condition
(
mach
=
0.2
)
inp_mean
=
dataset
.
inp_mean
inp_var
=
dataset
.
inp_var
prd
=
(
dataset
.
solver
.
spec2grid
(
ic
)
-
inp_mean
)
/
torch
.
sqrt
(
inp_var
)
prd
=
prd
.
unsqueeze
(
0
)
uspec
=
ic
.
clone
()
# ML model
start_time
=
time
.
time
()
for
i
in
range
(
1
,
autoreg_steps
+
1
):
# evaluate the ML model
prd
=
model
(
prd
)
if
iic
==
nics
-
1
and
nskip
>
0
and
i
%
nskip
==
0
:
# do plotting
fig
=
plt
.
figure
(
figsize
=
(
7.5
,
6
))
dataset
.
solver
.
plot_griddata
(
prd
[
0
,
plot_channel
],
fig
,
vmax
=
4
,
vmin
=-
4
)
plt
.
savefig
(
path_root
+
'_pred_'
+
str
(
i
//
nskip
)
+
'.png'
)
plt
.
clf
()
fno_times
[
iic
]
=
time
.
time
()
-
start_time
# classical model
start_time
=
time
.
time
()
for
i
in
range
(
1
,
autoreg_steps
+
1
):
# advance classical model
uspec
=
dataset
.
solver
.
timestep
(
uspec
,
nsteps
)
if
iic
==
nics
-
1
and
i
%
nskip
==
0
and
nskip
>
0
:
ref
=
(
dataset
.
solver
.
spec2grid
(
uspec
)
-
inp_mean
)
/
torch
.
sqrt
(
inp_var
)
fig
=
plt
.
figure
(
figsize
=
(
7.5
,
6
))
dataset
.
solver
.
plot_griddata
(
ref
[
plot_channel
],
fig
,
vmax
=
4
,
vmin
=-
4
)
plt
.
savefig
(
path_root
+
'_truth_'
+
str
(
i
//
nskip
)
+
'.png'
)
plt
.
clf
()
nwp_times
[
iic
]
=
time
.
time
()
-
start_time
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref
=
dataset
.
solver
.
spec2grid
(
uspec
)
prd
=
prd
*
torch
.
sqrt
(
inp_var
)
+
inp_mean
losses
[
iic
]
=
l2loss_sphere
(
solver
,
prd
,
ref
,
relative
=
True
).
item
()
return
losses
,
fno_times
,
nwp_times
def
count_parameters
(
model
):
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# prepare dicts containing models and corresponding metrics
models
=
{}
metrics
=
{}
# # U-Net if installed
# from models.unet import UNet
# models['unet_baseline'] = partial(UNet)
# SFNO and FNO models
from
models.sfno
import
SphericalFourierNeuralOperatorNet
as
SFNO
# SFNO models
models
[
'sfno_sc3_layer4_edim256_linear'
]
=
partial
(
SFNO
,
spectral_transform
=
'sht'
,
filter_type
=
'linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
operator_type
=
'vector'
)
models
[
'sfno_sc3_layer4_edim256_real'
]
=
partial
(
SFNO
,
spectral_transform
=
'sht'
,
filter_type
=
'non-linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
complex_activation
=
'real'
,
operator_type
=
'diagonal'
)
# FNO models
models
[
'fno_sc3_layer4_edim256_linear'
]
=
partial
(
SFNO
,
spectral_transform
=
'fft'
,
filter_type
=
'linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
operator_type
=
'diagonal'
)
models
[
'fno_sc3_layer4_edim256_real'
]
=
partial
(
SFNO
,
spectral_transform
=
'fft'
,
filter_type
=
'non-linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
complex_activation
=
'real'
)
# iterate over models and train each model
root_path
=
os
.
path
.
dirname
(
__file__
)
for
model_name
,
model_handle
in
models
.
items
():
model
=
model_handle
().
to
(
device
)
metrics
[
model_name
]
=
{}
num_params
=
count_parameters
(
model
)
print
(
f
'number of trainable params:
{
num_params
}
'
)
metrics
[
model_name
][
'num_params'
]
=
num_params
if
load_checkpoint
:
model
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
root_path
,
'checkpoints/'
+
model_name
)))
# run the training
if
train
:
run
=
wandb
.
init
(
project
=
"sfno spherical swe"
,
group
=
model_name
,
name
=
model_name
+
'_'
+
str
(
time
.
time
()),
config
=
model_handle
.
keywords
)
# optimizer:
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1E-3
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'min'
)
gscaler
=
amp
.
GradScaler
(
enabled
=
enable_amp
)
start_time
=
time
.
time
()
print
(
f
'Training
{
model_name
}
, single step'
)
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
,
nepochs
=
200
,
loss_fn
=
'l2'
)
# multistep training
print
(
f
'Training
{
model_name
}
, two step'
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5E-5
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'min'
)
gscaler
=
amp
.
GradScaler
(
enabled
=
enable_amp
)
dataloader
.
dataset
.
nsteps
=
2
*
dt
//
dt_solver
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
,
nepochs
=
20
,
nfuture
=
1
)
dataloader
.
dataset
.
nsteps
=
1
*
dt
//
dt_solver
training_time
=
time
.
time
()
-
start_time
run
.
finish
()
torch
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
root_path
,
'checkpoints/'
+
model_name
))
# set seed
torch
.
manual_seed
(
333
)
torch
.
cuda
.
manual_seed
(
333
)
with
torch
.
inference_mode
():
losses
,
fno_times
,
nwp_times
=
autoregressive_inference
(
model
,
dataset
,
os
.
path
.
join
(
root_path
,
'paper_figures/'
+
model_name
),
nsteps
=
nsteps
,
autoreg_steps
=
10
)
metrics
[
model_name
][
'loss_mean'
]
=
np
.
mean
(
losses
)
metrics
[
model_name
][
'loss_std'
]
=
np
.
std
(
losses
)
metrics
[
model_name
][
'fno_time_mean'
]
=
np
.
mean
(
fno_times
)
metrics
[
model_name
][
'fno_time_std'
]
=
np
.
std
(
fno_times
)
metrics
[
model_name
][
'nwp_time_mean'
]
=
np
.
mean
(
nwp_times
)
metrics
[
model_name
][
'nwp_time_std'
]
=
np
.
std
(
nwp_times
)
if
train
:
metrics
[
model_name
][
'training_time'
]
=
training_time
df
=
pd
.
DataFrame
(
metrics
)
df
.
to_pickle
(
os
.
path
.
join
(
root_path
,
'output_data/metrics.pkl'
))
if
__name__
==
"__main__"
:
import
torch.multiprocessing
as
mp
mp
.
set_start_method
(
'forkserver'
,
force
=
True
)
main
(
train
=
True
,
load_checkpoint
=
False
,
enable_amp
=
False
)
examples/sfno_examples/utils/pde_dataset.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
torch
import
os
from
math
import
ceil
import
sys
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))),
"torch_harmonics"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))),
"examples"
))
from
shallow_water_equations
import
ShallowWaterSolver
class
PdeDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Custom Dataset class for PDE training data"""
def
__init__
(
self
,
dt
,
nsteps
,
dims
=
(
384
,
768
),
pde
=
'shallow water equations'
,
initial_condition
=
'random'
,
num_examples
=
32
,
device
=
torch
.
device
(
'cpu'
),
normalize
=
True
,
stream
=
None
):
self
.
num_examples
=
num_examples
self
.
device
=
device
self
.
stream
=
stream
self
.
nlat
=
dims
[
0
]
self
.
nlon
=
dims
[
1
]
# number of solver steps used to compute the target
self
.
nsteps
=
nsteps
self
.
normalize
=
normalize
if
pde
==
'shallow water equations'
:
lmax
=
ceil
(
self
.
nlat
/
3
)
mmax
=
lmax
dt_solver
=
dt
/
float
(
self
.
nsteps
)
self
.
solver
=
ShallowWaterSolver
(
self
.
nlat
,
self
.
nlon
,
dt_solver
,
lmax
=
lmax
,
mmax
=
mmax
,
grid
=
'equiangular'
).
to
(
self
.
device
).
float
()
else
:
raise
NotImplementedError
self
.
set_initial_condition
(
ictype
=
initial_condition
)
if
self
.
normalize
:
inp0
,
_
=
self
.
_get_sample
()
self
.
inp_mean
=
torch
.
mean
(
inp0
,
dim
=
(
-
1
,
-
2
)).
reshape
(
-
1
,
1
,
1
)
self
.
inp_var
=
torch
.
var
(
inp0
,
dim
=
(
-
1
,
-
2
)).
reshape
(
-
1
,
1
,
1
)
def
__len__
(
self
):
length
=
self
.
num_examples
if
self
.
ictype
==
'random'
else
1
return
length
def
set_initial_condition
(
self
,
ictype
=
'random'
):
self
.
ictype
=
ictype
def
set_num_examples
(
self
,
num_examples
=
32
):
self
.
num_examples
=
num_examples
def
_get_sample
(
self
):
if
self
.
ictype
==
'random'
:
inp
=
self
.
solver
.
random_initial_condition
(
mach
=
0.2
)
elif
self
.
ictype
==
'galewsky'
:
inp
=
self
.
solver
.
galewsky_initial_condition
()
# solve pde for n steps to return the target
tar
=
self
.
solver
.
timestep
(
inp
,
self
.
nsteps
)
inp
=
self
.
solver
.
spec2grid
(
inp
)
tar
=
self
.
solver
.
spec2grid
(
tar
)
return
inp
,
tar
def
__getitem__
(
self
,
index
):
# if self.stream is None:
# self.stream = torch.cuda.Stream()
# with torch.cuda.stream(self.stream):
# with torch.inference_mode():
# with torch.no_grad():
# inp, tar = self._get_sample()
# if self.normalize:
# inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var)
# tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var)
# self.stream.synchronize()
with
torch
.
inference_mode
():
with
torch
.
no_grad
():
inp
,
tar
=
self
.
_get_sample
()
if
self
.
normalize
:
inp
=
(
inp
-
self
.
inp_mean
)
/
torch
.
sqrt
(
self
.
inp_var
)
tar
=
(
tar
-
self
.
inp_mean
)
/
torch
.
sqrt
(
self
.
inp_var
)
return
inp
.
clone
(),
tar
.
clone
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment