Commit 3b536956 authored by tpys's avatar tpys
Browse files

inference separately for memory efficiency

parent eed7f002
...@@ -9,11 +9,6 @@ This is the official repository for the FuXi paper. ...@@ -9,11 +9,6 @@ This is the official repository for the FuXi paper.
by Lei Chen, Xiaohui Zhong, Feng Zhang, Yuan Cheng, Yinghui Xu, Yuan Qi, Hao Li by Lei Chen, Xiaohui Zhong, Feng Zhang, Yuan Cheng, Yinghui Xu, Yuan Qi, Hao Li
## What's New
- Release of the ONNX model and inference code.
- Addition of new sample data (20210101).
## Installation ## Installation
...@@ -24,67 +19,39 @@ The downloaded files shall be organized as the following hierarchy: ...@@ -24,67 +19,39 @@ The downloaded files shall be organized as the following hierarchy:
```plain ```plain
├── root ├── root
│ ├── data │ ├── data
│ │ ├── 20180101
│ │ │ ├── input.nc
│ │ │ ├── output.nc
│ │ │
│ │ ├── 20210101 │ │ ├── 20210101
│ │ ├── input.nc │ │ ├── input.nc
│ │ ├── output.nc │ │ ├── output.nc
│ │ ├── target.nc │ │ ├── target.nc
│ │ │ │
│ ├── model │ ├── fuxi
│ | ├── buffer.st │ | ├── short
│ | ├── fuxi_short.st │ | ├── short.onnx
│ | ├── fuxi_medium.st │ | ├── medium
│ | ├── fuxi_long.st │ | ├── medium.onnx
│ | ├── onnx │ | ├── long
│ | ├── short │ | ├── long.onnx
│ | ├── short.onnx
│ | ├── medium
│ | ├── medium.onnx
│ | ├── long
│ | ├── long.onnx
| | | |
│ ├── fuxi.py
│ ├── fuxi_demo.ipynb
│ ├── infernece_fuxi.py │ ├── infernece_fuxi.py
``` ```
1. Install xarray. 1. Install xarray
```bash
conda install -c conda-forge xarray dask netCDF4 bottleneck
``` ```
conda install -c conda-forge xarray dask netCDF4 bottleneck
```
2. Install onnxruntime 2. Install onnxruntime
```bash
```
pip install -r requirement.txt pip install -r requirement.txt
``` ```
3. (Optional) Install PyTorch and CUDA for inference with `pth` model
## Demo ## Demo
### Inferece with onnx (recommend) ```bash
python inference_fuxi.py --model fuxi --input data/20210101/input.nc
```python
python inference_fuxi.py --model model/onnx --input data/20210101/input.nc
``` ```
### Inferece with pytorch
The `fuxi_demo.ipynb` consists of multiple sections:
1. Construct the Fuxi model.
2. Load weights and buffers.
3. Load the preprocessed input.
4. Run inference for 15-day forecasting.
5. Save the results.
6. Visualization.
## Data preparation ## Data preparation
......
...@@ -17,6 +17,7 @@ parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20]) ...@@ -17,6 +17,7 @@ parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20])
args = parser.parse_args() args = parser.parse_args()
def time_encoding(init_time, total_step, freq=6): def time_encoding(init_time, total_step, freq=6):
init_time = np.array([init_time]) init_time = np.array([init_time])
tembs = [] tembs = []
...@@ -32,18 +33,23 @@ def time_encoding(init_time, total_step, freq=6): ...@@ -32,18 +33,23 @@ def time_encoding(init_time, total_step, freq=6):
return np.stack(tembs) return np.stack(tembs)
def load_model(mo):
sessions = [] def load_model(model_name):
for name in ["short", "medium", "long"]: # Set the behavier of onnxruntime
model_name = os.path.join(mo, f"{name}.onnx") options = ort.SessionOptions()
if os.path.exists(model_name): options.enable_cpu_mem_arena=False
start = time.perf_counter() options.enable_mem_pattern = False
print(f'Load model from {model_name} ...') options.enable_mem_reuse = False
session = ort.InferenceSession(model_name, providers=['CUDAExecutionProvider']) # Increase the number for faster inference and more memory consumption
load_time = time.perf_counter() - start options.intra_op_num_threads = 1
print(f'Load model take {load_time:.2f} sec') cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
sessions.append(session)
return sessions session = ort.InferenceSession(
model_name,
sess_options=options,
providers=[('CUDAExecutionProvider', cuda_provider_options)]
)
return session
def load_data(data_file): def load_data(data_file):
...@@ -77,7 +83,7 @@ def save_like(output, data, step, save_dir="", freq=6, grid=0.25): ...@@ -77,7 +83,7 @@ def save_like(output, data, step, save_dir="", freq=6, grid=0.25):
def run_inference(sessions, data, num_steps, save_dir=""): def run_inference(model_dir, data, num_steps, save_dir=""):
total_step = sum(num_steps) total_step = sum(num_steps)
init_time = pd.to_datetime(data.time.values[-1]) init_time = pd.to_datetime(data.time.values[-1])
...@@ -87,31 +93,37 @@ def run_inference(sessions, data, num_steps, save_dir=""): ...@@ -87,31 +93,37 @@ def run_inference(sessions, data, num_steps, save_dir=""):
print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}') print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}')
print(f'tembs: {tembs.shape}, {tembs.mean():.4f}') print(f'tembs: {tembs.shape}, {tembs.mean():.4f}')
print('Inference ...')
start = time.perf_counter()
step = 0 step = 0
for i, session in enumerate(sessions): for i, stage in enumerate(['short', 'medium', 'long']):
start = time.perf_counter()
model_name = os.path.join(model_dir, f"{stage}.onnx")
print(f'Load model from {model_name} ...')
session = load_model(model_name)
load_time = time.perf_counter() - start
print(f'Load model take {load_time:.2f} sec')
print(f'Inference {stage} ...')
start = time.perf_counter()
for _ in range(0, num_steps[i]): for _ in range(0, num_steps[i]):
temb = tembs[step] temb = tembs[step]
new_input, = session.run(None, {'input': input, 'temb': temb}) new_input, = session.run(None, {'input': input, 'temb': temb})
output = new_input[:, -1] output = new_input[:, -1]
save_like(output, data, step, save_dir) save_like(output, data, step, save_dir)
print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}') print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}')
input = new_input input = new_input
step += 1 step += 1
run_time = time.perf_counter() - start
print(f'Inference {stage} take {run_time:.2f}')
if step > total_step: if step > total_step:
break break
run_time = time.perf_counter() - start
print(f'Inference take {run_time:.2f} for {total_step} step')
if __name__ == "__main__": if __name__ == "__main__":
sessions = load_model(args.model)
data = xr.open_dataarray(args.input) data = xr.open_dataarray(args.input)
run_inference(sessions, data, args.num_steps, args.save_dir) run_inference(args.model, data, args.num_steps, args.save_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment