Commit f3e35c25 authored by tpys's avatar tpys
Browse files

inference with short only

parent 46cb64ba
......@@ -14,10 +14,11 @@ parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir")
parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format")
parser.add_argument('--save_dir', type=str, default="")
parser.add_argument('--num_steps', type=int, nargs="+", default=[20, 20, 20])
parser.add_argument('--num_steps', type=int, nargs="+", default=[20])
args = parser.parse_args()
def time_encoding(init_time, total_step, freq=6):
init_time = np.array([init_time])
tembs = []
......@@ -67,9 +68,11 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}')
print(f'tembs: {tembs.shape}, {tembs.mean():.4f}')
stages = ['short', 'medium', 'long']
step = 0
for i, stage in enumerate(['short', 'medium', 'long']):
for i, num_step in enumerate(num_steps):
stage = stages[i]
start = time.perf_counter()
model_name = os.path.join(model_dir, f"{stage}.onnx")
print(f'Load model from {model_name} ...')
......@@ -80,7 +83,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
print(f'Inference {stage} ...')
start = time.perf_counter()
for _ in range(0, num_steps[i]):
for _ in range(0, num_step):
temb = tembs[step]
new_input, = session.run(None, {'input': input, 'temb': temb})
output = new_input[:, -1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment