Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
FuXi
Commits
c56e400c
Commit
c56e400c
authored
Aug 01, 2023
by
tpys
Browse files
move save_like to data_util
parent
2f128e8c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
198 additions
and
30 deletions
+198
-30
data_util.py
data_util.py
+196
-0
inference_fuxi.py
inference_fuxi.py
+2
-30
No files found.
make_input
.py
→
data_util
.py
View file @
c56e400c
...
...
@@ -4,35 +4,73 @@ import numpy as np
import
pandas
as
pd
import
xarray
as
xr
__all__
=
[
'make_input'
,
'chunk_time'
]
__all__
=
[
'make_input'
,
"save_like"
]
pl_names
=
[
'z'
,
't'
,
'u'
,
'v'
,
'r'
]
sfc_names
=
[
't2m'
,
'u10'
,
'v10'
,
'msl'
,
'tp'
]
levels
=
[
50
,
100
,
150
,
200
,
250
,
300
,
400
,
500
,
600
,
700
,
850
,
925
,
1000
]
def
chunk_time
(
ds
,
shape
=
None
):
if
shape
is
None
:
dims
=
{
k
:
v
for
k
,
v
in
ds
.
dims
.
items
()}
else
:
dims
=
{
k
:
v
for
k
,
v
in
zip
(
ds
.
dims
,
shape
)}
for
k
in
[
'time'
,
'lead_time'
]:
if
k
in
dims
:
dims
[
k
]
=
1
def
split_variable
(
ds
,
name
):
if
name
in
sfc_names
:
v
=
ds
.
sel
(
level
=
[
name
])
v
=
v
.
assign_coords
(
level
=
[
0
])
v
=
v
.
rename
({
"level"
:
"level0"
})
v
=
v
.
transpose
(
'member'
,
'level0'
,
'time'
,
'dtime'
,
'lat'
,
'lon'
)
elif
name
in
pl_names
:
level
=
[
f
'
{
name
}{
l
}
'
for
l
in
levels
]
v
=
ds
.
sel
(
level
=
level
)
v
=
v
.
assign_coords
(
level
=
levels
)
v
=
v
.
transpose
(
'member'
,
'level'
,
'time'
,
'dtime'
,
'lat'
,
'lon'
)
return
v
def
save_like
(
output
,
input
,
step
,
save_dir
=
""
,
freq
=
6
):
if
save_dir
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
dtime
=
(
step
+
2
)
*
freq
#
init_time
=
pd
.
to_datetime
(
input
.
time
.
values
[
0
])
print
(
f
'init_time:
{
init_time
}
, dtime:
{
dtime
}
'
)
data
=
xr
.
DataArray
(
output
[
None
,
None
],
dims
=
[
'member'
,
'time'
,
'dtime'
,
'level'
,
'lat'
,
'lon'
],
coords
=
dict
(
member
=
[
'FuXi'
],
time
=
[
init_time
],
dtime
=
[
dtime
],
level
=
input
.
level
,
lat
=
input
.
lat
.
values
,
lon
=
input
.
lon
.
values
,
)
).
astype
(
np
.
float32
)
def
rename
(
name
):
if
name
==
"tp"
:
return
"TP06"
elif
name
==
"r"
:
return
"RH"
return
name
.
upper
()
ds
=
[]
for
k
in
pl_names
+
sfc_names
:
v
=
split_variable
(
data
,
k
)
v
.
name
=
rename
(
k
)
# print(f"{k}: {v.shape} {v.values.min()} ~ {v.values.max()}")
ds
.
append
(
v
)
ds
=
xr
.
merge
(
ds
,
compat
=
"no_conflicts"
)
ds
=
d
s
.
chunk
(
dims
)
return
ds
save_name
=
o
s
.
path
.
join
(
save_dir
,
f
'
{
dtime
:
03
d
}
.nc'
)
ds
.
to_netcdf
(
save_name
)
def
make_input
(
init_time
,
data_dir
,
save_dir
,
deg
=
0.25
):
# These are fixed for FuXi
pl_names
=
[
'z'
,
't'
,
'u'
,
'v'
,
'r'
]
sfc_names
=
[
't2m'
,
'u10'
,
'v10'
,
'msl'
,
'tp'
]
levels
=
[
50
,
100
,
150
,
200
,
250
,
300
,
400
,
500
,
600
,
700
,
850
,
925
,
1000
]
lat
=
np
.
linspace
(
-
90
,
90
,
int
(
180
/
deg
)
+
1
,
dtype
=
np
.
float32
)
lon
=
np
.
arange
(
0
,
360
,
deg
,
dtype
=
np
.
float32
)
valid_time
=
init_time
+
pd
.
Timedelta
(
hours
=
6
)
# utc time
input
=
[]
level
=
[]
for
name
in
pl_names
+
sfc_names
:
src_name
=
'{}_{}'
.
format
(
name
,
init_time
.
strftime
(
"%Y%m%d%H.nc"
))
src_file
=
os
.
path
.
join
(
data_dir
,
src_name
)
...
...
@@ -41,7 +79,8 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
return
try
:
v
=
xr
.
open_dataset
(
src_file
).
sel
(
time
=
init_time
,
drop
=
True
).
data
v
=
xr
.
open_dataset
(
src_file
)
v
=
v
.
sel
(
time
=
init_time
,
drop
=
True
).
data
except
:
print
(
f
"open
{
src_file
}
failed"
)
return
...
...
@@ -59,7 +98,6 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
print
(
f
"
{
src_name
}
has nan value"
)
return
# reverse pressure level
try
:
if
name
in
pl_names
:
...
...
@@ -88,28 +126,71 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
v
.
attrs
=
{}
v
=
v
.
rename
({
'dtime'
:
'time'
})
v
=
v
.
squeeze
(
'member'
).
drop
(
'member'
)
v
=
v
.
assign_coords
(
time
=
[
init_time
,
valid_time
])
input
.
append
(
v
)
# concat and reshape
input
=
xr
.
concat
(
input
,
"level"
)
input
=
input
.
transpose
(
"time"
,
"level"
,
"lat"
,
"lon"
)
valid_time
=
init_time
+
pd
.
Timedelta
(
hours
=
6
)
# utc time
v
=
v
.
assign_coords
(
time
=
[
init_time
,
valid_time
])
# reverse latitude
input
=
input
.
reindex
(
lat
=
input
.
lat
[::
-
1
])
input
=
input
.
assign_coords
(
level
=
level
)
input
.
name
=
'data'
input
=
chunk_time
(
input
,
input
.
shape
)
# save to nc
save_name
=
os
.
path
.
join
(
save_dir
,
valid_time
.
strftime
(
"%Y%m%d-%H.nc"
))
print
(
input
)
save_name
=
os
.
path
.
join
(
save_dir
,
init_time
.
strftime
(
"%Y%m%d-%H.nc"
))
input
=
input
.
astype
(
np
.
float32
)
input
.
to_netcdf
(
save_name
)
def
visualize
(
save_name
,
vars
=
[],
titles
=
[],
vmin
=
None
,
vmax
=
None
):
import
cartopy.crs
as
ccrs
import
matplotlib.pyplot
as
plt
fig
,
ax
=
plt
.
subplots
(
len
(
vars
),
1
,
figsize
=
(
8
,
6
),
subplot_kw
=
{
"projection"
:
ccrs
.
PlateCarree
()})
def
plot
(
ax
,
v
,
title
):
v
.
plot
(
ax
=
ax
,
x
=
'lon'
,
y
=
'lat'
,
vmin
=
vmin
,
vmax
=
vmax
,
transform
=
ccrs
.
PlateCarree
(),
add_colorbar
=
False
)
# ax.coastlines()
ax
.
set_title
(
title
)
gl
=
ax
.
gridlines
(
draw_labels
=
True
,
linewidth
=
0.5
)
gl
.
top_labels
=
False
gl
.
right_labels
=
False
for
i
,
v
in
enumerate
(
vars
):
if
len
(
vars
)
==
1
:
plot
(
ax
,
v
,
titles
[
i
])
else
:
plot
(
ax
[
i
],
v
,
titles
[
i
])
plt
.
savefig
(
save_name
,
bbox_inches
=
'tight'
,
pad_inches
=
0.1
,
transparent
=
'true'
,
dpi
=
200
)
plt
.
close
()
def
test_make_input
():
init_time
=
pd
.
to_datetime
(
"20230731-12"
)
# must utc
data_dir
=
"data/HRES"
save_dir
=
"data/HRES/input"
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
make_input
(
init_time
,
data_dir
,
save_dir
)
def
test_visualize
():
ds
=
xr
.
open_dataarray
(
'data/HRES/output/072.nc'
)
tp
=
ds
.
sel
(
level
=
'tp'
)
visualize
(
'tp.jpg'
,
[
tp
],
[
'tp'
],
vmin
=
0
,
vmax
=
20
)
# test_make_input()
\ No newline at end of file
inference_fuxi.py
View file @
c56e400c
...
...
@@ -6,6 +6,8 @@ import xarray as xr
import
pandas
as
pd
import
onnxruntime
as
ort
from
data_util
import
save_like
ort
.
set_default_logger_severity
(
3
)
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -50,39 +52,9 @@ def load_model(model_name):
return
session
def
load_data
(
data_file
):
input
=
xr
.
open_dataarray
(
data_file
)
return
input
def
save_like
(
output
,
data
,
step
,
save_dir
=
""
,
freq
=
6
):
if
save_dir
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
lead_time
=
(
step
+
1
)
*
freq
init_time
=
pd
.
to_datetime
(
data
.
time
.
values
[
-
1
])
fcst_time
=
init_time
+
pd
.
Timedelta
(
hours
=
lead_time
)
output
=
xr
.
DataArray
(
output
,
# 1 x 70 x 721 x 1440
dims
=
[
'time'
,
'level'
,
'lat'
,
'lon'
],
coords
=
dict
(
time
=
[
fcst_time
],
level
=
data
.
level
,
lat
=
data
.
lat
,
lon
=
data
.
lon
,
)
)
output
.
name
=
'data'
save_name
=
os
.
path
.
join
(
save_dir
,
f
'
{
lead_time
:
03
d
}
.nc'
)
output
.
to_netcdf
(
save_name
)
def
run_inference
(
model_dir
,
data
,
num_steps
,
save_dir
=
""
):
total_step
=
sum
(
num_steps
)
init_time
=
pd
.
to_datetime
(
data
.
time
.
values
[
-
1
])
tembs
=
time_encoding
(
init_time
,
total_step
)
print
(
f
'init_time:
{
init_time
.
strftime
((
"%Y%m%d-%H"
))
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment