Commit e56b2a2e authored by tpys's avatar tpys
Browse files

support different input: ERA5, HRES, GFS

parent ce91d3d0
...@@ -11,6 +11,7 @@ sfc_names = ['t2m', 'u10', 'v10', 'msl', 'tp'] ...@@ -11,6 +11,7 @@ sfc_names = ['t2m', 'u10', 'v10', 'msl', 'tp']
levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
degree = 0.25 degree = 0.25
def weighted_rmse(out, tgt): def weighted_rmse(out, tgt):
wlat = np.cos(np.deg2rad(tgt.lat)) wlat = np.cos(np.deg2rad(tgt.lat))
wlat /= wlat.mean() wlat /= wlat.mean()
...@@ -35,20 +36,19 @@ def split_variable(ds, name): ...@@ -35,20 +36,19 @@ def split_variable(ds, name):
def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split=False): def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split=False):
if save_dir: if save_dir:
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
dtime = (step+1) * freq step = (step+1) * freq
init_time = pd.to_datetime(input.time.values[-1]) init_time = pd.to_datetime(input.time.values[-1])
if input_type.upper() == "HRES": if input_type.upper() == "HRES":
dtime = (step+2) * freq step = (step+2) * freq
init_time = pd.to_datetime(input.time.values[0]) init_time = pd.to_datetime(input.time.values[0])
ds = xr.DataArray( ds = xr.DataArray(
output[None, None], output[None],
dims=['member', 'time', 'dtime', 'level', 'lat', 'lon'], dims=['time', 'step', 'level', 'lat', 'lon'],
coords=dict( coords=dict(
member=['FuXi'],
time=[init_time], time=[init_time],
dtime=[dtime], step=[step],
level=input.level, level=input.level,
lat=input.lat.values, lat=input.lat.values,
lon=input.lon.values, lon=input.lon.values,
...@@ -70,8 +70,39 @@ def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split ...@@ -70,8 +70,39 @@ def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split
new_ds.append(v) new_ds.append(v)
ds = xr.merge(new_ds, compat="no_conflicts") ds = xr.merge(new_ds, compat="no_conflicts")
save_name = os.path.join(save_dir, f'{dtime:03d}.nc') save_name = os.path.join(save_dir, f'{step:03d}.nc')
print(f'Save to {save_name} ...') # print(f'Save to {save_name} ...')
ds.to_netcdf(save_name)
def make_era5_input(init_time, data_dir, save_dir):
ds = []
init_time = pd.to_datetime(init_time)
hist_time = init_time - pd.Timedelta(hours=6)
print(f"init_time: {init_time}")
level = []
for name in pl_names + sfc_names:
data_name = os.path.join(data_dir, name, f'{init_time.year}')
v = xr.open_zarr(data_name)
v = v.sel(time=[hist_time, init_time])
v = v.rename({name: 'data'})
v.attrs = {}
ds.append(v)
if name in pl_names:
level.extend([f'{name.lower()}{l}' for l in levels])
if name in sfc_names:
level.append(name.lower())
ds = xr.concat(ds, 'level')
ds = ds.assign_coords(level=level)
os.makedirs(save_dir, exist_ok=True)
save_name = os.path.join(save_dir, init_time.strftime("input.%Y%m%d.t%H.nc"))
print(f"save to {save_name} ...")
ds = ds.astype(np.float32)
ds.to_netcdf(save_name) ds.to_netcdf(save_name)
...@@ -185,7 +216,8 @@ def make_gfs_input(init_time, data_dir, save_dir): ...@@ -185,7 +216,8 @@ def make_gfs_input(init_time, data_dir, save_dir):
return return
if v.shape[-2:] != (721, 1440): if v.shape[-2:] != (721, 1440):
v = v.interp(lat=lat, lon=lon, kwargs={"fill_value": "extrapolate"}) v = v.interp(lat=lat, lon=lon, kwargs={
"fill_value": "extrapolate"})
if np.isnan(v).sum() > 0: if np.isnan(v).sum() > 0:
print(f"{src_name} has nan value") print(f"{src_name} has nan value")
return return
...@@ -219,7 +251,8 @@ def make_gfs_input(init_time, data_dir, save_dir): ...@@ -219,7 +251,8 @@ def make_gfs_input(init_time, data_dir, save_dir):
input = xr.concat(input, "level") # T input = xr.concat(input, "level") # T
input = input.rename({"latitude": "lat", "longitude": "lon"}) input = input.rename({"latitude": "lat", "longitude": "lon"})
times = [pd.to_datetime(str(t), format='%Y%m%d%H') for t in input.time.values] times = [pd.to_datetime(str(t), format='%Y%m%d%H')
for t in input.time.values]
input = input.assign_coords(level=level) input = input.assign_coords(level=level)
input = input.assign_coords(time=times) input = input.assign_coords(time=times)
...@@ -229,7 +262,6 @@ def make_gfs_input(init_time, data_dir, save_dir): ...@@ -229,7 +262,6 @@ def make_gfs_input(init_time, data_dir, save_dir):
input.to_netcdf(save_name) input.to_netcdf(save_name)
def visualize(save_name, vars=[], titles=[], vmin=None, vmax=None): def visualize(save_name, vars=[], titles=[], vmin=None, vmax=None):
import cartopy.crs as ccrs import cartopy.crs as ccrs
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -277,4 +309,14 @@ def test_visualize(): ...@@ -277,4 +309,14 @@ def test_visualize():
visualize('tp.jpg', [tp], ['tp'], vmin=0, vmax=20) visualize('tp.jpg', [tp], ['tp'], vmin=0, vmax=20)
# test_make_input() def test_rmse(output_name, target_name):
output = xr.open_dataarray(output_name)
output = output.isel(time=0).sel(step=120)
target = xr.open_dataarray(target_name)
for level in ["z500", "t850", "t2m", "u10", "v10", "msl", "tp"]:
out = output.sel(level=level)
tgt = target.sel(level=level)
rmse = weighted_rmse(out, tgt).load()
print(f"{level.upper()} 120h rmse: {rmse:.3f}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment