Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
FuXi
Commits
e56b2a2e
Commit
e56b2a2e
authored
Aug 17, 2023
by
tpys
Browse files
support different input: ERA5, HRES, GFS
parent
ce91d3d0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
20 deletions
+62
-20
data_util.py
data_util.py
+62
-20
No files found.
data_util.py
View file @
e56b2a2e
...
@@ -11,6 +11,7 @@ sfc_names = ['t2m', 'u10', 'v10', 'msl', 'tp']
...
@@ -11,6 +11,7 @@ sfc_names = ['t2m', 'u10', 'v10', 'msl', 'tp']
levels
=
[
50
,
100
,
150
,
200
,
250
,
300
,
400
,
500
,
600
,
700
,
850
,
925
,
1000
]
levels
=
[
50
,
100
,
150
,
200
,
250
,
300
,
400
,
500
,
600
,
700
,
850
,
925
,
1000
]
degree
=
0.25
degree
=
0.25
def
weighted_rmse
(
out
,
tgt
):
def
weighted_rmse
(
out
,
tgt
):
wlat
=
np
.
cos
(
np
.
deg2rad
(
tgt
.
lat
))
wlat
=
np
.
cos
(
np
.
deg2rad
(
tgt
.
lat
))
wlat
/=
wlat
.
mean
()
wlat
/=
wlat
.
mean
()
...
@@ -35,20 +36,19 @@ def split_variable(ds, name):
...
@@ -35,20 +36,19 @@ def split_variable(ds, name):
def
save_like
(
output
,
input
,
step
,
save_dir
=
""
,
input_type
=
"hres"
,
freq
=
6
,
split
=
False
):
def
save_like
(
output
,
input
,
step
,
save_dir
=
""
,
input_type
=
"hres"
,
freq
=
6
,
split
=
False
):
if
save_dir
:
if
save_dir
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
dtime
=
(
step
+
1
)
*
freq
step
=
(
step
+
1
)
*
freq
init_time
=
pd
.
to_datetime
(
input
.
time
.
values
[
-
1
])
init_time
=
pd
.
to_datetime
(
input
.
time
.
values
[
-
1
])
if
input_type
.
upper
()
==
"HRES"
:
if
input_type
.
upper
()
==
"HRES"
:
dtime
=
(
step
+
2
)
*
freq
step
=
(
step
+
2
)
*
freq
init_time
=
pd
.
to_datetime
(
input
.
time
.
values
[
0
])
init_time
=
pd
.
to_datetime
(
input
.
time
.
values
[
0
])
ds
=
xr
.
DataArray
(
ds
=
xr
.
DataArray
(
output
[
None
,
None
],
output
[
None
],
dims
=
[
'member'
,
'time'
,
'
dtime
'
,
'level'
,
'lat'
,
'lon'
],
dims
=
[
'time'
,
'
step
'
,
'level'
,
'lat'
,
'lon'
],
coords
=
dict
(
coords
=
dict
(
member
=
[
'FuXi'
],
time
=
[
init_time
],
time
=
[
init_time
],
dtime
=
[
dtime
],
step
=
[
step
],
level
=
input
.
level
,
level
=
input
.
level
,
lat
=
input
.
lat
.
values
,
lat
=
input
.
lat
.
values
,
lon
=
input
.
lon
.
values
,
lon
=
input
.
lon
.
values
,
...
@@ -70,8 +70,39 @@ def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split
...
@@ -70,8 +70,39 @@ def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split
new_ds
.
append
(
v
)
new_ds
.
append
(
v
)
ds
=
xr
.
merge
(
new_ds
,
compat
=
"no_conflicts"
)
ds
=
xr
.
merge
(
new_ds
,
compat
=
"no_conflicts"
)
save_name
=
os
.
path
.
join
(
save_dir
,
f
'
{
dtime
:
03
d
}
.nc'
)
save_name
=
os
.
path
.
join
(
save_dir
,
f
'
{
step
:
03
d
}
.nc'
)
print
(
f
'Save to
{
save_name
}
...'
)
# print(f'Save to {save_name} ...')
ds
.
to_netcdf
(
save_name
)
def
make_era5_input
(
init_time
,
data_dir
,
save_dir
):
ds
=
[]
init_time
=
pd
.
to_datetime
(
init_time
)
hist_time
=
init_time
-
pd
.
Timedelta
(
hours
=
6
)
print
(
f
"init_time:
{
init_time
}
"
)
level
=
[]
for
name
in
pl_names
+
sfc_names
:
data_name
=
os
.
path
.
join
(
data_dir
,
name
,
f
'
{
init_time
.
year
}
'
)
v
=
xr
.
open_zarr
(
data_name
)
v
=
v
.
sel
(
time
=
[
hist_time
,
init_time
])
v
=
v
.
rename
({
name
:
'data'
})
v
.
attrs
=
{}
ds
.
append
(
v
)
if
name
in
pl_names
:
level
.
extend
([
f
'
{
name
.
lower
()
}{
l
}
'
for
l
in
levels
])
if
name
in
sfc_names
:
level
.
append
(
name
.
lower
())
ds
=
xr
.
concat
(
ds
,
'level'
)
ds
=
ds
.
assign_coords
(
level
=
level
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_name
=
os
.
path
.
join
(
save_dir
,
init_time
.
strftime
(
"input.%Y%m%d.t%H.nc"
))
print
(
f
"save to
{
save_name
}
..."
)
ds
=
ds
.
astype
(
np
.
float32
)
ds
.
to_netcdf
(
save_name
)
ds
.
to_netcdf
(
save_name
)
...
@@ -185,7 +216,8 @@ def make_gfs_input(init_time, data_dir, save_dir):
...
@@ -185,7 +216,8 @@ def make_gfs_input(init_time, data_dir, save_dir):
return
return
if
v
.
shape
[
-
2
:]
!=
(
721
,
1440
):
if
v
.
shape
[
-
2
:]
!=
(
721
,
1440
):
v
=
v
.
interp
(
lat
=
lat
,
lon
=
lon
,
kwargs
=
{
"fill_value"
:
"extrapolate"
})
v
=
v
.
interp
(
lat
=
lat
,
lon
=
lon
,
kwargs
=
{
"fill_value"
:
"extrapolate"
})
if
np
.
isnan
(
v
).
sum
()
>
0
:
if
np
.
isnan
(
v
).
sum
()
>
0
:
print
(
f
"
{
src_name
}
has nan value"
)
print
(
f
"
{
src_name
}
has nan value"
)
return
return
...
@@ -219,7 +251,8 @@ def make_gfs_input(init_time, data_dir, save_dir):
...
@@ -219,7 +251,8 @@ def make_gfs_input(init_time, data_dir, save_dir):
input
=
xr
.
concat
(
input
,
"level"
)
# T
input
=
xr
.
concat
(
input
,
"level"
)
# T
input
=
input
.
rename
({
"latitude"
:
"lat"
,
"longitude"
:
"lon"
})
input
=
input
.
rename
({
"latitude"
:
"lat"
,
"longitude"
:
"lon"
})
times
=
[
pd
.
to_datetime
(
str
(
t
),
format
=
'%Y%m%d%H'
)
for
t
in
input
.
time
.
values
]
times
=
[
pd
.
to_datetime
(
str
(
t
),
format
=
'%Y%m%d%H'
)
for
t
in
input
.
time
.
values
]
input
=
input
.
assign_coords
(
level
=
level
)
input
=
input
.
assign_coords
(
level
=
level
)
input
=
input
.
assign_coords
(
time
=
times
)
input
=
input
.
assign_coords
(
time
=
times
)
...
@@ -229,7 +262,6 @@ def make_gfs_input(init_time, data_dir, save_dir):
...
@@ -229,7 +262,6 @@ def make_gfs_input(init_time, data_dir, save_dir):
input
.
to_netcdf
(
save_name
)
input
.
to_netcdf
(
save_name
)
def
visualize
(
save_name
,
vars
=
[],
titles
=
[],
vmin
=
None
,
vmax
=
None
):
def
visualize
(
save_name
,
vars
=
[],
titles
=
[],
vmin
=
None
,
vmax
=
None
):
import
cartopy.crs
as
ccrs
import
cartopy.crs
as
ccrs
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
...
@@ -277,4 +309,14 @@ def test_visualize():
...
@@ -277,4 +309,14 @@ def test_visualize():
visualize
(
'tp.jpg'
,
[
tp
],
[
'tp'
],
vmin
=
0
,
vmax
=
20
)
visualize
(
'tp.jpg'
,
[
tp
],
[
'tp'
],
vmin
=
0
,
vmax
=
20
)
# test_make_input()
def
test_rmse
(
output_name
,
target_name
):
output
=
xr
.
open_dataarray
(
output_name
)
output
=
output
.
isel
(
time
=
0
).
sel
(
step
=
120
)
target
=
xr
.
open_dataarray
(
target_name
)
for
level
in
[
"z500"
,
"t850"
,
"t2m"
,
"u10"
,
"v10"
,
"msl"
,
"tp"
]:
out
=
output
.
sel
(
level
=
level
)
tgt
=
target
.
sel
(
level
=
level
)
rmse
=
weighted_rmse
(
out
,
tgt
).
load
()
print
(
f
"
{
level
.
upper
()
}
120h rmse:
{
rmse
:.
3
f
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment