Commit fe6e4f3d authored by tpys's avatar tpys
Browse files

fuxi dropout

parent e56b2a2e
...@@ -13,6 +13,7 @@ ort.set_default_logger_severity(3) ...@@ -13,6 +13,7 @@ ort.set_default_logger_severity(3)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir") parser.add_argument('--model', type=str, required=True, help="FuXi onnx model dir")
parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format") parser.add_argument('--input', type=str, required=True, help="The input data file, store in netcdf format")
parser.add_argument('--drop_prob', type=float, help="dropout prob", default=0)
parser.add_argument('--input_type', type=str, help="The input type", default="ERA5") parser.add_argument('--input_type', type=str, help="The input type", default="ERA5")
parser.add_argument('--save_dir', type=str, default="") parser.add_argument('--save_dir', type=str, default="")
parser.add_argument('--num_steps', type=int, nargs="+", default=[20]) parser.add_argument('--num_steps', type=int, nargs="+", default=[20])
...@@ -57,6 +58,7 @@ def load_model(model_name): ...@@ -57,6 +58,7 @@ def load_model(model_name):
def run_inference(model_dir, data, num_steps, save_dir=""): def run_inference(model_dir, data, num_steps, save_dir=""):
total_step = sum(num_steps) total_step = sum(num_steps)
init_time = pd.to_datetime(data.time.values[-1]) init_time = pd.to_datetime(data.time.values[-1])
tembs = time_encoding(init_time, total_step) tembs = time_encoding(init_time, total_step)
...@@ -68,6 +70,8 @@ def run_inference(model_dir, data, num_steps, save_dir=""): ...@@ -68,6 +70,8 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
assert data.lat.values[-1] == -90 assert data.lat.values[-1] == -90
input = data.values[None] input = data.values[None]
prob = np.array([args.drop_prob], dtype=np.float32)
print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}') print(f'input: {input.shape}, {input.min():.2f} ~ {input.max():.2f}')
print(f'tembs: {tembs.shape}, {tembs.mean():.4f}') print(f'tembs: {tembs.shape}, {tembs.mean():.4f}')
...@@ -88,7 +92,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""): ...@@ -88,7 +92,7 @@ def run_inference(model_dir, data, num_steps, save_dir=""):
for _ in range(0, num_step): for _ in range(0, num_step):
temb = tembs[step] temb = tembs[step]
new_input, = session.run(None, {'input': input, 'temb': temb}) new_input, = session.run(None, {'input': input, 'temb': temb, 'prob': prob})
output = new_input[:, -1] output = new_input[:, -1]
save_like(output, data, step, save_dir, input_type=args.input_type) save_like(output, data, step, save_dir, input_type=args.input_type)
print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}') print(f'stage: {i}, step: {step+1:02d}, output: {output.min():.2f} {output.max():.2f}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment