Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
FuXi
Commits
0e8d5efa
Commit
0e8d5efa
authored
Aug 16, 2023
by
tpys
Browse files
add input_type
parent
cfe4af4c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
13 deletions
+15
-13
data_util.py
data_util.py
+11
-12
inference_fuxi.py
inference_fuxi.py
+4
-1
No files found.
data_util.py
View file @
0e8d5efa
...
@@ -16,7 +16,7 @@ def split_variable(ds, name):
...
@@ -16,7 +16,7 @@ def split_variable(ds, name):
v
=
ds
.
sel
(
level
=
[
name
])
v
=
ds
.
sel
(
level
=
[
name
])
v
=
v
.
assign_coords
(
level
=
[
0
])
v
=
v
.
assign_coords
(
level
=
[
0
])
v
=
v
.
rename
({
"level"
:
"level0"
})
v
=
v
.
rename
({
"level"
:
"level0"
})
v
=
v
.
transpose
(
'member'
,
'level0'
,
'time'
,
'dtime'
,
'lat'
,
'lon'
)
v
=
v
.
transpose
(
'member'
,
'level0'
,
'time'
,
'dtime'
,
'lat'
,
'lon'
)
elif
name
in
pl_names
:
elif
name
in
pl_names
:
level
=
[
f
'
{
name
}{
l
}
'
for
l
in
levels
]
level
=
[
f
'
{
name
}{
l
}
'
for
l
in
levels
]
v
=
ds
.
sel
(
level
=
level
)
v
=
ds
.
sel
(
level
=
level
)
...
@@ -25,15 +25,14 @@ def split_variable(ds, name):
...
@@ -25,15 +25,14 @@ def split_variable(ds, name):
return
v
return
v
def
save_like
(
output
,
input
,
step
,
input_type
=
"hres"
,
save_dir
=
""
,
freq
=
6
,
split
=
False
):
def
save_like
(
output
,
input
,
step
,
save_dir
=
""
,
input_type
=
"hres"
,
freq
=
6
,
split
=
False
):
if
save_dir
:
if
save_dir
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
dtime
=
(
step
+
1
)
*
freq
if
input_type
==
"
hres
"
:
if
input_type
.
upper
()
==
"
HRES
"
:
dtime
=
(
step
+
2
)
*
freq
dtime
=
(
step
+
2
)
*
freq
elif
input_type
==
"gfs"
:
dtime
=
(
step
+
1
)
*
freq
init_time
=
pd
.
to_datetime
(
input
.
time
.
values
[
0
])
init_time
=
pd
.
to_datetime
(
input
.
time
.
values
[
0
])
ds
=
xr
.
DataArray
(
ds
=
xr
.
DataArray
(
...
@@ -47,8 +46,8 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split
...
@@ -47,8 +46,8 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split
lat
=
input
.
lat
.
values
,
lat
=
input
.
lat
.
values
,
lon
=
input
.
lon
.
values
,
lon
=
input
.
lon
.
values
,
)
)
).
astype
(
np
.
float32
)
).
astype
(
np
.
float32
)
if
split
:
if
split
:
def
rename
(
name
):
def
rename
(
name
):
if
name
==
"tp"
:
if
name
==
"tp"
:
...
@@ -56,7 +55,7 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split
...
@@ -56,7 +55,7 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split
elif
name
==
"r"
:
elif
name
==
"r"
:
return
"RH"
return
"RH"
return
name
.
upper
()
return
name
.
upper
()
new_ds
=
[]
new_ds
=
[]
for
k
in
pl_names
+
sfc_names
:
for
k
in
pl_names
+
sfc_names
:
v
=
split_variable
(
ds
,
k
)
v
=
split_variable
(
ds
,
k
)
...
@@ -66,7 +65,7 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split
...
@@ -66,7 +65,7 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split
print
(
f
'Save to
{
save_name
}
...'
)
print
(
f
'Save to
{
save_name
}
...'
)
save_name
=
os
.
path
.
join
(
save_dir
,
f
'
{
dtime
:
03
d
}
.nc'
)
save_name
=
os
.
path
.
join
(
save_dir
,
f
'
{
dtime
:
03
d
}
.nc'
)
ds
.
to_netcdf
(
save_name
)
ds
.
to_netcdf
(
save_name
)
def
make_input
(
init_time
,
data_dir
,
save_dir
,
deg
=
0.25
):
def
make_input
(
init_time
,
data_dir
,
save_dir
,
deg
=
0.25
):
...
@@ -75,7 +74,7 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
...
@@ -75,7 +74,7 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
input
=
[]
input
=
[]
level
=
[]
level
=
[]
for
name
in
pl_names
+
sfc_names
:
for
name
in
pl_names
+
sfc_names
:
src_name
=
'{}_{}'
.
format
(
name
,
init_time
.
strftime
(
"%Y%m%d%H.nc"
))
src_name
=
'{}_{}'
.
format
(
name
,
init_time
.
strftime
(
"%Y%m%d%H.nc"
))
src_file
=
os
.
path
.
join
(
data_dir
,
src_name
)
src_file
=
os
.
path
.
join
(
data_dir
,
src_name
)
...
@@ -198,4 +197,4 @@ def test_visualize():
...
@@ -198,4 +197,4 @@ def test_visualize():
visualize
(
'tp.jpg'
,
[
tp
],
[
'tp'
],
vmin
=
0
,
vmax
=
20
)
visualize
(
'tp.jpg'
,
[
tp
],
[
'tp'
],
vmin
=
0
,
vmax
=
20
)
# test_make_input()
# test_make_input()
\ No newline at end of file
inference_fuxi.py
View file @
0e8d5efa
...
@@ -13,11 +13,14 @@ ort.set_default_logger_severity(3)
...
@@ -13,11 +13,14 @@ ort.set_default_logger_severity(3)
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
"FuXi onnx model dir"
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
"FuXi onnx model dir"
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
True
,
help
=
"The input data file, store in netcdf format"
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
True
,
help
=
"The input data file, store in netcdf format"
)
parser
.
add_argument
(
'--input_type'
,
type
=
str
,
help
=
"The input type"
,
default
=
"ERA5"
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'--num_steps'
,
type
=
int
,
nargs
=
"+"
,
default
=
[
20
])
parser
.
add_argument
(
'--num_steps'
,
type
=
int
,
nargs
=
"+"
,
default
=
[
20
])
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
assert
args
.
input_type
.
upper
()
in
[
"ERA5"
,
"GFS"
,
"HRES"
]
def
time_encoding
(
init_time
,
total_step
,
freq
=
6
):
def
time_encoding
(
init_time
,
total_step
,
freq
=
6
):
init_time
=
np
.
array
([
init_time
])
init_time
=
np
.
array
([
init_time
])
...
@@ -87,7 +90,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
...
@@ -87,7 +90,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
temb
=
tembs
[
step
]
temb
=
tembs
[
step
]
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
output
=
new_input
[:,
-
1
]
output
=
new_input
[:,
-
1
]
save_like
(
output
,
data
,
step
,
save_dir
)
save_like
(
output
,
data
,
step
,
save_dir
,
input_type
=
args
.
input_type
)
print
(
f
'stage:
{
i
}
, step:
{
step
+
1
:
02
d
}
, output:
{
output
.
min
():.
2
f
}
{
output
.
max
():.
2
f
}
'
)
print
(
f
'stage:
{
i
}
, step:
{
step
+
1
:
02
d
}
, output:
{
output
.
min
():.
2
f
}
{
output
.
max
():.
2
f
}
'
)
input
=
new_input
input
=
new_input
step
+=
1
step
+=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment