Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
e7ceb9c8
Commit
e7ceb9c8
authored
Jun 07, 2023
by
Boris Bonev
Browse files
Adding SFNO examples
parent
24490256
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2533 additions
and
0 deletions
+2533
-0
examples/sfno_examples/models/activations.py
examples/sfno_examples/models/activations.py
+96
-0
examples/sfno_examples/models/contractions.py
examples/sfno_examples/models/contractions.py
+141
-0
examples/sfno_examples/models/factorizations.py
examples/sfno_examples/models/factorizations.py
+199
-0
examples/sfno_examples/models/layers.py
examples/sfno_examples/models/layers.py
+552
-0
examples/sfno_examples/models/sfno.py
examples/sfno_examples/models/sfno.py
+517
-0
examples/sfno_examples/train_sfno.ipynb
examples/sfno_examples/train_sfno.ipynb
+488
-0
examples/sfno_examples/train_sfno.py
examples/sfno_examples/train_sfno.py
+419
-0
examples/sfno_examples/utils/pde_dataset.py
examples/sfno_examples/utils/pde_dataset.py
+121
-0
No files found.
examples/sfno_examples/models/activations.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
torch
import
torch.nn
as
nn
# complex activation functions
class
ComplexCardioid
(
nn
.
Module
):
"""
Complex Cardioid activation function
"""
def
__init__
(
self
):
super
(
ComplexCardioid
,
self
).
__init__
()
def
forward
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
0.5
*
(
1.
+
torch
.
cos
(
z
.
angle
()))
*
z
return
out
class
ComplexReLU
(
nn
.
Module
):
"""
Complex-valued variants of the ReLU activation function
"""
def
__init__
(
self
,
negative_slope
=
0.
,
mode
=
"real"
,
bias_shape
=
None
,
scale
=
1.
):
super
(
ComplexReLU
,
self
).
__init__
()
# store parameters
self
.
mode
=
mode
if
self
.
mode
in
[
"modulus"
,
"halfplane"
]:
if
bias_shape
is
not
None
:
self
.
bias
=
nn
.
Parameter
(
scale
*
torch
.
ones
(
bias_shape
,
dtype
=
torch
.
float32
))
else
:
self
.
bias
=
nn
.
Parameter
(
scale
*
torch
.
ones
((
1
),
dtype
=
torch
.
float32
))
else
:
self
.
bias
=
0
self
.
negative_slope
=
negative_slope
self
.
act
=
nn
.
LeakyReLU
(
negative_slope
=
negative_slope
)
def
forward
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
mode
==
"cartesian"
:
zr
=
torch
.
view_as_real
(
z
)
za
=
self
.
act
(
zr
)
out
=
torch
.
view_as_complex
(
za
)
elif
self
.
mode
==
"modulus"
:
zabs
=
torch
.
sqrt
(
torch
.
square
(
z
.
real
)
+
torch
.
square
(
z
.
imag
))
out
=
torch
.
where
(
zabs
+
self
.
bias
>
0
,
(
zabs
+
self
.
bias
)
*
z
/
zabs
,
0.0
)
elif
self
.
mode
==
"cardioid"
:
out
=
0.5
*
(
1.
+
torch
.
cos
(
z
.
angle
()))
*
z
# elif self.mode == "halfplane":
# # bias is an angle parameter in this case
# modified_angle = torch.angle(z) - self.bias
# condition = torch.logical_and( (0. <= modified_angle), (modified_angle < torch.pi/2.) )
# out = torch.where(condition, z, self.negative_slope * z)
elif
self
.
mode
==
"real"
:
zr
=
torch
.
view_as_real
(
z
)
outr
=
zr
.
clone
()
outr
[...,
0
]
=
self
.
act
(
zr
[...,
0
])
out
=
torch
.
view_as_complex
(
outr
)
else
:
raise
NotImplementedError
return
out
\ No newline at end of file
examples/sfno_examples/models/contractions.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
torch
"""
Contains complex contractions wrapped into jit for harmonic layers
"""
@
torch
.
jit
.
script
def
compl_contract2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bixys,kixr->srbkx"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
@
torch
.
jit
.
script
def
compl_contract2d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
res
=
torch
.
einsum
(
"bixy,kix->bkx"
,
ac
,
bc
)
return
torch
.
view_as_real
(
res
)
@
torch
.
jit
.
script
def
compl_contract_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bins,kinr->srbkn"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
@
torch
.
jit
.
script
def
compl_contract_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
res
=
torch
.
einsum
(
"bin,kin->bkn"
,
ac
,
bc
)
return
torch
.
view_as_real
(
res
)
# Helper routines for spherical MLPs
@
torch
.
jit
.
script
def
compl_mul1d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bixs,ior->srbox"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
@
torch
.
jit
.
script
def
compl_mul1d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
resc
=
torch
.
einsum
(
"bix,io->box"
,
ac
,
bc
)
res
=
torch
.
view_as_real
(
resc
)
return
res
@
torch
.
jit
.
script
def
compl_muladd1d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
res
=
compl_mul1d_fwd
(
a
,
b
)
+
c
return
res
@
torch
.
jit
.
script
def
compl_muladd1d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmpcc
=
torch
.
view_as_complex
(
compl_mul1d_fwd_c
(
a
,
b
))
cc
=
torch
.
view_as_complex
(
c
)
return
torch
.
view_as_real
(
tmpcc
+
cc
)
# Helper routines for FFT MLPs
@
torch
.
jit
.
script
def
compl_mul2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bixys,ior->srboxy"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
@
torch
.
jit
.
script
def
compl_mul2d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
resc
=
torch
.
einsum
(
"bixy,io->boxy"
,
ac
,
bc
)
res
=
torch
.
view_as_real
(
resc
)
return
res
@
torch
.
jit
.
script
def
compl_muladd2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
res
=
compl_mul2d_fwd
(
a
,
b
)
+
c
return
res
@
torch
.
jit
.
script
def
compl_muladd2d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmpcc
=
torch
.
view_as_complex
(
compl_mul2d_fwd_c
(
a
,
b
))
cc
=
torch
.
view_as_complex
(
c
)
return
torch
.
view_as_real
(
tmpcc
+
cc
)
@
torch
.
jit
.
script
def
real_mul2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
torch
.
einsum
(
"bixy,io->boxy"
,
a
,
b
)
return
out
@
torch
.
jit
.
script
def
real_muladd2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
compl_mul2d_fwd_c
(
a
,
b
)
+
c
# for all the experimental layers
# @torch.jit.script
# def compl_exp_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
# resc = torch.einsum("bixy,xio->boxy", ac, bc)
# res = torch.view_as_real(resc)
# return res
# @torch.jit.script
# def compl_exp_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# tmpcc = torch.view_as_complex(compl_exp_mul2d_fwd(a, b))
# cc = torch.view_as_complex(c)
# return torch.view_as_real(tmpcc + cc)
examples/sfno_examples/models/factorizations.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
torch
import
tensorly
as
tl
tl
.
set_backend
(
'pytorch'
)
from
tltorch.factorized_tensors.core
import
FactorizedTensor
einsum_symbols
=
'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
def
_contract_dense
(
x
,
weight
,
separable
=
False
,
operator_type
=
'diagonal'
):
order
=
tl
.
ndim
(
x
)
# batch-size, in_channels, x, y...
x_syms
=
list
(
einsum_symbols
[:
order
])
# in_channels, out_channels, x, y...
weight_syms
=
list
(
x_syms
[
1
:])
# no batch-size
# batch-size, out_channels, x, y...
if
separable
:
out_syms
=
[
x_syms
[
0
]]
+
list
(
weight_syms
)
else
:
weight_syms
.
insert
(
1
,
einsum_symbols
[
order
])
# outputs
out_syms
=
list
(
weight_syms
)
out_syms
[
0
]
=
x_syms
[
0
]
if
operator_type
==
'diagonal'
:
pass
elif
operator_type
==
'block-diagonal'
:
weight_syms
.
insert
(
-
1
,
einsum_symbols
[
order
+
1
])
out_syms
[
-
1
]
=
weight_syms
[
-
2
]
elif
operator_type
==
'vector'
:
weight_syms
.
pop
()
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
eq
=
''
.
join
(
x_syms
)
+
','
+
''
.
join
(
weight_syms
)
+
'->'
+
''
.
join
(
out_syms
)
if
not
torch
.
is_tensor
(
weight
):
weight
=
weight
.
to_tensor
()
return
tl
.
einsum
(
eq
,
x
,
weight
)
def
_contract_cp
(
x
,
cp_weight
,
separable
=
False
,
operator_type
=
'diagonal'
):
order
=
tl
.
ndim
(
x
)
x_syms
=
str
(
einsum_symbols
[:
order
])
rank_sym
=
einsum_symbols
[
order
]
out_sym
=
einsum_symbols
[
order
+
1
]
out_syms
=
list
(
x_syms
)
if
separable
:
factor_syms
=
[
einsum_symbols
[
1
]
+
rank_sym
]
#in only
else
:
out_syms
[
1
]
=
out_sym
factor_syms
=
[
einsum_symbols
[
1
]
+
rank_sym
,
out_sym
+
rank_sym
]
#in, out
factor_syms
+=
[
xs
+
rank_sym
for
xs
in
x_syms
[
2
:]]
#x, y, ...
if
operator_type
==
'diagonal'
:
pass
elif
operator_type
==
'block-diagonal'
:
out_syms
[
-
1
]
=
einsum_symbols
[
order
+
2
]
factor_syms
+=
[
out_syms
[
-
1
]
+
rank_sym
]
elif
operator_type
==
'vector'
:
factor_syms
.
pop
()
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
eq
=
x_syms
+
','
+
rank_sym
+
','
+
','
.
join
(
factor_syms
)
+
'->'
+
''
.
join
(
out_syms
)
return
tl
.
einsum
(
eq
,
x
,
cp_weight
.
weights
,
*
cp_weight
.
factors
)
def
_contract_tucker
(
x
,
tucker_weight
,
separable
=
False
,
operator_type
=
'diagonal'
):
order
=
tl
.
ndim
(
x
)
x_syms
=
str
(
einsum_symbols
[:
order
])
out_sym
=
einsum_symbols
[
order
]
out_syms
=
list
(
x_syms
)
if
separable
:
core_syms
=
einsum_symbols
[
order
+
1
:
2
*
order
]
# factor_syms = [einsum_symbols[1]+core_syms[0]] #in only
factor_syms
=
[
xs
+
rs
for
(
xs
,
rs
)
in
zip
(
x_syms
[
1
:],
core_syms
)]
#x, y, ...
else
:
core_syms
=
einsum_symbols
[
order
+
1
:
2
*
order
+
1
]
out_syms
[
1
]
=
out_sym
factor_syms
=
[
einsum_symbols
[
1
]
+
core_syms
[
0
],
out_sym
+
core_syms
[
1
]]
#out, in
factor_syms
+=
[
xs
+
rs
for
(
xs
,
rs
)
in
zip
(
x_syms
[
2
:],
core_syms
[
2
:])]
#x, y, ...
if
operator_type
==
'diagonal'
:
pass
elif
operator_type
==
'block-diagonal'
:
raise
NotImplementedError
(
f
"Operator type
{
operator_type
}
not implemented for Tucker"
)
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
eq
=
x_syms
+
','
+
core_syms
+
','
+
','
.
join
(
factor_syms
)
+
'->'
+
''
.
join
(
out_syms
)
return
tl
.
einsum
(
eq
,
x
,
tucker_weight
.
core
,
*
tucker_weight
.
factors
)
def
_contract_tt
(
x
,
tt_weight
,
separable
=
False
,
operator_type
=
'diagonal'
):
order
=
tl
.
ndim
(
x
)
x_syms
=
list
(
einsum_symbols
[:
order
])
weight_syms
=
list
(
x_syms
[
1
:])
# no batch-size
if
not
separable
:
weight_syms
.
insert
(
1
,
einsum_symbols
[
order
])
# outputs
out_syms
=
list
(
weight_syms
)
out_syms
[
0
]
=
x_syms
[
0
]
else
:
out_syms
=
list
(
x_syms
)
if
operator_type
==
'diagonal'
:
pass
elif
operator_type
==
'block-diagonal'
:
weight_syms
.
insert
(
-
1
,
einsum_symbols
[
order
+
1
])
out_syms
[
-
1
]
=
weight_syms
[
-
2
]
elif
operator_type
==
'vector'
:
weight_syms
.
pop
()
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
rank_syms
=
list
(
einsum_symbols
[
order
+
2
:])
tt_syms
=
[]
for
i
,
s
in
enumerate
(
weight_syms
):
tt_syms
.
append
([
rank_syms
[
i
],
s
,
rank_syms
[
i
+
1
]])
eq
=
''
.
join
(
x_syms
)
+
','
+
','
.
join
(
''
.
join
(
f
)
for
f
in
tt_syms
)
+
'->'
+
''
.
join
(
out_syms
)
return
tl
.
einsum
(
eq
,
x
,
*
tt_weight
.
factors
)
def
get_contract_fun
(
weight
,
implementation
=
'reconstructed'
,
separable
=
False
):
"""Generic ND implementation of Fourier Spectral Conv contraction
Parameters
----------
weight : tensorly-torch's FactorizedTensor
implementation : {'reconstructed', 'factorized'}, default is 'reconstructed'
whether to reconstruct the weight and do a forward pass (reconstructed)
or contract directly the factors of the factorized weight with the input (factorized)
Returns
-------
function : (x, weight) -> x * weight in Fourier space
"""
if
implementation
==
'reconstructed'
:
return
_contract_dense
elif
implementation
==
'factorized'
:
if
torch
.
is_tensor
(
weight
):
return
_contract_dense
elif
isinstance
(
weight
,
FactorizedTensor
):
if
weight
.
name
.
lower
()
==
'complexdense'
:
return
_contract_dense
elif
weight
.
name
.
lower
()
==
'complextucker'
:
return
_contract_tucker
elif
weight
.
name
.
lower
()
==
'complextt'
:
return
_contract_tt
elif
weight
.
name
.
lower
()
==
'complexcp'
:
return
_contract_cp
else
:
raise
ValueError
(
f
'Got unexpected factorized weight type
{
weight
.
name
}
'
)
else
:
raise
ValueError
(
f
'Got unexpected weight type of class
{
weight
.
__class__
.
__name__
}
'
)
else
:
raise
ValueError
(
f
'Got
{
implementation
=
}
, expected "reconstructed" or "factorized"'
)
examples/sfno_examples/models/layers.py
0 → 100644
View file @
e7ceb9c8
This diff is collapsed.
Click to expand it.
examples/sfno_examples/models/sfno.py
0 → 100644
View file @
e7ceb9c8
This diff is collapsed.
Click to expand it.
examples/sfno_examples/train_sfno.ipynb
0 → 100644
View file @
e7ceb9c8
This source diff could not be displayed because it is too large. You can
view the blob
instead.
examples/sfno_examples/train_sfno.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
os
import
time
from
tqdm
import
tqdm
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
torch.cuda
import
amp
import
numpy
as
np
import
pandas
as
pd
import
matplotlib.pyplot
as
plt
# wandb logging
import
wandb
wandb
.
login
()
import
sys
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../"
))
from
pde_sphere
import
SphereSolver
def
l2loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
,
squared
=
False
):
loss
=
solver
.
integrate_grid
((
prd
-
tar
)
**
2
,
dimensionless
=
True
).
sum
(
dim
=-
1
)
if
relative
:
loss
=
loss
/
solver
.
integrate_grid
(
tar
**
2
,
dimensionless
=
True
).
sum
(
dim
=-
1
)
if
not
squared
:
loss
=
torch
.
sqrt
(
loss
)
loss
=
loss
.
mean
()
return
loss
def
spectral_l2loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
,
squared
=
False
):
# compute coefficients
coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
prd
-
tar
))
coeffs
=
coeffs
[...,
0
]
**
2
+
coeffs
[...,
1
]
**
2
norm2
=
coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
coeffs
[...,
:,
1
:],
dim
=-
1
)
loss
=
torch
.
sum
(
norm2
,
dim
=
(
-
1
,
-
2
))
if
relative
:
tar_coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
tar
))
tar_coeffs
=
tar_coeffs
[...,
0
]
**
2
+
tar_coeffs
[...,
1
]
**
2
tar_norm2
=
tar_coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
tar_coeffs
[...,
:,
1
:],
dim
=-
1
)
tar_norm2
=
torch
.
sum
(
tar_norm2
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
tar_norm2
if
not
squared
:
loss
=
torch
.
sqrt
(
loss
)
loss
=
loss
.
mean
()
return
loss
def
spectral_loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
,
squared
=
False
):
# gradient weighting factors
lmax
=
solver
.
sht
.
lmax
ls
=
torch
.
arange
(
lmax
).
float
()
spectral_weights
=
(
ls
*
(
ls
+
1
)).
reshape
(
1
,
1
,
-
1
,
1
).
to
(
prd
.
device
)
# compute coefficients
coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
prd
-
tar
))
coeffs
=
coeffs
[...,
0
]
**
2
+
coeffs
[...,
1
]
**
2
coeffs
=
spectral_weights
*
coeffs
norm2
=
coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
coeffs
[...,
:,
1
:],
dim
=-
1
)
loss
=
torch
.
sum
(
norm2
,
dim
=
(
-
1
,
-
2
))
if
relative
:
tar_coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
tar
))
tar_coeffs
=
tar_coeffs
[...,
0
]
**
2
+
tar_coeffs
[...,
1
]
**
2
tar_coeffs
=
spectral_weights
*
tar_coeffs
tar_norm2
=
tar_coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
tar_coeffs
[...,
:,
1
:],
dim
=-
1
)
tar_norm2
=
torch
.
sum
(
tar_norm2
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
tar_norm2
if
not
squared
:
loss
=
torch
.
sqrt
(
loss
)
loss
=
loss
.
mean
()
return
loss
def
h1loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
,
squared
=
False
):
# gradient weighting factors
lmax
=
solver
.
sht
.
lmax
ls
=
torch
.
arange
(
lmax
).
float
()
spectral_weights
=
(
ls
*
(
ls
+
1
)).
reshape
(
1
,
1
,
-
1
,
1
).
to
(
prd
.
device
)
# compute coefficients
coeffs
=
torch
.
view_as_real
(
solver
.
sht
(
prd
-
tar
))
coeffs
=
coeffs
[...,
0
]
**
2
+
coeffs
[...,
1
]
**
2
h1_coeffs
=
spectral_weights
*
coeffs
h1_norm2
=
h1_coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
h1_coeffs
[...,
:,
1
:],
dim
=-
1
)
l2_norm2
=
coeffs
[...,
:,
0
]
+
2
*
torch
.
sum
(
coeffs
[...,
:,
1
:],
dim
=-
1
)
h1_loss
=
torch
.
sum
(
h1_norm2
,
dim
=
(
-
1
,
-
2
))
l2_loss
=
torch
.
sum
(
l2_norm2
,
dim
=
(
-
1
,
-
2
))
# strictly speaking this is not exactly h1 loss
if
not
squared
:
loss
=
torch
.
sqrt
(
h1_loss
)
+
torch
.
sqrt
(
l2_loss
)
else
:
loss
=
h1_loss
+
l2_loss
if
relative
:
raise
NotImplementedError
(
"Relative H1 loss not implemented"
)
loss
=
loss
.
mean
()
return
loss
def
fluct_l2loss_sphere
(
solver
,
prd
,
tar
,
inp
,
relative
=
False
,
polar_opt
=
0
):
# compute the weighting factor first
fluct
=
solver
.
integrate_grid
((
tar
-
inp
)
**
2
,
dimensionless
=
True
,
polar_opt
=
polar_opt
)
weight
=
fluct
/
torch
.
sum
(
fluct
,
dim
=-
1
,
keepdim
=
True
)
# weight = weight.reshape(*weight.shape, 1, 1)
loss
=
weight
*
solver
.
integrate_grid
((
prd
-
tar
)
**
2
,
dimensionless
=
True
,
polar_opt
=
polar_opt
)
if
relative
:
loss
=
loss
/
(
weight
*
solver
.
integrate_grid
(
tar
**
2
,
dimensionless
=
True
,
polar_opt
=
polar_opt
))
loss
=
torch
.
mean
(
loss
)
return
loss
def
main
(
train
=
True
,
load_checkpoint
=
False
,
enable_amp
=
False
):
# set seed
torch
.
manual_seed
(
333
)
torch
.
cuda
.
manual_seed
(
333
)
# set device
device
=
torch
.
device
(
'cuda:1'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
device
.
index
)
# dataset
from
utils.pde_dataset
import
PdeDataset
# 1 hour prediction steps
dt
=
1
*
3600
dt_solver
=
150
nsteps
=
dt
//
dt_solver
dataset
=
PdeDataset
(
dt
=
dt
,
nsteps
=
nsteps
,
dims
=
(
256
,
512
),
device
=
device
,
normalize
=
True
)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
,
persistent_workers
=
False
)
solver
=
dataset
.
solver
.
to
(
device
)
nlat
=
dataset
.
nlat
nlon
=
dataset
.
nlon
# training function
def
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
=
None
,
nepochs
=
20
,
nfuture
=
0
,
num_examples
=
256
,
num_valid
=
8
,
loss_fn
=
'l2'
):
train_start
=
time
.
time
()
for
epoch
in
range
(
nepochs
):
# time each epoch
epoch_start
=
time
.
time
()
dataloader
.
dataset
.
set_initial_condition
(
'random'
)
dataloader
.
dataset
.
set_num_examples
(
num_examples
)
# do the training
acc_loss
=
0
model
.
train
()
for
inp
,
tar
in
dataloader
:
with
amp
.
autocast
(
enabled
=
enable_amp
):
prd
=
model
(
inp
)
for
_
in
range
(
nfuture
):
prd
=
model
(
prd
)
if
loss_fn
==
'l2'
:
loss
=
l2loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
)
elif
loss_fn
==
'h1'
:
loss
=
h1loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
)
elif
loss_fn
==
'spectral'
:
loss
=
spectral_loss_sphere
(
solver
,
prd
,
tar
,
relative
=
False
)
elif
loss_fn
==
'fluct'
:
loss
=
fluct_l2loss_sphere
(
solver
,
prd
,
tar
,
inp
,
relative
=
True
)
else
:
raise
NotImplementedError
(
f
'Unknown loss function
{
loss_fn
}
'
)
acc_loss
+=
loss
.
item
()
*
inp
.
size
(
0
)
optimizer
.
zero_grad
(
set_to_none
=
True
)
# gscaler.scale(loss).backward()
gscaler
.
scale
(
loss
).
backward
()
gscaler
.
step
(
optimizer
)
gscaler
.
update
()
acc_loss
=
acc_loss
/
len
(
dataloader
.
dataset
)
dataloader
.
dataset
.
set_initial_condition
(
'random'
)
dataloader
.
dataset
.
set_num_examples
(
num_valid
)
# perform validation
valid_loss
=
0
model
.
eval
()
with
torch
.
no_grad
():
for
inp
,
tar
in
dataloader
:
prd
=
model
(
inp
)
for
_
in
range
(
nfuture
):
prd
=
model
(
prd
)
loss
=
l2loss_sphere
(
solver
,
prd
,
tar
,
relative
=
True
)
valid_loss
+=
loss
.
item
()
*
inp
.
size
(
0
)
valid_loss
=
valid_loss
/
len
(
dataloader
.
dataset
)
if
scheduler
is
not
None
:
scheduler
.
step
(
valid_loss
)
epoch_time
=
time
.
time
()
-
epoch_start
print
(
f
'--------------------------------------------------------------------------------'
)
print
(
f
'Epoch
{
epoch
}
summary:'
)
print
(
f
'time taken:
{
epoch_time
}
'
)
print
(
f
'accumulated training loss:
{
acc_loss
}
'
)
print
(
f
'relative validation loss:
{
valid_loss
}
'
)
if
wandb
.
run
is
not
None
:
current_lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
wandb
.
log
({
"loss"
:
acc_loss
,
"validation loss"
:
valid_loss
,
"learning rate"
:
current_lr
})
train_time
=
time
.
time
()
-
train_start
print
(
f
'--------------------------------------------------------------------------------'
)
print
(
f
'done. Training took
{
train_time
}
.'
)
return
valid_loss
# rolls out the FNO and compares to the classical solver
def
autoregressive_inference
(
model
,
dataset
,
path_root
,
nsteps
,
autoreg_steps
=
10
,
nskip
=
1
,
plot_channel
=
0
,
nics
=
20
):
model
.
eval
()
losses
=
np
.
zeros
(
nics
)
fno_times
=
np
.
zeros
(
nics
)
nwp_times
=
np
.
zeros
(
nics
)
for
iic
in
range
(
nics
):
ic
=
dataset
.
solver
.
random_initial_condition
(
mach
=
0.2
)
inp_mean
=
dataset
.
inp_mean
inp_var
=
dataset
.
inp_var
prd
=
(
dataset
.
solver
.
spec2grid
(
ic
)
-
inp_mean
)
/
torch
.
sqrt
(
inp_var
)
prd
=
prd
.
unsqueeze
(
0
)
uspec
=
ic
.
clone
()
# ML model
start_time
=
time
.
time
()
for
i
in
range
(
1
,
autoreg_steps
+
1
):
# evaluate the ML model
prd
=
model
(
prd
)
if
iic
==
nics
-
1
and
nskip
>
0
and
i
%
nskip
==
0
:
# do plotting
fig
=
plt
.
figure
(
figsize
=
(
7.5
,
6
))
dataset
.
solver
.
plot_griddata
(
prd
[
0
,
plot_channel
],
fig
,
vmax
=
4
,
vmin
=-
4
)
plt
.
savefig
(
path_root
+
'_pred_'
+
str
(
i
//
nskip
)
+
'.png'
)
plt
.
clf
()
fno_times
[
iic
]
=
time
.
time
()
-
start_time
# classical model
start_time
=
time
.
time
()
for
i
in
range
(
1
,
autoreg_steps
+
1
):
# advance classical model
uspec
=
dataset
.
solver
.
timestep
(
uspec
,
nsteps
)
if
iic
==
nics
-
1
and
i
%
nskip
==
0
and
nskip
>
0
:
ref
=
(
dataset
.
solver
.
spec2grid
(
uspec
)
-
inp_mean
)
/
torch
.
sqrt
(
inp_var
)
fig
=
plt
.
figure
(
figsize
=
(
7.5
,
6
))
dataset
.
solver
.
plot_griddata
(
ref
[
plot_channel
],
fig
,
vmax
=
4
,
vmin
=-
4
)
plt
.
savefig
(
path_root
+
'_truth_'
+
str
(
i
//
nskip
)
+
'.png'
)
plt
.
clf
()
nwp_times
[
iic
]
=
time
.
time
()
-
start_time
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref
=
dataset
.
solver
.
spec2grid
(
uspec
)
prd
=
prd
*
torch
.
sqrt
(
inp_var
)
+
inp_mean
losses
[
iic
]
=
l2loss_sphere
(
solver
,
prd
,
ref
,
relative
=
True
).
item
()
return
losses
,
fno_times
,
nwp_times
def
count_parameters
(
model
):
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# prepare dicts containing models and corresponding metrics
models
=
{}
metrics
=
{}
# # U-Net if installed
# from models.unet import UNet
# models['unet_baseline'] = partial(UNet)
# SFNO and FNO models
from
models.sfno
import
SphericalFourierNeuralOperatorNet
as
SFNO
# SFNO models
models
[
'sfno_sc3_layer4_edim256_linear'
]
=
partial
(
SFNO
,
spectral_transform
=
'sht'
,
filter_type
=
'linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
operator_type
=
'vector'
)
models
[
'sfno_sc3_layer4_edim256_real'
]
=
partial
(
SFNO
,
spectral_transform
=
'sht'
,
filter_type
=
'non-linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
complex_activation
=
'real'
,
operator_type
=
'diagonal'
)
# FNO models
models
[
'fno_sc3_layer4_edim256_linear'
]
=
partial
(
SFNO
,
spectral_transform
=
'fft'
,
filter_type
=
'linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
operator_type
=
'diagonal'
)
models
[
'fno_sc3_layer4_edim256_real'
]
=
partial
(
SFNO
,
spectral_transform
=
'fft'
,
filter_type
=
'non-linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
complex_activation
=
'real'
)
# iterate over models and train each model
root_path
=
os
.
path
.
dirname
(
__file__
)
for
model_name
,
model_handle
in
models
.
items
():
model
=
model_handle
().
to
(
device
)
metrics
[
model_name
]
=
{}
num_params
=
count_parameters
(
model
)
print
(
f
'number of trainable params:
{
num_params
}
'
)
metrics
[
model_name
][
'num_params'
]
=
num_params
if
load_checkpoint
:
model
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
root_path
,
'checkpoints/'
+
model_name
)))
# run the training
if
train
:
run
=
wandb
.
init
(
project
=
"sfno spherical swe"
,
group
=
model_name
,
name
=
model_name
+
'_'
+
str
(
time
.
time
()),
config
=
model_handle
.
keywords
)
# optimizer:
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1E-3
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'min'
)
gscaler
=
amp
.
GradScaler
(
enabled
=
enable_amp
)
start_time
=
time
.
time
()
print
(
f
'Training
{
model_name
}
, single step'
)
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
,
nepochs
=
200
,
loss_fn
=
'l2'
)
# multistep training
print
(
f
'Training
{
model_name
}
, two step'
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5E-5
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'min'
)
gscaler
=
amp
.
GradScaler
(
enabled
=
enable_amp
)
dataloader
.
dataset
.
nsteps
=
2
*
dt
//
dt_solver
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
,
nepochs
=
20
,
nfuture
=
1
)
dataloader
.
dataset
.
nsteps
=
1
*
dt
//
dt_solver
training_time
=
time
.
time
()
-
start_time
run
.
finish
()
torch
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
root_path
,
'checkpoints/'
+
model_name
))
# set seed
torch
.
manual_seed
(
333
)
torch
.
cuda
.
manual_seed
(
333
)
with
torch
.
inference_mode
():
losses
,
fno_times
,
nwp_times
=
autoregressive_inference
(
model
,
dataset
,
os
.
path
.
join
(
root_path
,
'paper_figures/'
+
model_name
),
nsteps
=
nsteps
,
autoreg_steps
=
10
)
metrics
[
model_name
][
'loss_mean'
]
=
np
.
mean
(
losses
)
metrics
[
model_name
][
'loss_std'
]
=
np
.
std
(
losses
)
metrics
[
model_name
][
'fno_time_mean'
]
=
np
.
mean
(
fno_times
)
metrics
[
model_name
][
'fno_time_std'
]
=
np
.
std
(
fno_times
)
metrics
[
model_name
][
'nwp_time_mean'
]
=
np
.
mean
(
nwp_times
)
metrics
[
model_name
][
'nwp_time_std'
]
=
np
.
std
(
nwp_times
)
if
train
:
metrics
[
model_name
][
'training_time'
]
=
training_time
df
=
pd
.
DataFrame
(
metrics
)
df
.
to_pickle
(
os
.
path
.
join
(
root_path
,
'output_data/metrics.pkl'
))
if
__name__
==
"__main__"
:
import
torch.multiprocessing
as
mp
mp
.
set_start_method
(
'forkserver'
,
force
=
True
)
main
(
train
=
True
,
load_checkpoint
=
False
,
enable_amp
=
False
)
examples/sfno_examples/utils/pde_dataset.py
0 → 100644
View file @
e7ceb9c8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
torch
import
os
from
math
import
ceil
import
sys
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))),
"torch_harmonics"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))),
"examples"
))
from
shallow_water_equations
import
ShallowWaterSolver
class
PdeDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Custom Dataset class for PDE training data"""
def
__init__
(
self
,
dt
,
nsteps
,
dims
=
(
384
,
768
),
pde
=
'shallow water equations'
,
initial_condition
=
'random'
,
num_examples
=
32
,
device
=
torch
.
device
(
'cpu'
),
normalize
=
True
,
stream
=
None
):
self
.
num_examples
=
num_examples
self
.
device
=
device
self
.
stream
=
stream
self
.
nlat
=
dims
[
0
]
self
.
nlon
=
dims
[
1
]
# number of solver steps used to compute the target
self
.
nsteps
=
nsteps
self
.
normalize
=
normalize
if
pde
==
'shallow water equations'
:
lmax
=
ceil
(
self
.
nlat
/
3
)
mmax
=
lmax
dt_solver
=
dt
/
float
(
self
.
nsteps
)
self
.
solver
=
ShallowWaterSolver
(
self
.
nlat
,
self
.
nlon
,
dt_solver
,
lmax
=
lmax
,
mmax
=
mmax
,
grid
=
'equiangular'
).
to
(
self
.
device
).
float
()
else
:
raise
NotImplementedError
self
.
set_initial_condition
(
ictype
=
initial_condition
)
if
self
.
normalize
:
inp0
,
_
=
self
.
_get_sample
()
self
.
inp_mean
=
torch
.
mean
(
inp0
,
dim
=
(
-
1
,
-
2
)).
reshape
(
-
1
,
1
,
1
)
self
.
inp_var
=
torch
.
var
(
inp0
,
dim
=
(
-
1
,
-
2
)).
reshape
(
-
1
,
1
,
1
)
def
__len__
(
self
):
length
=
self
.
num_examples
if
self
.
ictype
==
'random'
else
1
return
length
def
set_initial_condition
(
self
,
ictype
=
'random'
):
self
.
ictype
=
ictype
def
set_num_examples
(
self
,
num_examples
=
32
):
self
.
num_examples
=
num_examples
def
_get_sample
(
self
):
if
self
.
ictype
==
'random'
:
inp
=
self
.
solver
.
random_initial_condition
(
mach
=
0.2
)
elif
self
.
ictype
==
'galewsky'
:
inp
=
self
.
solver
.
galewsky_initial_condition
()
# solve pde for n steps to return the target
tar
=
self
.
solver
.
timestep
(
inp
,
self
.
nsteps
)
inp
=
self
.
solver
.
spec2grid
(
inp
)
tar
=
self
.
solver
.
spec2grid
(
tar
)
return
inp
,
tar
def
__getitem__
(
self
,
index
):
# if self.stream is None:
# self.stream = torch.cuda.Stream()
# with torch.cuda.stream(self.stream):
# with torch.inference_mode():
# with torch.no_grad():
# inp, tar = self._get_sample()
# if self.normalize:
# inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var)
# tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var)
# self.stream.synchronize()
with
torch
.
inference_mode
():
with
torch
.
no_grad
():
inp
,
tar
=
self
.
_get_sample
()
if
self
.
normalize
:
inp
=
(
inp
-
self
.
inp_mean
)
/
torch
.
sqrt
(
self
.
inp_var
)
tar
=
(
tar
-
self
.
inp_mean
)
/
torch
.
sqrt
(
self
.
inp_var
)
return
inp
.
clone
(),
tar
.
clone
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment