Commit 0e8d5efa authored by tpys's avatar tpys
Browse files

add input_type

parent cfe4af4c
......@@ -25,14 +25,13 @@ def split_variable(ds, name):
return v
def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split=False):
def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split=False):
if save_dir:
os.makedirs(save_dir, exist_ok=True)
dtime = (step+1) * freq
if input_type == "hres":
if input_type.upper() == "HRES":
dtime = (step+2) * freq
elif input_type == "gfs":
dtime = (step+1) * freq
init_time = pd.to_datetime(input.time.values[0])
......
......@@ -13,11 +13,14 @@ ort.set_default_logger_severity(3)
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir")
parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format")
parser.add_argument('--input_type', type=str, help="The input type", default="ERA5")
parser.add_argument('--save_dir', type=str, default="")
parser.add_argument('--num_steps', type=int, nargs="+", default=[20])
args = parser.parse_args()
assert args.input_type.upper() in ["ERA5", "GFS", "HRES"]
def time_encoding(init_time, total_step, freq=6):
init_time = np.array([init_time])
......@@ -87,7 +90,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
temb = tembs[step]
new_input, = session.run(None, {'input': input, 'temb': temb})
output = new_input[:, -1]
save_like(output, data, step, save_dir)
save_like(output, data, step, save_dir, input_type=args.input_type)
print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}')
input = new_input
step += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment