Commit 0e8d5efa authored by tpys's avatar tpys
Browse files

add input_type

parent cfe4af4c
...@@ -16,7 +16,7 @@ def split_variable(ds, name): ...@@ -16,7 +16,7 @@ def split_variable(ds, name):
v = ds.sel(level=[name]) v = ds.sel(level=[name])
v = v.assign_coords(level=[0]) v = v.assign_coords(level=[0])
v = v.rename({"level": "level0"}) v = v.rename({"level": "level0"})
v = v.transpose('member', 'level0', 'time', 'dtime', 'lat', 'lon') v = v.transpose('member', 'level0', 'time', 'dtime', 'lat', 'lon')
elif name in pl_names: elif name in pl_names:
level = [f'{name}{l}' for l in levels] level = [f'{name}{l}' for l in levels]
v = ds.sel(level=level) v = ds.sel(level=level)
...@@ -25,15 +25,14 @@ def split_variable(ds, name): ...@@ -25,15 +25,14 @@ def split_variable(ds, name):
return v return v
def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split=False): def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split=False):
if save_dir: if save_dir:
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
dtime = (step+1) * freq
if input_type == "hres": if input_type.upper() == "HRES":
dtime = (step+2) * freq dtime = (step+2) * freq
elif input_type == "gfs":
dtime = (step+1) * freq
init_time = pd.to_datetime(input.time.values[0]) init_time = pd.to_datetime(input.time.values[0])
ds = xr.DataArray( ds = xr.DataArray(
...@@ -47,8 +46,8 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split ...@@ -47,8 +46,8 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split
lat=input.lat.values, lat=input.lat.values,
lon=input.lon.values, lon=input.lon.values,
) )
).astype(np.float32) ).astype(np.float32)
if split: if split:
def rename(name): def rename(name):
if name == "tp": if name == "tp":
...@@ -56,7 +55,7 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split ...@@ -56,7 +55,7 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split
elif name == "r": elif name == "r":
return "RH" return "RH"
return name.upper() return name.upper()
new_ds = [] new_ds = []
for k in pl_names + sfc_names: for k in pl_names + sfc_names:
v = split_variable(ds, k) v = split_variable(ds, k)
...@@ -66,7 +65,7 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split ...@@ -66,7 +65,7 @@ def save_like(output, input, step, input_type="hres", save_dir="", freq=6, split
print(f'Save to {save_name} ...') print(f'Save to {save_name} ...')
save_name = os.path.join(save_dir, f'{dtime:03d}.nc') save_name = os.path.join(save_dir, f'{dtime:03d}.nc')
ds.to_netcdf(save_name) ds.to_netcdf(save_name)
def make_input(init_time, data_dir, save_dir, deg=0.25): def make_input(init_time, data_dir, save_dir, deg=0.25):
...@@ -75,7 +74,7 @@ def make_input(init_time, data_dir, save_dir, deg=0.25): ...@@ -75,7 +74,7 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
input = [] input = []
level = [] level = []
for name in pl_names + sfc_names: for name in pl_names + sfc_names:
src_name = '{}_{}'.format(name, init_time.strftime("%Y%m%d%H.nc")) src_name = '{}_{}'.format(name, init_time.strftime("%Y%m%d%H.nc"))
src_file = os.path.join(data_dir, src_name) src_file = os.path.join(data_dir, src_name)
...@@ -198,4 +197,4 @@ def test_visualize(): ...@@ -198,4 +197,4 @@ def test_visualize():
visualize('tp.jpg', [tp], ['tp'], vmin=0, vmax=20) visualize('tp.jpg', [tp], ['tp'], vmin=0, vmax=20)
# test_make_input() # test_make_input()
\ No newline at end of file
...@@ -13,11 +13,14 @@ ort.set_default_logger_severity(3) ...@@ -13,11 +13,14 @@ ort.set_default_logger_severity(3)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir") parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir")
parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format") parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format")
parser.add_argument('--input_type', type=str, help="The input type", default="ERA5")
parser.add_argument('--save_dir', type=str, default="") parser.add_argument('--save_dir', type=str, default="")
parser.add_argument('--num_steps', type=int, nargs="+", default=[20]) parser.add_argument('--num_steps', type=int, nargs="+", default=[20])
args = parser.parse_args() args = parser.parse_args()
assert args.input_type.upper() in ["ERA5", "GFS", "HRES"]
def time_encoding(init_time, total_step, freq=6): def time_encoding(init_time, total_step, freq=6):
init_time = np.array([init_time]) init_time = np.array([init_time])
...@@ -87,7 +90,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""): ...@@ -87,7 +90,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
temb = tembs[step] temb = tembs[step]
new_input, = session.run(None, {'input': input, 'temb': temb}) new_input, = session.run(None, {'input': input, 'temb': temb})
output = new_input[:, -1] output = new_input[:, -1]
save_like(output, data, step, save_dir) save_like(output, data, step, save_dir, input_type=args.input_type)
print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}') print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}')
input = new_input input = new_input
step += 1 step += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment