Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
FuXi
Commits
1df8a46a
Commit
1df8a46a
authored
Sep 13, 2023
by
tpys
Browse files
clean and rename
parent
b5342442
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
239 additions
and
12 deletions
+239
-12
README.md
README.md
+12
-6
fuxi.py
fuxi.py
+2
-6
make_era5_input.py
make_era5_input.py
+0
-0
make_gfs_input.py
make_gfs_input.py
+0
-0
make_hres_input.py
make_hres_input.py
+98
-0
util.py
util.py
+127
-0
No files found.
README.md
View file @
1df8a46a
...
...
@@ -21,10 +21,12 @@ The downloaded files shall be organized as the following hierarchy:
│ ├── data
│ │ ├── 20210101
│ │ ├── input.nc
│ │ ├── output.nc
│ │ ├── target.nc
│ │
│ ├── fuxi
│ │ ├── output
│ │ ├── 006.nc
│ │ ├── 012.nc
│ │ ├── ...
│ │ ├── 360.nc
│ ├── model
│ | ├── short
│ | ├── short.onnx
│ | ├── medium
...
...
@@ -32,7 +34,11 @@ The downloaded files shall be organized as the following hierarchy:
│ | ├── long
│ | ├── long.onnx
| |
│ ├── infernece_fuxi.py
│ ├── fuxi.py
│ ├── util.py
│ ├── make_era5_input.py
│ ├── make_hres_input.py
│ ├── make_gfs_input.py
```
...
...
@@ -51,7 +57,7 @@ pip install -r requirement.txt
## Demo
```
bash
python
inference_
fuxi.py
--model
model_dir
--input
input_file
--num_steps
20
--input_type
GFS
python fuxi.py
--model
model_dir
--input
input_file
--num_steps
20
20 20
```
...
...
inference_
fuxi.py
→
fuxi.py
View file @
1df8a46a
...
...
@@ -6,22 +6,18 @@ import xarray as xr
import
pandas
as
pd
import
onnxruntime
as
ort
from
data_
util
import
save_like
from
util
import
save_like
ort
.
set_default_logger_severity
(
3
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
"FuXi onnx model dir"
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
True
,
help
=
"The input data file, store in netcdf format"
)
parser
.
add_argument
(
'--input_type'
,
type
=
str
,
help
=
"The input type"
,
default
=
"ERA5"
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'--num_steps'
,
type
=
int
,
nargs
=
"+"
,
default
=
[
20
])
args
=
parser
.
parse_args
()
assert
args
.
input_type
.
upper
()
in
[
"ERA5"
,
"GFS"
,
"HRES"
]
def
time_encoding
(
init_time
,
total_step
,
freq
=
6
):
init_time
=
np
.
array
([
init_time
])
tembs
=
[]
...
...
@@ -91,7 +87,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
temb
=
tembs
[
step
]
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
output
=
new_input
[:,
-
1
]
save_like
(
output
,
data
,
step
,
save_dir
,
input_type
=
args
.
input_type
)
save_like
(
output
,
data
,
step
,
save_dir
)
print
(
f
'stage:
{
i
}
, step:
{
step
+
1
:
02
d
}
, output:
{
output
.
min
():.
2
f
}
{
output
.
max
():.
2
f
}
'
)
input
=
new_input
step
+=
1
...
...
make_era5.py
→
make_era5
_input
.py
View file @
1df8a46a
File moved
make_gfs.py
→
make_gfs
_input
.py
View file @
1df8a46a
File moved
make_hres_input.py
0 → 100644
View file @
1df8a46a
import
os
import
numpy
as
np
import
pandas
as
pd
import
xarray
as
xr
def
make_hres_input
(
init_time
,
data_dir
,
save_dir
,
degree
=
0.25
):
lat
=
np
.
linspace
(
-
90
,
90
,
int
(
180
/
degree
)
+
1
,
dtype
=
np
.
float32
)
lon
=
np
.
arange
(
0
,
360
,
degree
,
dtype
=
np
.
float32
)
pl_names
=
[
'z'
,
't'
,
'u'
,
'v'
,
'r'
]
sfc_names
=
[
't2m'
,
'u10'
,
'v10'
,
'msl'
,
'tp'
]
levels
=
[
50
,
100
,
150
,
200
,
250
,
300
,
400
,
500
,
600
,
700
,
850
,
925
,
1000
]
input
=
[]
level
=
[]
for
name
in
pl_names
+
sfc_names
:
src_name
=
'{}_{}'
.
format
(
name
,
init_time
.
strftime
(
"%Y%m%d%H.nc"
))
src_file
=
os
.
path
.
join
(
data_dir
,
src_name
)
if
not
os
.
path
.
exists
(
src_file
):
return
try
:
v
=
xr
.
open_dataset
(
src_file
)
v
=
v
.
sel
(
time
=
init_time
,
drop
=
True
).
data
except
:
print
(
f
"open
{
src_file
}
failed"
)
return
# is there nan in raw data ?
if
np
.
isnan
(
v
).
sum
()
>
0
:
print
(
f
"
{
src_name
}
has nan value"
)
return
# interpolate to 0.25 deg
v
=
v
.
interp
(
lat
=
lat
,
lon
=
lon
,
kwargs
=
{
"fill_value"
:
"extrapolate"
})
# make sure on nan
if
np
.
isnan
(
v
).
sum
()
>
0
:
print
(
f
"
{
src_name
}
has nan value"
)
return
# reverse pressure level
try
:
if
name
in
pl_names
:
v
=
xr
.
concat
([
v
.
sel
(
level
=
l
)
for
l
in
levels
],
'level'
)
level
.
extend
([
f
'
{
name
}{
l
}
'
for
l
in
levels
])
except
:
print
(
"missing pressure level"
)
return
if
name
in
sfc_names
:
level
.
append
(
name
)
# temperature in kelvin
if
name
==
"t"
:
v
=
v
+
273.15
# FuXi take two step as input
if
name
==
"tp"
:
v
=
v
.
clip
(
min
=
0
,
max
=
1000
)
zero
=
v
*
0
zero
=
zero
.
assign_coords
(
dtime
=
[
0
])
v
=
xr
.
concat
([
zero
,
v
],
"dtime"
)
print
(
f
'
{
src_name
}
:
{
v
.
min
().
values
:.
2
f
}
~
{
v
.
max
().
values
:.
2
f
}
'
)
v
.
attrs
=
{}
v
=
v
.
rename
({
'dtime'
:
'time'
})
v
=
v
.
squeeze
(
'member'
).
drop
(
'member'
)
input
.
append
(
v
)
# concat and reshape
input
=
xr
.
concat
(
input
,
"level"
)
input
=
input
.
transpose
(
"time"
,
"level"
,
"lat"
,
"lon"
)
valid_time
=
init_time
+
pd
.
Timedelta
(
hours
=
6
)
# utc time
v
=
v
.
assign_coords
(
time
=
[
init_time
,
valid_time
])
# reverse latitude
input
=
input
.
reindex
(
lat
=
input
.
lat
[::
-
1
])
input
=
input
.
assign_coords
(
level
=
level
)
input
.
name
=
'data'
# save to nc
print
(
input
)
save_name
=
os
.
path
.
join
(
save_dir
,
init_time
.
strftime
(
"%Y%m%d-%H.nc"
))
input
=
input
.
astype
(
np
.
float32
)
input
.
to_netcdf
(
save_name
)
def
test_make_input
():
init_time
=
pd
.
to_datetime
(
"20230731-12"
)
# must utc
data_dir
=
"data/HRES"
save_dir
=
"data/HRES/input"
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
make_hres_input
(
init_time
,
data_dir
,
save_dir
)
data_
util.py
→
util.py
View file @
1df8a46a
...
...
@@ -9,7 +9,6 @@ __all__ = ["save_like"]
pl_names
=
[
'z'
,
't'
,
'u'
,
'v'
,
'r'
]
sfc_names
=
[
't2m'
,
'u10'
,
'v10'
,
'msl'
,
'tp'
]
levels
=
[
50
,
100
,
150
,
200
,
250
,
300
,
400
,
500
,
600
,
700
,
850
,
925
,
1000
]
degree
=
0.25
def
weighted_rmse
(
out
,
tgt
):
...
...
@@ -33,16 +32,12 @@ def split_variable(ds, name):
return
v
def
save_like
(
output
,
input
,
step
,
save_dir
=
""
,
input_type
=
"hres"
,
freq
=
6
,
split
=
False
):
def
save_like
(
output
,
input
,
step
,
save_dir
=
""
,
freq
=
6
,
split
=
False
):
if
save_dir
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
step
=
(
step
+
1
)
*
freq
init_time
=
pd
.
to_datetime
(
input
.
time
.
values
[
-
1
])
if
input_type
.
upper
()
==
"HRES"
:
step
=
(
step
+
2
)
*
freq
init_time
=
pd
.
to_datetime
(
input
.
time
.
values
[
0
])
ds
=
xr
.
DataArray
(
output
[
None
],
dims
=
[
'time'
,
'step'
,
'level'
,
'lat'
,
'lon'
],
...
...
@@ -75,252 +70,6 @@ def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split
ds
.
to_netcdf
(
save_name
)
def
make_era5_input
(
init_time
,
data_dir
,
save_dir
):
ds
=
[]
init_time
=
pd
.
to_datetime
(
init_time
)
hist_time
=
init_time
-
pd
.
Timedelta
(
hours
=
6
)
print
(
f
"init_time:
{
init_time
}
"
)
level
=
[]
for
name
in
pl_names
+
sfc_names
:
data_name
=
os
.
path
.
join
(
data_dir
,
name
,
f
'
{
init_time
.
year
}
'
)
v
=
xr
.
open_zarr
(
data_name
)
v
=
v
.
sel
(
time
=
[
hist_time
,
init_time
])
v
=
v
.
rename
({
name
:
'data'
})
v
.
attrs
=
{}
ds
.
append
(
v
)
if
name
in
pl_names
:
level
.
extend
([
f
'
{
name
.
lower
()
}{
l
}
'
for
l
in
levels
])
if
name
in
sfc_names
:
level
.
append
(
name
.
lower
())
ds
=
xr
.
concat
(
ds
,
'level'
)
ds
=
ds
.
assign_coords
(
level
=
level
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_name
=
os
.
path
.
join
(
save_dir
,
init_time
.
strftime
(
"input.%Y%m%d.t%H.nc"
))
print
(
f
"save to
{
save_name
}
..."
)
ds
=
ds
.
astype
(
np
.
float32
)
ds
.
to_netcdf
(
save_name
)
def
make_era5
(
init_time
,
data_dir
):
import
os
import
numpy
as
np
import
pandas
as
pd
import
xarray
as
xr
init_time
=
pd
.
to_datetime
(
init_time
)
print
(
f
"process
{
init_time
}
..."
)
pl_file
=
os
.
path
.
join
(
data_dir
,
init_time
.
strftime
(
'P%Y%m%d%H.nc'
))
pl
=
xr
.
open_dataset
(
pl_file
)
sfc_file
=
os
.
path
.
join
(
data_dir
,
init_time
.
strftime
(
'S%Y%m%d%H.nc'
))
sfc
=
xr
.
open_dataset
(
sfc_file
)
tp_file
=
os
.
path
.
join
(
data_dir
,
init_time
.
strftime
(
'R%Y%m%d.nc'
))
tp
=
xr
.
open_dataarray
(
tp_file
).
fillna
(
0
)
tp
=
tp
.
rolling
(
time
=
6
).
sum
()
*
1000
tp
=
tp
.
sel
(
time
=
tp
.
time
[::
6
])
tp
=
tp
.
clip
(
min
=
0
,
max
=
1000
)
sfc
[
'tp'
]
=
tp
pl_names
=
[
'z'
,
't'
,
'u'
,
'v'
,
'r'
]
sfc_names
=
[
't2m'
,
'u10'
,
'v10'
,
'msl'
,
'tp'
]
levels
=
[
50
,
100
,
150
,
200
,
250
,
300
,
400
,
500
,
600
,
700
,
850
,
925
,
1000
]
channel
=
[
f
'
{
n
.
upper
()
}{
l
}
'
for
n
in
pl_names
for
l
in
levels
]
channel
+=
[
n
.
upper
()
for
n
in
sfc_names
]
ds
=
[]
for
name
in
pl_names
+
sfc_names
:
if
name
in
[
'z'
,
't'
,
'u'
,
'v'
,
'r'
]:
v
=
pl
[
name
]
if
name
in
[
't2m'
,
'u10'
,
'v10'
,
'msl'
,
'tp'
]:
v
=
sfc
[
name
]
level
=
xr
.
DataArray
([
1
],
coords
=
{
'level'
:
[
1
]},
dims
=
[
'level'
])
v
=
v
.
expand_dims
({
'level'
:
level
},
axis
=
1
)
if
np
.
isnan
(
v
).
sum
()
>
0
:
print
(
f
"
{
name
}
has nan value"
)
raise
ValueError
v
.
name
=
"data"
v
.
attrs
=
{}
print
(
f
"
{
name
}
:
{
v
.
shape
}
,
{
v
.
min
().
values
}
~
{
v
.
max
().
values
}
"
)
ds
.
append
(
v
)
ds
=
xr
.
concat
(
ds
,
'level'
)
ds
=
ds
.
assign_coords
(
level
=
channel
)
ds
=
ds
.
rename
({
'longitude'
:
'lon'
,
'latitude'
:
'lat'
})
ds
=
ds
.
astype
(
np
.
float32
)
return
ds
# ds12 = make_era5('20230725-12', 'ERA520230725')
# ds18 = make_era5('20230725-18', 'ERA520230725')
# ds = xr.concat([ds12, ds18], 'time')
# ds.to_netcdf('new_input.nc')
def
make_hres_input
(
init_time
,
data_dir
,
save_dir
):
lat
=
np
.
linspace
(
-
90
,
90
,
int
(
180
/
degree
)
+
1
,
dtype
=
np
.
float32
)
lon
=
np
.
arange
(
0
,
360
,
degree
,
dtype
=
np
.
float32
)
input
=
[]
level
=
[]
for
name
in
pl_names
+
sfc_names
:
src_name
=
'{}_{}'
.
format
(
name
,
init_time
.
strftime
(
"%Y%m%d%H.nc"
))
src_file
=
os
.
path
.
join
(
data_dir
,
src_name
)
if
not
os
.
path
.
exists
(
src_file
):
return
try
:
v
=
xr
.
open_dataset
(
src_file
)
v
=
v
.
sel
(
time
=
init_time
,
drop
=
True
).
data
except
:
print
(
f
"open
{
src_file
}
failed"
)
return
# is there nan in raw data ?
if
np
.
isnan
(
v
).
sum
()
>
0
:
print
(
f
"
{
src_name
}
has nan value"
)
return
# interpolate to 0.25 deg
v
=
v
.
interp
(
lat
=
lat
,
lon
=
lon
,
kwargs
=
{
"fill_value"
:
"extrapolate"
})
# make sure on nan
if
np
.
isnan
(
v
).
sum
()
>
0
:
print
(
f
"
{
src_name
}
has nan value"
)
return
# reverse pressure level
try
:
if
name
in
pl_names
:
v
=
xr
.
concat
([
v
.
sel
(
level
=
l
)
for
l
in
levels
],
'level'
)
level
.
extend
([
f
'
{
name
}{
l
}
'
for
l
in
levels
])
except
:
print
(
"missing pressure level"
)
return
if
name
in
sfc_names
:
level
.
append
(
name
)
# temperature in kelvin
if
name
==
"t"
:
v
=
v
+
273.15
# FuXi take two step as input
if
name
==
"tp"
:
v
=
v
.
clip
(
min
=
0
,
max
=
1000
)
zero
=
v
*
0
zero
=
zero
.
assign_coords
(
dtime
=
[
0
])
v
=
xr
.
concat
([
zero
,
v
],
"dtime"
)
print
(
f
'
{
src_name
}
:
{
v
.
min
().
values
:.
2
f
}
~
{
v
.
max
().
values
:.
2
f
}
'
)
v
.
attrs
=
{}
v
=
v
.
rename
({
'dtime'
:
'time'
})
v
=
v
.
squeeze
(
'member'
).
drop
(
'member'
)
input
.
append
(
v
)
# concat and reshape
input
=
xr
.
concat
(
input
,
"level"
)
input
=
input
.
transpose
(
"time"
,
"level"
,
"lat"
,
"lon"
)
valid_time
=
init_time
+
pd
.
Timedelta
(
hours
=
6
)
# utc time
v
=
v
.
assign_coords
(
time
=
[
init_time
,
valid_time
])
# reverse latitude
input
=
input
.
reindex
(
lat
=
input
.
lat
[::
-
1
])
input
=
input
.
assign_coords
(
level
=
level
)
input
.
name
=
'data'
# save to nc
print
(
input
)
save_name
=
os
.
path
.
join
(
save_dir
,
init_time
.
strftime
(
"%Y%m%d-%H.nc"
))
input
=
input
.
astype
(
np
.
float32
)
input
.
to_netcdf
(
save_name
)
def
make_gfs_input
(
init_time
,
data_dir
,
save_dir
):
pl_names
=
[
'Z'
,
'T'
,
'U'
,
'V'
,
'R'
]
sfc_names
=
[
't2m'
,
'u10'
,
'v10'
,
'msl'
,
'tp'
]
lon
=
np
.
arange
(
0
,
360
,
degree
,
dtype
=
np
.
float32
)
lat
=
np
.
arange
(
90
,
-
90
,
-
degree
,
dtype
=
np
.
float32
)
input
=
[]
level
=
[]
for
name
in
pl_names
+
sfc_names
:
src_name
=
'{}_{}'
.
format
(
name
,
init_time
.
strftime
(
"%Y%m%d.nc"
))
src_file
=
os
.
path
.
join
(
data_dir
,
src_name
)
if
not
os
.
path
.
exists
(
src_file
):
print
(
src_file
)
return
try
:
v
=
xr
.
open_dataset
(
src_file
)[
name
]
except
:
print
(
f
"open
{
src_file
}
failed"
)
return
if
np
.
isnan
(
v
).
sum
()
>
0
:
print
(
f
"
{
src_name
}
has nan value"
)
return
if
v
.
shape
[
-
2
:]
!=
(
721
,
1440
):
v
=
v
.
interp
(
lat
=
lat
,
lon
=
lon
,
kwargs
=
{
"fill_value"
:
"extrapolate"
})
if
np
.
isnan
(
v
).
sum
()
>
0
:
print
(
f
"
{
src_name
}
has nan value"
)
return
if
name
in
pl_names
:
level
.
extend
([
f
'
{
name
.
lower
()
}{
l
}
'
for
l
in
levels
])
if
name
in
sfc_names
:
level
.
append
(
name
.
lower
())
if
name
==
"Z"
:
v
=
v
*
9.8
if
name
==
"tp"
:
v
=
v
.
clip
(
min
=
0
,
max
=
1000
)
v
=
v
.
squeeze
(
'step'
).
drop
(
'step'
)
v
.
attrs
=
{}
v
.
name
=
'data'
vmin
=
v
.
min
().
values
vmax
=
v
.
max
().
values
if
vmax
>
1e10
:
v
=
v
.
where
(
v
<
1e10
,
0
)
vmax
=
v
.
max
().
values
assert
vmax
<
1e10
print
(
f
'
{
src_name
}
:
{
v
.
shape
}
,
{
vmin
:.
2
f
}
~
{
vmax
:.
2
f
}
'
)
input
.
append
(
v
)
input
=
xr
.
concat
(
input
,
"level"
)
# T
input
=
input
.
rename
({
"latitude"
:
"lat"
,
"longitude"
:
"lon"
})
times
=
[
pd
.
to_datetime
(
str
(
t
),
format
=
'%Y%m%d%H'
)
for
t
in
input
.
time
.
values
]
input
=
input
.
assign_coords
(
level
=
level
)
input
=
input
.
assign_coords
(
time
=
times
)
# TODO, we only need two time step input with dims: 2 x 70 x 721 x 1440
save_name
=
os
.
path
.
join
(
save_dir
,
init_time
.
strftime
(
"%Y%m%d.nc"
))
input
=
input
.
astype
(
np
.
float32
)
input
.
to_netcdf
(
save_name
)
def
visualize
(
save_name
,
vars
=
[],
titles
=
[],
vmin
=
None
,
vmax
=
None
):
import
cartopy.crs
as
ccrs
import
matplotlib.pyplot
as
plt
...
...
@@ -354,14 +103,6 @@ def visualize(save_name, vars=[], titles=[], vmin=None, vmax=None):
plt
.
close
()
def
test_make_input
():
init_time
=
pd
.
to_datetime
(
"20230731-12"
)
# must utc
data_dir
=
"data/HRES"
save_dir
=
"data/HRES/input"
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
make_hres_input
(
init_time
,
data_dir
,
save_dir
)
def
test_visualize
(
step
,
data_dir
):
src_name
=
os
.
path
.
join
(
data_dir
,
f
"
{
step
:
03
d
}
.nc"
)
ds
=
xr
.
open_dataarray
(
src_name
).
isel
(
time
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment