Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
30f7802b
Commit
30f7802b
authored
Jul 21, 2025
by
Thorsten Kurth
Browse files
fixing some more missing device statements
parent
f30ec30a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
5 deletions
+5
-5
torch_harmonics/random_fields.py
torch_harmonics/random_fields.py
+3
-3
torch_harmonics/resample.py
torch_harmonics/resample.py
+2
-2
No files found.
torch_harmonics/random_fields.py
View file @
30f7802b
...
@@ -77,7 +77,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
...
@@ -77,7 +77,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
self
.
isht
=
InverseRealSHT
(
self
.
nlat
,
2
*
self
.
nlat
,
grid
=
grid
,
norm
=
'backward'
).
to
(
dtype
=
dtype
)
self
.
isht
=
InverseRealSHT
(
self
.
nlat
,
2
*
self
.
nlat
,
grid
=
grid
,
norm
=
'backward'
).
to
(
dtype
=
dtype
)
#Square root of the eigenvalues of C.
#Square root of the eigenvalues of C.
sqrt_eig
=
torch
.
tensor
([
j
*
(
j
+
1
)
for
j
in
range
(
self
.
nlat
)]).
view
(
self
.
nlat
,
1
).
repeat
(
1
,
self
.
nlat
+
1
)
sqrt_eig
=
torch
.
as_
tensor
([
j
*
(
j
+
1
)
for
j
in
range
(
self
.
nlat
)]).
view
(
self
.
nlat
,
1
).
repeat
(
1
,
self
.
nlat
+
1
)
sqrt_eig
=
torch
.
tril
(
sigma
*
(((
sqrt_eig
/
radius
**
2
)
+
tau
**
2
)
**
(
-
alpha
/
2.0
)))
sqrt_eig
=
torch
.
tril
(
sigma
*
(((
sqrt_eig
/
radius
**
2
)
+
tau
**
2
)
**
(
-
alpha
/
2.0
)))
sqrt_eig
[
0
,
0
]
=
0.0
sqrt_eig
[
0
,
0
]
=
0.0
sqrt_eig
=
sqrt_eig
.
unsqueeze
(
0
)
sqrt_eig
=
sqrt_eig
.
unsqueeze
(
0
)
...
@@ -85,8 +85,8 @@ class GaussianRandomFieldS2(torch.nn.Module):
...
@@ -85,8 +85,8 @@ class GaussianRandomFieldS2(torch.nn.Module):
#Save mean and var of the standard Gaussian.
#Save mean and var of the standard Gaussian.
#Need these to re-initialize distribution on a new device.
#Need these to re-initialize distribution on a new device.
mean
=
torch
.
tensor
([
0.0
]).
to
(
dtype
=
dtype
)
mean
=
torch
.
as_
tensor
([
0.0
]).
to
(
dtype
=
dtype
)
var
=
torch
.
tensor
([
1.0
]).
to
(
dtype
=
dtype
)
var
=
torch
.
as_
tensor
([
1.0
]).
to
(
dtype
=
dtype
)
self
.
register_buffer
(
'mean'
,
mean
)
self
.
register_buffer
(
'mean'
,
mean
)
self
.
register_buffer
(
'var'
,
var
)
self
.
register_buffer
(
'var'
,
var
)
...
...
torch_harmonics/resample.py
View file @
30f7802b
...
@@ -75,9 +75,9 @@ class ResampleS2(nn.Module):
...
@@ -75,9 +75,9 @@ class ResampleS2(nn.Module):
# we need to expand the solution to the poles before interpolating
# we need to expand the solution to the poles before interpolating
self
.
expand_poles
=
(
self
.
lats_out
>
self
.
lats_in
[
-
1
]).
any
()
or
(
self
.
lats_out
<
self
.
lats_in
[
0
]).
any
()
self
.
expand_poles
=
(
self
.
lats_out
>
self
.
lats_in
[
-
1
]).
any
()
or
(
self
.
lats_out
<
self
.
lats_in
[
0
]).
any
()
if
self
.
expand_poles
:
if
self
.
expand_poles
:
self
.
lats_in
=
torch
.
cat
([
torch
.
tensor
([
0.
],
dtype
=
torch
.
float64
),
self
.
lats_in
=
torch
.
cat
([
torch
.
as_
tensor
([
0.
],
dtype
=
torch
.
float64
),
self
.
lats_in
,
self
.
lats_in
,
torch
.
tensor
([
math
.
pi
],
dtype
=
torch
.
float64
)]).
contiguous
()
torch
.
as_
tensor
([
math
.
pi
],
dtype
=
torch
.
float64
)]).
contiguous
()
# prepare the interpolation by computing indices to the left and right of each output latitude
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx
=
torch
.
searchsorted
(
self
.
lats_in
,
self
.
lats_out
,
side
=
"right"
)
-
1
lat_idx
=
torch
.
searchsorted
(
self
.
lats_in
,
self
.
lats_out
,
side
=
"right"
)
-
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment