Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
FuXi
Commits
f3e35c25
Commit
f3e35c25
authored
Aug 16, 2023
by
tpys
Browse files
inference with short only
parent
46cb64ba
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
inference_fuxi.py
inference_fuxi.py
+7
-4
No files found.
inference_fuxi.py
View file @
f3e35c25
...
@@ -14,10 +14,11 @@ parser = argparse.ArgumentParser()
...
@@ -14,10 +14,11 @@ parser = argparse.ArgumentParser()
parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
"FuXi onnx model dir"
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
"FuXi onnx model dir"
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
True
,
help
=
"The input data file, store in netcdf format"
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
required
=
True
,
help
=
"The input data file, store in netcdf format"
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'--num_steps'
,
type
=
int
,
nargs
=
"+"
,
default
=
[
20
,
20
,
20
])
parser
.
add_argument
(
'--num_steps'
,
type
=
int
,
nargs
=
"+"
,
default
=
[
20
])
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
def
time_encoding
(
init_time
,
total_step
,
freq
=
6
):
def
time_encoding
(
init_time
,
total_step
,
freq
=
6
):
init_time
=
np
.
array
([
init_time
])
init_time
=
np
.
array
([
init_time
])
tembs
=
[]
tembs
=
[]
...
@@ -67,9 +68,11 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
...
@@ -67,9 +68,11 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
print
(
f
'input:
{
input
.
shape
}
,
{
input
.
min
():.
2
f
}
~
{
input
.
max
():.
2
f
}
'
)
print
(
f
'input:
{
input
.
shape
}
,
{
input
.
min
():.
2
f
}
~
{
input
.
max
():.
2
f
}
'
)
print
(
f
'tembs:
{
tembs
.
shape
}
,
{
tembs
.
mean
():.
4
f
}
'
)
print
(
f
'tembs:
{
tembs
.
shape
}
,
{
tembs
.
mean
():.
4
f
}
'
)
stages
=
[
'short'
,
'medium'
,
'long'
]
step
=
0
step
=
0
for
i
,
stage
in
enumerate
([
'short'
,
'medium'
,
'long'
]):
for
i
,
num_step
in
enumerate
(
num_steps
):
stage
=
stages
[
i
]
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
model_name
=
os
.
path
.
join
(
model_dir
,
f
"
{
stage
}
.onnx"
)
model_name
=
os
.
path
.
join
(
model_dir
,
f
"
{
stage
}
.onnx"
)
print
(
f
'Load model from
{
model_name
}
...'
)
print
(
f
'Load model from
{
model_name
}
...'
)
...
@@ -80,7 +83,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
...
@@ -80,7 +83,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
print
(
f
'Inference
{
stage
}
...'
)
print
(
f
'Inference
{
stage
}
...'
)
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
for
_
in
range
(
0
,
num_step
s
[
i
]
):
for
_
in
range
(
0
,
num_step
):
temb
=
tembs
[
step
]
temb
=
tembs
[
step
]
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
output
=
new_input
[:,
-
1
]
output
=
new_input
[:,
-
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment