inference_fuxi.py 4.28 KB
Newer Older
tpys's avatar
tpys committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import argparse
import os
import time 
import numpy as np
import xarray as xr
import pandas as pd
import onnxruntime as ort

ort.set_default_logger_severity(3)

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir")
parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format")
parser.add_argument('--save_dir', type=str, default="")
parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20])
args = parser.parse_args()


def time_encoding(init_time, total_step, freq=6):
    init_time = np.array([init_time])
    tembs = []
    for i in range(total_step):
        hours = np.array([pd.Timedelta(hours=t*freq) for t in [i-1, i, i+1]])
        times = init_time[:, None] + hours[None]
        times = [pd.Period(t, 'H') for t in times.reshape(-1)]
        times = [(p.day_of_year/366, p.hour/24) for p in times]
        temb = np.array(times, dtype=np.float32)
        temb = np.concatenate([np.sin(temb), np.cos(temb)], axis=-1)
        temb = temb.reshape(1, -1)
        tembs.append(temb)
    return np.stack(tembs)


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

def load_model(model_name):
    # Set the behavier of onnxruntime
    options = ort.SessionOptions()
    options.enable_cpu_mem_arena=False
    options.enable_mem_pattern = False
    options.enable_mem_reuse = False
    # Increase the number for faster inference and more memory consumption
    options.intra_op_num_threads = 1
    cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}

    session = ort.InferenceSession(
        model_name,  
        sess_options=options, 
        providers=[('CUDAExecutionProvider', cuda_provider_options)]
    )
    return session
tpys's avatar
tpys committed
51
52
53
54
55
56
57


def load_data(data_file):
    input = xr.open_dataarray(data_file)
    return input


tpys's avatar
tpys committed
58
def save_like(output, data, step, save_dir="", freq=6):
tpys's avatar
tpys committed
59
60
61
62
63
64
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)

        lead_time = (step+1) * freq
        init_time = pd.to_datetime(data.time.values[-1])
        fcst_time = init_time + pd.Timedelta(hours=lead_time)
tpys's avatar
tpys committed
65
        
tpys's avatar
tpys committed
66
67
68
69
70
71
        output = xr.DataArray(
            output, # 1 x 70 x 721 x 1440
            dims=['time', 'level', 'lat', 'lon'],
            coords=dict(
                time=[fcst_time],
                level=data.level,
tpys's avatar
tpys committed
72
73
                lat=data.lat,
                lon=data.lon,
tpys's avatar
tpys committed
74
            )
tpys's avatar
tpys committed
75
76
        )  
        output.name = 'data'
tpys's avatar
tpys committed
77
78
79
80
81
        save_name = os.path.join(save_dir, f'{lead_time:03d}.nc')
        output.to_netcdf(save_name)



82
def run_inference(model_dir, data, num_steps, save_dir=""):
tpys's avatar
tpys committed
83
84
85
86
87
    total_step = sum(num_steps)
    init_time = pd.to_datetime(data.time.values[-1])

    tembs = time_encoding(init_time, total_step)

tpys's avatar
tpys committed
88
89
    print(f'init_time: {init_time.strftime(("%Y%m%d-%H"))}')
    print(f'latitude: {data.lat.values[0]} ~ {data.lat.values[-1]}')
tpys's avatar
tpys committed
90
91
92
93
    
    assert data.lat.values[0] == 90
    assert data.lat.values[-1] == -90

tpys's avatar
tpys committed
94
    input = data.values[None]
tpys's avatar
tpys committed
95
96
97
98
99
    print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}')
    print(f'tembs: {tembs.shape}, {tembs.mean():.4f}')


    step = 0
100
101
102
103
104
105
106
107
108
109
110
    for i, stage in enumerate(['short', 'medium', 'long']):
        start = time.perf_counter()
        model_name = os.path.join(model_dir, f"{stage}.onnx")
        print(f'Load model from {model_name} ...')        
        session = load_model(model_name)
        load_time = time.perf_counter() - start
        print(f'Load model take {load_time:.2f} sec')

        print(f'Inference {stage} ...')
        start = time.perf_counter()

tpys's avatar
tpys committed
111
112
113
114
115
116
117
118
119
        for _ in range(0, num_steps[i]):
            temb = tembs[step]
            new_input, = session.run(None, {'input': input, 'temb': temb})
            output = new_input[:, -1] 
            save_like(output, data, step, save_dir)
            print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}')
            input = new_input
            step += 1

120
121
122
        run_time = time.perf_counter() - start
        print(f'Inference {stage} take {run_time:.2f}')

tpys's avatar
tpys committed
123
124
125
126
127
128
129
        if step > total_step:
            break


    
if __name__ == "__main__":
    data = xr.open_dataarray(args.input)
130
    run_inference(args.model, data, args.num_steps, args.save_dir)
tpys's avatar
tpys committed
131