Commit 3b536956 authored by tpys's avatar tpys
Browse files

inference separately for memory efficiency

parent eed7f002
......@@ -9,11 +9,6 @@ This is the official repository for the FuXi paper.
by Lei Chen, Xiaohui Zhong, Feng Zhang, Yuan Cheng, Yinghui Xu, Yuan Qi, Hao Li
## What's New
- Release of the ONNX model and inference code.
- Addition of new sample data (20210101).
## Installation
......@@ -24,67 +19,39 @@ The downloaded files shall be organized as the following hierarchy:
```plain
├── root
│ ├── data
│ │ ├── 20180101
│ │ │ ├── input.nc
│ │ │ ├── output.nc
│ │ │
│ │ ├── 20210101
│ │ ├── input.nc
│ │ ├── output.nc
│ │ ├── target.nc
│ │ ├── input.nc
│ │ ├── output.nc
│ │ ├── target.nc
│ │
│ ├── model
│ | ├── buffer.st
│ | ├── fuxi_short.st
│ | ├── fuxi_medium.st
│ | ├── fuxi_long.st
│ | ├── onnx
│ | ├── short
│ | ├── short.onnx
│ | ├── medium
│ | ├── medium.onnx
│ | ├── long
│ | ├── long.onnx
│ ├── fuxi
│ | ├── short
│ | ├── short.onnx
│ | ├── medium
│ | ├── medium.onnx
│ | ├── long
│ | ├── long.onnx
| |
│ ├── fuxi.py
│ ├── fuxi_demo.ipynb
│ ├── infernece_fuxi.py
```
1. Install xarray.
1. Install xarray
```bash
conda install -c conda-forge xarray dask netCDF4 bottleneck
```
conda install -c conda-forge xarray dask netCDF4 bottleneck
```
2. Install onnxruntime
```
```bash
pip install -r requirement.txt
```
3. (Optional) Install PyTorch and CUDA for inference with `pth` model
## Demo
### Inferece with onnx (recommend)
```python
python inference_fuxi.py --model model/onnx --input data/20210101/input.nc
```bash
python inference_fuxi.py --model fuxi --input data/20210101/input.nc
```
### Inferece with pytorch
The `fuxi_demo.ipynb` consists of multiple sections:
1. Construct the Fuxi model.
2. Load weights and buffers.
3. Load the preprocessed input.
4. Run inference for 15-day forecasting.
5. Save the results.
6. Visualization.
## Data preparation
......
......@@ -17,6 +17,7 @@ parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20])
args = parser.parse_args()
def time_encoding(init_time, total_step, freq=6):
init_time = np.array([init_time])
tembs = []
......@@ -32,18 +33,23 @@ def time_encoding(init_time, total_step, freq=6):
return np.stack(tembs)
def load_model(mo):
sessions = []
for name in ["short", "medium", "long"]:
model_name = os.path.join(mo, f"{name}.onnx")
if os.path.exists(model_name):
start = time.perf_counter()
print(f'Load model from {model_name} ...')
session = ort.InferenceSession(model_name, providers=['CUDAExecutionProvider'])
load_time = time.perf_counter() - start
print(f'Load model take {load_time:.2f} sec')
sessions.append(session)
return sessions
def load_model(model_name):
# Set the behavier of onnxruntime
options = ort.SessionOptions()
options.enable_cpu_mem_arena=False
options.enable_mem_pattern = False
options.enable_mem_reuse = False
# Increase the number for faster inference and more memory consumption
options.intra_op_num_threads = 1
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
session = ort.InferenceSession(
model_name,
sess_options=options,
providers=[('CUDAExecutionProvider', cuda_provider_options)]
)
return session
def load_data(data_file):
......@@ -77,7 +83,7 @@ def save_like(output, data, step, save_dir="", freq=6, grid=0.25):
def run_inference(sessions, data, num_steps, save_dir=""):
def run_inference(model_dir, data, num_steps, save_dir=""):
total_step = sum(num_steps)
init_time = pd.to_datetime(data.time.values[-1])
......@@ -87,31 +93,37 @@ def run_inference(sessions, data, num_steps, save_dir=""):
print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}')
print(f'tembs: {tembs.shape}, {tembs.mean():.4f}')
print('Inference ...')
start = time.perf_counter()
step = 0
for i, session in enumerate(sessions):
for i, stage in enumerate(['short', 'medium', 'long']):
start = time.perf_counter()
model_name = os.path.join(model_dir, f"{stage}.onnx")
print(f'Load model from {model_name} ...')
session = load_model(model_name)
load_time = time.perf_counter() - start
print(f'Load model take {load_time:.2f} sec')
print(f'Inference {stage} ...')
start = time.perf_counter()
for _ in range(0, num_steps[i]):
temb = tembs[step]
new_input, = session.run(None, {'input': input, 'temb': temb})
output = new_input[:, -1]
save_like(output, data, step, save_dir)
print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}')
input = new_input
step += 1
run_time = time.perf_counter() - start
print(f'Inference {stage} take {run_time:.2f}')
if step > total_step:
break
run_time = time.perf_counter() - start
print(f'Inference take {run_time:.2f} for {total_step} step')
if __name__ == "__main__":
sessions = load_model(args.model)
data = xr.open_dataarray(args.input)
run_inference(sessions, data, args.num_steps, args.save_dir)
run_inference(args.model, data, args.num_steps, args.save_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment