Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
FuXi
Commits
37384df6
Commit
37384df6
authored
Jul 27, 2023
by
tpys
Browse files
inference fuxi with onnx model
parent
83ef4f67
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
117 additions
and
0 deletions
+117
-0
inference_fuxi.py
inference_fuxi.py
+117
-0
No files found.
inference_fuxi.py
0 → 100644
View file @
37384df6
import
argparse
import
os
import
time
import
numpy
as
np
import
xarray
as
xr
import
pandas
as
pd
import
onnxruntime
as
ort
ort
.
set_default_logger_severity
(
3
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
"FuXi onnx model dir"
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
True
,
help
=
"The input data file, store in netcdf format"
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'--num_steps'
,
type
=
int
,
nargs
=
"+"
,
default
=
[
20
,
20
,
20
])
args
=
parser
.
parse_args
()
def
time_encoding
(
init_time
,
total_step
,
freq
=
6
):
init_time
=
np
.
array
([
init_time
])
tembs
=
[]
for
i
in
range
(
total_step
):
hours
=
np
.
array
([
pd
.
Timedelta
(
hours
=
t
*
freq
)
for
t
in
[
i
-
1
,
i
,
i
+
1
]])
times
=
init_time
[:,
None
]
+
hours
[
None
]
times
=
[
pd
.
Period
(
t
,
'H'
)
for
t
in
times
.
reshape
(
-
1
)]
times
=
[(
p
.
day_of_year
/
366
,
p
.
hour
/
24
)
for
p
in
times
]
temb
=
np
.
array
(
times
,
dtype
=
np
.
float32
)
temb
=
np
.
concatenate
([
np
.
sin
(
temb
),
np
.
cos
(
temb
)],
axis
=-
1
)
temb
=
temb
.
reshape
(
1
,
-
1
)
tembs
.
append
(
temb
)
return
np
.
stack
(
tembs
)
def
load_model
(
mo
):
sessions
=
[]
for
name
in
[
"short"
,
"medium"
,
"long"
]:
model_name
=
os
.
path
.
join
(
mo
,
f
"
{
name
}
.onnx"
)
if
os
.
path
.
exists
(
model_name
):
start
=
time
.
perf_counter
()
print
(
f
'Load model from
{
model_name
}
...'
)
session
=
ort
.
InferenceSession
(
model_name
,
providers
=
[
'CUDAExecutionProvider'
])
load_time
=
time
.
perf_counter
()
-
start
print
(
f
'Load model take
{
load_time
:.
2
f
}
sec'
)
sessions
.
append
(
session
)
return
sessions
def
load_data
(
data_file
):
input
=
xr
.
open_dataarray
(
data_file
)
return
input
def
save_like
(
output
,
data
,
step
,
save_dir
=
""
,
freq
=
6
,
grid
=
0.25
):
if
save_dir
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
lead_time
=
(
step
+
1
)
*
freq
init_time
=
pd
.
to_datetime
(
data
.
time
.
values
[
-
1
])
lat
=
np
.
linspace
(
-
90
,
90
,
int
(
180
/
grid
)
+
1
,
dtype
=
np
.
float32
)
lon
=
np
.
arange
(
0
,
360
,
grid
,
dtype
=
np
.
float32
)
fcst_time
=
init_time
+
pd
.
Timedelta
(
hours
=
lead_time
)
output
=
xr
.
DataArray
(
output
,
# 1 x 70 x 721 x 1440
dims
=
[
'time'
,
'level'
,
'lat'
,
'lon'
],
coords
=
dict
(
time
=
[
fcst_time
],
level
=
data
.
level
,
lat
=
lat
,
lon
=
lon
,
)
)
save_name
=
os
.
path
.
join
(
save_dir
,
f
'
{
lead_time
:
03
d
}
.nc'
)
output
.
to_netcdf
(
save_name
)
def
run_inference
(
sessions
,
data
,
num_steps
,
save_dir
=
""
):
total_step
=
sum
(
num_steps
)
init_time
=
pd
.
to_datetime
(
data
.
time
.
values
[
-
1
])
tembs
=
time_encoding
(
init_time
,
total_step
)
input
=
data
.
values
[
None
]
print
(
f
'input:
{
input
.
shape
}
,
{
input
.
min
():.
2
f
}
~
{
input
.
max
():.
2
f
}
'
)
print
(
f
'tembs:
{
tembs
.
shape
}
,
{
tembs
.
mean
():.
4
f
}
'
)
print
(
'Inference ...'
)
start
=
time
.
perf_counter
()
step
=
0
for
i
,
session
in
enumerate
(
sessions
):
for
_
in
range
(
0
,
num_steps
[
i
]):
temb
=
tembs
[
step
]
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
output
=
new_input
[:,
-
1
]
save_like
(
output
,
data
,
step
,
save_dir
)
print
(
f
'stage:
{
i
}
, step:
{
step
+
1
:
02
d
}
, output:
{
output
.
min
():.
2
f
}
{
output
.
max
():.
2
f
}
'
)
input
=
new_input
step
+=
1
if
step
>
total_step
:
break
run_time
=
time
.
perf_counter
()
-
start
print
(
f
'Inference take
{
run_time
:.
2
f
}
for
{
total_step
}
step'
)
if
__name__
==
"__main__"
:
sessions
=
load_model
(
args
.
model
)
data
=
xr
.
open_dataarray
(
args
.
input
)
run_inference
(
sessions
,
data
,
args
.
num_steps
,
args
.
save_dir
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment