Commit 37384df6 authored by tpys's avatar tpys
Browse files

inference fuxi with onnx model

parent 83ef4f67
import argparse
import os
import time
import numpy as np
import xarray as xr
import pandas as pd
import onnxruntime as ort
ort.set_default_logger_severity(3)
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir")
parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format")
parser.add_argument('--save_dir', type=str, default="")
parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20])
args = parser.parse_args()
def time_encoding(init_time, total_step, freq=6):
init_time = np.array([init_time])
tembs = []
for i in range(total_step):
hours = np.array([pd.Timedelta(hours=t*freq) for t in [i-1, i, i+1]])
times = init_time[:, None] + hours[None]
times = [pd.Period(t, 'H') for t in times.reshape(-1)]
times = [(p.day_of_year/366, p.hour/24) for p in times]
temb = np.array(times, dtype=np.float32)
temb = np.concatenate([np.sin(temb), np.cos(temb)], axis=-1)
temb = temb.reshape(1, -1)
tembs.append(temb)
return np.stack(tembs)
def load_model(mo):
sessions = []
for name in ["short", "medium", "long"]:
model_name = os.path.join(mo, f"{name}.onnx")
if os.path.exists(model_name):
start = time.perf_counter()
print(f'Load model from {model_name} ...')
session = ort.InferenceSession(model_name, providers=['CUDAExecutionProvider'])
load_time = time.perf_counter() - start
print(f'Load model take {load_time:.2f} sec')
sessions.append(session)
return sessions
def load_data(data_file):
input = xr.open_dataarray(data_file)
return input
def save_like(output, data, step, save_dir="", freq=6, grid=0.25):
if save_dir:
os.makedirs(save_dir, exist_ok=True)
lead_time = (step+1) * freq
init_time = pd.to_datetime(data.time.values[-1])
lat = np.linspace(-90, 90, int(180/grid)+1, dtype=np.float32)
lon = np.arange(0, 360, grid, dtype=np.float32)
fcst_time = init_time + pd.Timedelta(hours=lead_time)
output = xr.DataArray(
output, # 1 x 70 x 721 x 1440
dims=['time', 'level', 'lat', 'lon'],
coords=dict(
time=[fcst_time],
level=data.level,
lat=lat,
lon=lon,
)
)
save_name = os.path.join(save_dir, f'{lead_time:03d}.nc')
output.to_netcdf(save_name)
def run_inference(sessions, data, num_steps, save_dir=""):
total_step = sum(num_steps)
init_time = pd.to_datetime(data.time.values[-1])
tembs = time_encoding(init_time, total_step)
input = data.values[None]
print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}')
print(f'tembs: {tembs.shape}, {tembs.mean():.4f}')
print('Inference ...')
start = time.perf_counter()
step = 0
for i, session in enumerate(sessions):
for _ in range(0, num_steps[i]):
temb = tembs[step]
new_input, = session.run(None, {'input': input, 'temb': temb})
output = new_input[:, -1]
save_like(output, data, step, save_dir)
print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}')
input = new_input
step += 1
if step > total_step:
break
run_time = time.perf_counter() - start
print(f'Inference take {run_time:.2f} for {total_step} step')
if __name__ == "__main__":
sessions = load_model(args.model)
data = xr.open_dataarray(args.input)
run_inference(sessions, data, args.num_steps, args.save_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment