Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
FuXi
Commits
fe6e4f3d
"vscode:/vscode.git/clone" did not exist on "15f5632365a98fd43ea42e4948a995aa399e99b5"
Commit
fe6e4f3d
authored
Aug 21, 2023
by
tpys
Browse files
fuxi dropout
parent
e56b2a2e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
1 deletion
+5
-1
inference_fuxi.py
inference_fuxi.py
+5
-1
No files found.
inference_fuxi.py
View file @
fe6e4f3d
...
...
@@ -13,6 +13,7 @@ ort.set_default_logger_severity(3)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
"FuXi onnx model dir"
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
True
,
help
=
"The input data file, store in netcdf format"
)
parser
.
add_argument
(
'--drop_prob'
,
type
=
float
,
help
=
"dropout prob"
,
default
=
0
)
parser
.
add_argument
(
'--input_type'
,
type
=
str
,
help
=
"The input type"
,
default
=
"ERA5"
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'--num_steps'
,
type
=
int
,
nargs
=
"+"
,
default
=
[
20
])
...
...
@@ -57,6 +58,7 @@ def load_model(model_name):
def
run_inference
(
model_dir
,
data
,
num_steps
,
save_dir
=
""
):
total_step
=
sum
(
num_steps
)
init_time
=
pd
.
to_datetime
(
data
.
time
.
values
[
-
1
])
tembs
=
time_encoding
(
init_time
,
total_step
)
...
...
@@ -68,6 +70,8 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
assert
data
.
lat
.
values
[
-
1
]
==
-
90
input
=
data
.
values
[
None
]
prob
=
np
.
array
([
args
.
drop_prob
],
dtype
=
np
.
float32
)
print
(
f
'input:
{
input
.
shape
}
,
{
input
.
min
():.
2
f
}
~
{
input
.
max
():.
2
f
}
'
)
print
(
f
'tembs:
{
tembs
.
shape
}
,
{
tembs
.
mean
():.
4
f
}
'
)
...
...
@@ -88,7 +92,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
for
_
in
range
(
0
,
num_step
):
temb
=
tembs
[
step
]
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
,
'prob'
:
prob
})
output
=
new_input
[:,
-
1
]
save_like
(
output
,
data
,
step
,
save_dir
,
input_type
=
args
.
input_type
)
print
(
f
'stage:
{
i
}
, step:
{
step
+
1
:
02
d
}
, output:
{
output
.
min
():.
2
f
}
{
output
.
max
():.
2
f
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment