Commit c56e400c authored by tpys's avatar tpys
Browse files

move save_like to data_util

parent 2f128e8c
......@@ -4,35 +4,73 @@ import numpy as np
import pandas as pd
import xarray as xr
__all__ = ['make_input', 'chunk_time']
__all__ = ['make_input', "save_like"]
pl_names = ['z', 't', 'u', 'v', 'r']
sfc_names = ['t2m', 'u10', 'v10', 'msl', 'tp']
levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
def chunk_time(ds, shape=None):
if shape is None:
dims = {k:v for k, v in ds.dims.items()}
else:
dims = {k:v for k, v in zip(ds.dims, shape)}
for k in ['time', 'lead_time']:
if k in dims:
dims[k] = 1
def split_variable(ds, name):
if name in sfc_names:
v = ds.sel(level=[name])
v = v.assign_coords(level=[0])
v = v.rename({"level": "level0"})
v = v.transpose('member', 'level0', 'time', 'dtime', 'lat', 'lon')
elif name in pl_names:
level = [f'{name}{l}' for l in levels]
v = ds.sel(level=level)
v = v.assign_coords(level=levels)
v = v.transpose('member', 'level', 'time', 'dtime', 'lat', 'lon')
return v
def save_like(output, input, step, save_dir="", freq=6):
if save_dir:
os.makedirs(save_dir, exist_ok=True)
dtime = (step+2) * freq #
init_time = pd.to_datetime(input.time.values[0])
print(f'init_time: {init_time}, dtime: {dtime}')
data = xr.DataArray(
output[None, None],
dims=['member', 'time', 'dtime', 'level', 'lat', 'lon'],
coords=dict(
member=['FuXi'],
time=[init_time],
dtime=[dtime],
level=input.level,
lat=input.lat.values,
lon=input.lon.values,
)
).astype(np.float32)
def rename(name):
if name == "tp":
return "TP06"
elif name == "r":
return "RH"
return name.upper()
ds = []
for k in pl_names + sfc_names:
v = split_variable(data, k)
v.name = rename(k)
# print(f"{k}: {v.shape} {v.values.min()} ~ {v.values.max()}")
ds.append(v)
ds = xr.merge(ds, compat="no_conflicts")
ds = ds.chunk(dims)
return ds
save_name = os.path.join(save_dir, f'{dtime:03d}.nc')
ds.to_netcdf(save_name)
def make_input(init_time, data_dir, save_dir, deg=0.25):
# These are fixed for FuXi
pl_names = ['z', 't', 'u', 'v', 'r']
sfc_names = ['t2m', 'u10', 'v10', 'msl', 'tp']
levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
lat = np.linspace(-90, 90, int(180/deg)+1, dtype=np.float32)
lon = np.arange(0, 360, deg, dtype=np.float32)
valid_time = init_time + pd.Timedelta(hours=6) # utc time
input = []
level = []
for name in pl_names + sfc_names:
src_name = '{}_{}'.format(name, init_time.strftime("%Y%m%d%H.nc"))
src_file = os.path.join(data_dir, src_name)
......@@ -41,7 +79,8 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
return
try:
v = xr.open_dataset(src_file).sel(time=init_time, drop=True).data
v = xr.open_dataset(src_file)
v = v.sel(time=init_time, drop=True).data
except:
print(f"open {src_file} failed")
return
......@@ -59,7 +98,6 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
print(f"{src_name} has nan value")
return
# reverse pressure level
try:
if name in pl_names:
......@@ -88,28 +126,71 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
v.attrs = {}
v = v.rename({'dtime': 'time'})
v = v.squeeze('member').drop('member')
v = v.assign_coords(time=[init_time, valid_time])
input.append(v)
# concat and reshape
input = xr.concat(input, "level")
input = input.transpose("time", "level", "lat", "lon")
valid_time = init_time + pd.Timedelta(hours=6) # utc time
v = v.assign_coords(time=[init_time, valid_time])
# reverse latitude
input = input.reindex(lat=input.lat[::-1])
input = input.assign_coords(level=level)
input.name = 'data'
input = chunk_time(input, input.shape)
# save to nc
save_name = os.path.join(save_dir, valid_time.strftime("%Y%m%d-%H.nc"))
print(input)
save_name = os.path.join(save_dir, init_time.strftime("%Y%m%d-%H.nc"))
input = input.astype(np.float32)
input.to_netcdf(save_name)
def visualize(save_name, vars=[], titles=[], vmin=None, vmax=None):
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
fig, ax = plt.subplots(len(vars), 1, figsize=(8, 6), subplot_kw={
"projection": ccrs.PlateCarree()})
def plot(ax, v, title):
v.plot(
ax=ax,
x='lon',
y='lat',
vmin=vmin,
vmax=vmax,
transform=ccrs.PlateCarree(),
add_colorbar=False
)
# ax.coastlines()
ax.set_title(title)
gl = ax.gridlines(draw_labels=True, linewidth=0.5)
gl.top_labels = False
gl.right_labels = False
for i, v in enumerate(vars):
if len(vars) == 1:
plot(ax, v, titles[i])
else:
plot(ax[i], v, titles[i])
plt.savefig(save_name, bbox_inches='tight',
pad_inches=0.1, transparent='true', dpi=200)
plt.close()
def test_make_input():
init_time = pd.to_datetime("20230731-12") # must utc
data_dir = "data/HRES"
save_dir = "data/HRES/input"
os.makedirs(save_dir, exist_ok=True)
make_input(init_time, data_dir, save_dir)
def test_visualize():
ds = xr.open_dataarray('data/HRES/output/072.nc')
tp = ds.sel(level='tp')
visualize('tp.jpg', [tp], ['tp'], vmin=0, vmax=20)
# test_make_input()
\ No newline at end of file
......@@ -6,6 +6,8 @@ import xarray as xr
import pandas as pd
import onnxruntime as ort
from data_util import save_like
ort.set_default_logger_severity(3)
parser = argparse.ArgumentParser()
......@@ -50,39 +52,9 @@ def load_model(model_name):
return session
def load_data(data_file):
input = xr.open_dataarray(data_file)
return input
def save_like(output, data, step, save_dir="", freq=6):
if save_dir:
os.makedirs(save_dir, exist_ok=True)
lead_time = (step+1) * freq
init_time = pd.to_datetime(data.time.values[-1])
fcst_time = init_time + pd.Timedelta(hours=lead_time)
output = xr.DataArray(
output, # 1 x 70 x 721 x 1440
dims=['time', 'level', 'lat', 'lon'],
coords=dict(
time=[fcst_time],
level=data.level,
lat=data.lat,
lon=data.lon,
)
)
output.name = 'data'
save_name = os.path.join(save_dir, f'{lead_time:03d}.nc')
output.to_netcdf(save_name)
def run_inference(model_dir, data, num_steps, save_dir=""):
total_step = sum(num_steps)
init_time = pd.to_datetime(data.time.values[-1])
tembs = time_encoding(init_time, total_step)
print(f'init_time: {init_time.strftime(("%Y%m%d-%H"))}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment