Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
FuXi
Commits
3b536956
Commit
3b536956
authored
Jul 27, 2023
by
tpys
Browse files
inference separately for memory efficiency
parent
eed7f002
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
71 deletions
+50
-71
README.md
README.md
+16
-49
inference_fuxi.py
inference_fuxi.py
+34
-22
No files found.
README.md
View file @
3b536956
...
@@ -9,11 +9,6 @@ This is the official repository for the FuXi paper.
...
@@ -9,11 +9,6 @@ This is the official repository for the FuXi paper.
by Lei Chen, Xiaohui Zhong, Feng Zhang, Yuan Cheng, Yinghui Xu, Yuan Qi, Hao Li
by Lei Chen, Xiaohui Zhong, Feng Zhang, Yuan Cheng, Yinghui Xu, Yuan Qi, Hao Li
## What's New
-
Release of the ONNX model and inference code.
-
Addition of new sample data (20210101).
## Installation
## Installation
...
@@ -24,67 +19,39 @@ The downloaded files shall be organized as the following hierarchy:
...
@@ -24,67 +19,39 @@ The downloaded files shall be organized as the following hierarchy:
```
plain
```
plain
├── root
├── root
│ ├── data
│ ├── data
│ │ ├── 20180101
│ │ │ ├── input.nc
│ │ │ ├── output.nc
│ │ │
│ │ ├── 20210101
│ │ ├── 20210101
│ │
├── input.nc
│ │ ├── input.nc
│ │
├── output.nc
│ │ ├── output.nc
│ │
├── target.nc
│ │ ├── target.nc
│ │
│ │
│ ├── model
│ ├── fuxi
│ | ├── buffer.st
│ | ├── short
│ | ├── fuxi_short.st
│ | ├── short.onnx
│ | ├── fuxi_medium.st
│ | ├── medium
│ | ├── fuxi_long.st
│ | ├── medium.onnx
│ | ├── onnx
│ | ├── long
│ | ├── short
│ | ├── long.onnx
│ | ├── short.onnx
│ | ├── medium
│ | ├── medium.onnx
│ | ├── long
│ | ├── long.onnx
| |
| |
│ ├── fuxi.py
│ ├── fuxi_demo.ipynb
│ ├── infernece_fuxi.py
│ ├── infernece_fuxi.py
```
```
1.
Install xarray
.
1.
Install xarray
```
bash
conda
install
-c
conda-forge xarray dask netCDF4 bottleneck
```
```
conda install -c conda-forge xarray dask netCDF4 bottleneck
```
2.
Install onnxruntime
2.
Install onnxruntime
```
bash
```
pip
install
-r
requirement.txt
pip
install
-r
requirement.txt
```
```
3.
(Optional) Install PyTorch and CUDA for inference with
`pth`
model
## Demo
## Demo
### Inferece with onnx (recommend)
```
bash
python inference_fuxi.py
--model
fuxi
--input
data/20210101/input.nc
```
python
python
inference_fuxi
.
py
--
model
model
/
onnx
--
input
data
/
20210101
/
input
.
nc
```
```
### Inferece with pytorch
The
`fuxi_demo.ipynb`
consists of multiple sections:
1.
Construct the Fuxi model.
2.
Load weights and buffers.
3.
Load the preprocessed input.
4.
Run inference for 15-day forecasting.
5.
Save the results.
6.
Visualization.
## Data preparation
## Data preparation
...
...
inference_fuxi.py
View file @
3b536956
...
@@ -17,6 +17,7 @@ parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20])
...
@@ -17,6 +17,7 @@ parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20])
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
def
time_encoding
(
init_time
,
total_step
,
freq
=
6
):
def
time_encoding
(
init_time
,
total_step
,
freq
=
6
):
init_time
=
np
.
array
([
init_time
])
init_time
=
np
.
array
([
init_time
])
tembs
=
[]
tembs
=
[]
...
@@ -32,18 +33,23 @@ def time_encoding(init_time, total_step, freq=6):
...
@@ -32,18 +33,23 @@ def time_encoding(init_time, total_step, freq=6):
return
np
.
stack
(
tembs
)
return
np
.
stack
(
tembs
)
def
load_model
(
mo
):
sessions
=
[]
def
load_model
(
model_name
):
for
name
in
[
"short"
,
"medium"
,
"long"
]:
# Set the behavier of onnxruntime
model_name
=
os
.
path
.
join
(
mo
,
f
"
{
name
}
.onnx"
)
options
=
ort
.
SessionOptions
()
if
os
.
path
.
exists
(
model_name
):
options
.
enable_cpu_mem_arena
=
False
start
=
time
.
perf_counter
()
options
.
enable_mem_pattern
=
False
print
(
f
'Load model from
{
model_name
}
...'
)
options
.
enable_mem_reuse
=
False
session
=
ort
.
InferenceSession
(
model_name
,
providers
=
[
'CUDAExecutionProvider'
])
# Increase the number for faster inference and more memory consumption
load_time
=
time
.
perf_counter
()
-
start
options
.
intra_op_num_threads
=
1
print
(
f
'Load model take
{
load_time
:.
2
f
}
sec'
)
cuda_provider_options
=
{
'arena_extend_strategy'
:
'kSameAsRequested'
,}
sessions
.
append
(
session
)
return
sessions
session
=
ort
.
InferenceSession
(
model_name
,
sess_options
=
options
,
providers
=
[(
'CUDAExecutionProvider'
,
cuda_provider_options
)]
)
return
session
def
load_data
(
data_file
):
def
load_data
(
data_file
):
...
@@ -77,7 +83,7 @@ def save_like(output, data, step, save_dir="", freq=6, grid=0.25):
...
@@ -77,7 +83,7 @@ def save_like(output, data, step, save_dir="", freq=6, grid=0.25):
def
run_inference
(
sessions
,
data
,
num_steps
,
save_dir
=
""
):
def
run_inference
(
model_dir
,
data
,
num_steps
,
save_dir
=
""
):
total_step
=
sum
(
num_steps
)
total_step
=
sum
(
num_steps
)
init_time
=
pd
.
to_datetime
(
data
.
time
.
values
[
-
1
])
init_time
=
pd
.
to_datetime
(
data
.
time
.
values
[
-
1
])
...
@@ -87,31 +93,37 @@ def run_inference(sessions, data, num_steps, save_dir=""):
...
@@ -87,31 +93,37 @@ def run_inference(sessions, data, num_steps, save_dir=""):
print
(
f
'input:
{
input
.
shape
}
,
{
input
.
min
():.
2
f
}
~
{
input
.
max
():.
2
f
}
'
)
print
(
f
'input:
{
input
.
shape
}
,
{
input
.
min
():.
2
f
}
~
{
input
.
max
():.
2
f
}
'
)
print
(
f
'tembs:
{
tembs
.
shape
}
,
{
tembs
.
mean
():.
4
f
}
'
)
print
(
f
'tembs:
{
tembs
.
shape
}
,
{
tembs
.
mean
():.
4
f
}
'
)
print
(
'Inference ...'
)
start
=
time
.
perf_counter
()
step
=
0
step
=
0
for
i
,
session
in
enumerate
(
sessions
):
for
i
,
stage
in
enumerate
([
'short'
,
'medium'
,
'long'
]):
start
=
time
.
perf_counter
()
model_name
=
os
.
path
.
join
(
model_dir
,
f
"
{
stage
}
.onnx"
)
print
(
f
'Load model from
{
model_name
}
...'
)
session
=
load_model
(
model_name
)
load_time
=
time
.
perf_counter
()
-
start
print
(
f
'Load model take
{
load_time
:.
2
f
}
sec'
)
print
(
f
'Inference
{
stage
}
...'
)
start
=
time
.
perf_counter
()
for
_
in
range
(
0
,
num_steps
[
i
]):
for
_
in
range
(
0
,
num_steps
[
i
]):
temb
=
tembs
[
step
]
temb
=
tembs
[
step
]
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
new_input
,
=
session
.
run
(
None
,
{
'input'
:
input
,
'temb'
:
temb
})
output
=
new_input
[:,
-
1
]
output
=
new_input
[:,
-
1
]
save_like
(
output
,
data
,
step
,
save_dir
)
save_like
(
output
,
data
,
step
,
save_dir
)
print
(
f
'stage:
{
i
}
, step:
{
step
+
1
:
02
d
}
, output:
{
output
.
min
():.
2
f
}
{
output
.
max
():.
2
f
}
'
)
print
(
f
'stage:
{
i
}
, step:
{
step
+
1
:
02
d
}
, output:
{
output
.
min
():.
2
f
}
{
output
.
max
():.
2
f
}
'
)
input
=
new_input
input
=
new_input
step
+=
1
step
+=
1
run_time
=
time
.
perf_counter
()
-
start
print
(
f
'Inference
{
stage
}
take
{
run_time
:.
2
f
}
'
)
if
step
>
total_step
:
if
step
>
total_step
:
break
break
run_time
=
time
.
perf_counter
()
-
start
print
(
f
'Inference take
{
run_time
:.
2
f
}
for
{
total_step
}
step'
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
sessions
=
load_model
(
args
.
model
)
data
=
xr
.
open_dataarray
(
args
.
input
)
data
=
xr
.
open_dataarray
(
args
.
input
)
run_inference
(
sessions
,
data
,
args
.
num_steps
,
args
.
save_dir
)
run_inference
(
args
.
model
,
data
,
args
.
num_steps
,
args
.
save_dir
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment