inference_fuxi.py 3.86 KB
Newer Older
tpys's avatar
tpys committed
1
2
3
4
5
6
7
8
import argparse
import os
import time 
import numpy as np
import xarray as xr
import pandas as pd
import onnxruntime as ort

tpys's avatar
tpys committed
9
10
from data_util import save_like

tpys's avatar
tpys committed
11
12
13
14
15
ort.set_default_logger_severity(3)

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir")
parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format")
tpys's avatar
tpys committed
16
parser.add_argument('--drop_prob', type=float, help="dropout prob", default=0)
tpys's avatar
tpys committed
17
parser.add_argument('--input_type', type=str, help="The input type", default="ERA5")
tpys's avatar
tpys committed
18
parser.add_argument('--save_dir', type=str, default="")
tpys's avatar
tpys committed
19
parser.add_argument('--num_steps', type=int, nargs="+", default=[20])
tpys's avatar
tpys committed
20
21
22
args = parser.parse_args()


tpys's avatar
tpys committed
23
24
assert args.input_type.upper() in ["ERA5", "GFS", "HRES"]

tpys's avatar
tpys committed
25

tpys's avatar
tpys committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def time_encoding(init_time, total_step, freq=6):
    init_time = np.array([init_time])
    tembs = []
    for i in range(total_step):
        hours = np.array([pd.Timedelta(hours=t*freq) for t in [i-1, i, i+1]])
        times = init_time[:, None] + hours[None]
        times = [pd.Period(t, 'H') for t in times.reshape(-1)]
        times = [(p.day_of_year/366, p.hour/24) for p in times]
        temb = np.array(times, dtype=np.float32)
        temb = np.concatenate([np.sin(temb), np.cos(temb)], axis=-1)
        temb = temb.reshape(1, -1)
        tembs.append(temb)
    return np.stack(tembs)


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

def load_model(model_name):
    # Set the behavier of onnxruntime
    options = ort.SessionOptions()
    options.enable_cpu_mem_arena=False
    options.enable_mem_pattern = False
    options.enable_mem_reuse = False
    # Increase the number for faster inference and more memory consumption
    options.intra_op_num_threads = 1
    cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}

    session = ort.InferenceSession(
        model_name,  
        sess_options=options, 
        providers=[('CUDAExecutionProvider', cuda_provider_options)]
    )
    return session
tpys's avatar
tpys committed
58
59


60
def run_inference(model_dir, data, num_steps, save_dir=""):
tpys's avatar
tpys committed
61

tpys's avatar
tpys committed
62
63
64
65
    total_step = sum(num_steps)
    init_time = pd.to_datetime(data.time.values[-1])
    tembs = time_encoding(init_time, total_step)

tpys's avatar
tpys committed
66
67
    print(f'init_time: {init_time.strftime(("%Y%m%d-%H"))}')
    print(f'latitude: {data.lat.values[0]} ~ {data.lat.values[-1]}')
tpys's avatar
tpys committed
68
69
70
71
    
    assert data.lat.values[0] == 90
    assert data.lat.values[-1] == -90

tpys's avatar
tpys committed
72
    input = data.values[None]
tpys's avatar
tpys committed
73
74
    prob = np.array([args.drop_prob], dtype=np.float32)

tpys's avatar
tpys committed
75
76
77
    print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}')
    print(f'tembs: {tembs.shape}, {tembs.mean():.4f}')

tpys's avatar
tpys committed
78
    stages = ['short', 'medium', 'long']
tpys's avatar
tpys committed
79
    step = 0
tpys's avatar
tpys committed
80
81
82

    for i, num_step in enumerate(num_steps):
        stage = stages[i]
83
84
85
86
87
88
89
90
91
92
        start = time.perf_counter()
        model_name = os.path.join(model_dir, f"{stage}.onnx")
        print(f'Load model from {model_name} ...')        
        session = load_model(model_name)
        load_time = time.perf_counter() - start
        print(f'Load model take {load_time:.2f} sec')

        print(f'Inference {stage} ...')
        start = time.perf_counter()

tpys's avatar
tpys committed
93
        for _ in range(0, num_step):
tpys's avatar
tpys committed
94
            temb = tembs[step]
tpys's avatar
tpys committed
95
            new_input, = session.run(None, {'input': input, 'temb': temb, 'prob': prob})
tpys's avatar
tpys committed
96
            output = new_input[:, -1] 
tpys's avatar
tpys committed
97
            save_like(output, data, step, save_dir, input_type=args.input_type)
tpys's avatar
tpys committed
98
99
100
101
            print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}')
            input = new_input
            step += 1

102
103
104
        run_time = time.perf_counter() - start
        print(f'Inference {stage} take {run_time:.2f}')

tpys's avatar
tpys committed
105
106
107
108
109
110
111
        if step > total_step:
            break


    
if __name__ == "__main__":
    data = xr.open_dataarray(args.input)
112
    run_inference(args.model, data, args.num_steps, args.save_dir)
tpys's avatar
tpys committed
113