inference_fuxi.py 3.82 KB
Newer Older
tpys's avatar
tpys committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import os
import time 
import numpy as np
import xarray as xr
import pandas as pd
import onnxruntime as ort

ort.set_default_logger_severity(3)


parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir")
parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format")
parser.add_argument('--save_dir', type=str, default="")
parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20])
args = parser.parse_args()


def time_encoding(init_time, total_step, freq=6):
    init_time = np.array([init_time])
    tembs = []
    for i in range(total_step):
        hours = np.array([pd.Timedelta(hours=t*freq) for t in [i-1, i, i+1]])
        times = init_time[:, None] + hours[None]
        times = [pd.Period(t, 'H') for t in times.reshape(-1)]
        times = [(p.day_of_year/366, p.hour/24) for p in times]
        temb = np.array(times, dtype=np.float32)
        temb = np.concatenate([np.sin(temb), np.cos(temb)], axis=-1)
        temb = temb.reshape(1, -1)
        tembs.append(temb)
    return np.stack(tembs)


def load_model(mo):
    sessions = []
    for name in ["short", "medium", "long"]:
        model_name = os.path.join(mo, f"{name}.onnx")
        if os.path.exists(model_name):
            start = time.perf_counter()
            print(f'Load model from {model_name} ...')
            session = ort.InferenceSession(model_name,  providers=['CUDAExecutionProvider'])
            load_time = time.perf_counter() - start
            print(f'Load model take {load_time:.2f} sec')
            sessions.append(session)
    return sessions


def load_data(data_file):
    input = xr.open_dataarray(data_file)
    return input


def save_like(output, data, step, save_dir="", freq=6, grid=0.25):
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)

        lead_time = (step+1) * freq
        init_time = pd.to_datetime(data.time.values[-1])

        lat = np.linspace(-90, 90, int(180/grid)+1, dtype=np.float32)
        lon = np.arange(0, 360, grid, dtype=np.float32)  
        fcst_time = init_time + pd.Timedelta(hours=lead_time)

        output = xr.DataArray(
            output, # 1 x 70 x 721 x 1440
            dims=['time', 'level', 'lat', 'lon'],
            coords=dict(
                time=[fcst_time],
                level=data.level,
                lat=lat,
                lon=lon,
            )
        )            
        save_name = os.path.join(save_dir, f'{lead_time:03d}.nc')
        output.to_netcdf(save_name)



def run_inference(sessions, data, num_steps, save_dir=""):
    total_step = sum(num_steps)
    init_time = pd.to_datetime(data.time.values[-1])

    tembs = time_encoding(init_time, total_step)
    input = data.values[None]

    print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}')
    print(f'tembs: {tembs.shape}, {tembs.mean():.4f}')

    print('Inference ...')
    start = time.perf_counter()

    step = 0
    for i, session in enumerate(sessions):
        for _ in range(0, num_steps[i]):
            temb = tembs[step]
            new_input, = session.run(None, {'input': input, 'temb': temb})
            output = new_input[:, -1] 

            save_like(output, data, step, save_dir)

            print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}')
            input = new_input
            step += 1

        if step > total_step:
            break

    run_time = time.perf_counter() - start
    print(f'Inference take {run_time:.2f} for {total_step} step')

    
if __name__ == "__main__":
    sessions = load_model(args.model)
    data = xr.open_dataarray(args.input)
    run_inference(sessions, data, args.num_steps, args.save_dir)