Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
FuXi
Commits
3b536956
Commit
3b536956
authored
Jul 27, 2023
by
tpys
Browse files
inference separately for memory efficiency
parent
eed7f002
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
71 deletions
+50
-71
README.md
README.md
+16
-49
inference_fuxi.py
inference_fuxi.py
+34
-22
No files found.
README.md
View file @
3b536956
...
...
@@ -9,11 +9,6 @@ This is the official repository for the FuXi paper.
by Lei Chen, Xiaohui Zhong, Feng Zhang, Yuan Cheng, Yinghui Xu, Yuan Qi, Hao Li
## What's New
-
Release of the ONNX model and inference code.
-
Addition of new sample data (20210101).
## Installation
...
...
@@ -24,67 +19,39 @@ The downloaded files shall be organized as the following hierarchy:
```
plain
├── root
│ ├── data
│ │ ├── 20180101
│ │ │ ├── input.nc
│ │ │ ├── output.nc
│ │ │
│ │ ├── 20210101
│ │
├── input.nc
│ │
├── output.nc
│ │
├── target.nc
│ │ ├── input.nc
│ │ ├── output.nc
│ │ ├── target.nc
│ │
│ ├── model
│ | ├── buffer.st
│ | ├── fuxi_short.st
│ | ├── fuxi_medium.st
│ | ├── fuxi_long.st
│ | ├── onnx
│ | ├── short
│ | ├── short.onnx
│ | ├── medium
│ | ├── medium.onnx
│ | ├── long
│ | ├── long.onnx
│ ├── fuxi
│ | ├── short
│ | ├── short.onnx
│ | ├── medium
│ | ├── medium.onnx
│ | ├── long
│ | ├── long.onnx
| |
│ ├── fuxi.py
│ ├── fuxi_demo.ipynb
│ ├── infernece_fuxi.py
```
1.
Install xarray
.
1.
Install xarray
```
bash
conda
install
-c
conda-forge xarray dask netCDF4 bottleneck
```
conda install -c conda-forge xarray dask netCDF4 bottleneck
```
2.
Install onnxruntime
```
```
bash
pip
install
-r
requirement.txt
```
3.
(Optional) Install PyTorch and CUDA for inference with
`pth`
model
## Demo
### Inferece with onnx (recommend)
```
python
python
inference_fuxi
.
py
--
model
model
/
onnx
--
input
data
/
20210101
/
input
.
nc
```
bash
python inference_fuxi.py
--model
fuxi
--input
data/20210101/input.nc
```
### Inferece with pytorch
The
`fuxi_demo.ipynb`
consists of multiple sections:
1.
Construct the Fuxi model.
2.
Load weights and buffers.
3.
Load the preprocessed input.
4.
Run inference for 15-day forecasting.
5.
Save the results.
6.
Visualization.
## Data preparation
...
...
inference_fuxi.py
View file @
3b536956
...
...
@@ -17,6 +17,7 @@ parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20])
args
=
parser
.
parse_args
()
def
time_encoding
(
init_time
,
total_step
,
freq
=
6
):
init_time
=
np
.
array
([
init_time
])
tembs
=
[]
...
...
@@ -32,18 +33,23 @@ def time_encoding(init_time, total_step, freq=6):
return
np
.
stack
(
tembs
)
def
load_model
(
mo
):
sessions
=
[]
for
name
in
[
"short"
,
"medium"
,
"long"
]:
model_name
=
os
.
path
.
join
(
mo
,
f
"
{
name
}
.onnx"
)
if
os
.
path
.
exists
(
model_name
):
start
=
time
.
perf_counter
()
print
(
f
'Load model from
{
model_name
}
...'
)
session
=
ort
.
InferenceSession
(
model_name
,
providers
=
[
'CUDAExecutionProvider'
])
load_time
=
time
.
perf_counter
()
-
start
print
(
f
'Load model take
{
load_time
:.
2
f
}
sec'
)
sessions
.
append
(
session
)
return
sessions
def
load_model
(
model_name
):
# Set the behavier of onnxruntime
options
=
ort
.
SessionOptions
()
options
.
enable_cpu_mem_arena
=
False
options
.
enable_mem_pattern
=
False
options
.
enable_mem_reuse
=
False
# Increase the number for faster inference and more memory consumption
options
.
intra_op_num_threads
=
1
cuda_provider_options
=
{
'arena_extend_strategy'
:
'kSameAsRequested'
,}
session
=
ort
.
InferenceSession
(
model_name
,
sess_options
=
options
,
providers
=
[(
'CUDAExecutionProvider'
,
cuda_provider_options
)]
)
return
session
def
load_data
(
data_file
):
...
...
@@ -77,7 +83,7 @@ def save_like(output, data, step, save_dir="", freq=6, grid=0.25):
def
run_inference
(
sessions
,
data
,
num_steps
,
save_dir
=
""
):
def
run_inference
(
model_dir
,
data
,
num_steps
,
save_dir
=
""
):
total_step
=
sum
(
num_steps
)
init_time
=
pd
.
to_datetime
(
data
.
time
.
values
[
-
1
])
...
...
@@ -87,31 +93,37 @@ def run_inference(sessions, data, num_steps, save_dir=""):
print
(
f
'input:
{
input
.
shape
}
,
{
input
.
min
():.
2
f
}
~
{
input
.
max
():.
2
f
}
'
)
print
(
f
'tembs:
{
tembs
.
shape
}
,
{
tembs
.
mean
():.
4
f
}
'
)
print
(
'Inference ...'
)
start
=
time
.
perf_counter
()
step
=
0
for
i
,
session
in
enumerate
(
sessions
):
for
i
,
stage
in
enumerate
([
'short'
,
'medium'
,
'long'
]):
start
=
time
.
perf_counter
()
model_name
=
os
.
path
.
join
(
model_dir
,
f
"
{
stage
}
.onnx"
)
print
(
f
'Load model from
{
model_name
}
...'
)
session
=
load_model
(
model_name
)
load_time
=
time
.
perf_counter
()
-
start
print
(
f
'Load model take
{
load_time
:.
2
f
}
sec'
)
print
(
f
'Inference
{
stage
}
...'
)
start
=
time
.
perf_counter
()
for
_
in
range
(
0
,
num_steps
[
i
]):
temb
=
tembs
[
step
]
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
output
=
new_input
[:,
-
1
]
save_like
(
output
,
data
,
step
,
save_dir
)
print
(
f
'stage:
{
i
}
, step:
{
step
+
1
:
02
d
}
, output:
{
output
.
min
():.
2
f
}
{
output
.
max
():.
2
f
}
'
)
input
=
new_input
step
+=
1
run_time
=
time
.
perf_counter
()
-
start
print
(
f
'Inference
{
stage
}
take
{
run_time
:.
2
f
}
'
)
if
step
>
total_step
:
break
run_time
=
time
.
perf_counter
()
-
start
print
(
f
'Inference take
{
run_time
:.
2
f
}
for
{
total_step
}
step'
)
if
__name__
==
"__main__"
:
sessions
=
load_model
(
args
.
model
)
data
=
xr
.
open_dataarray
(
args
.
input
)
run_inference
(
sessions
,
data
,
args
.
num_steps
,
args
.
save_dir
)
run_inference
(
args
.
model
,
data
,
args
.
num_steps
,
args
.
save_dir
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment