Commit ce91d3d0 authored by tpys's avatar tpys
Browse files

add make_gfs_input

parent 49e8c4fd
...@@ -4,11 +4,18 @@ import numpy as np ...@@ -4,11 +4,18 @@ import numpy as np
import pandas as pd import pandas as pd
import xarray as xr import xarray as xr
__all__ = ['make_input', "save_like"] __all__ = ["save_like"]
pl_names = ['z', 't', 'u', 'v', 'r'] pl_names = ['z', 't', 'u', 'v', 'r']
sfc_names = ['t2m', 'u10', 'v10', 'msl', 'tp'] sfc_names = ['t2m', 'u10', 'v10', 'msl', 'tp']
levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
degree = 0.25
def weighted_rmse(out, tgt):
wlat = np.cos(np.deg2rad(tgt.lat))
wlat /= wlat.mean()
error = ((out - tgt) ** 2 * wlat)
return np.sqrt(error.mean(('lat', 'lon')))
def split_variable(ds, name): def split_variable(ds, name):
...@@ -29,10 +36,10 @@ def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split ...@@ -29,10 +36,10 @@ def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split
if save_dir: if save_dir:
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
dtime = (step+1) * freq dtime = (step+1) * freq
init_time = pd.to_datetime(input.time.values[-1])
if input_type.upper() == "HRES": if input_type.upper() == "HRES":
dtime = (step+2) * freq dtime = (step+2) * freq
init_time = pd.to_datetime(input.time.values[0]) init_time = pd.to_datetime(input.time.values[0])
ds = xr.DataArray( ds = xr.DataArray(
...@@ -68,9 +75,9 @@ def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split ...@@ -68,9 +75,9 @@ def save_like(output, input, step, save_dir="", input_type="hres", freq=6, split
ds.to_netcdf(save_name) ds.to_netcdf(save_name)
def make_input(init_time, data_dir, save_dir, deg=0.25): def make_hres_input(init_time, data_dir, save_dir):
lat = np.linspace(-90, 90, int(180/deg)+1, dtype=np.float32) lat = np.linspace(-90, 90, int(180/degree)+1, dtype=np.float32)
lon = np.arange(0, 360, deg, dtype=np.float32) lon = np.arange(0, 360, degree, dtype=np.float32)
input = [] input = []
level = [] level = []
...@@ -150,6 +157,79 @@ def make_input(init_time, data_dir, save_dir, deg=0.25): ...@@ -150,6 +157,79 @@ def make_input(init_time, data_dir, save_dir, deg=0.25):
input.to_netcdf(save_name) input.to_netcdf(save_name)
def make_gfs_input(init_time, data_dir, save_dir):
pl_names = ['Z', 'T', 'U', 'V', 'R']
sfc_names = ['t2m', 'u10', 'v10', 'msl', 'tp']
lon = np.arange(0, 360, degree, dtype=np.float32)
lat = np.arange(90, -90, -degree, dtype=np.float32)
input = []
level = []
for name in pl_names + sfc_names:
src_name = '{}_{}'.format(name, init_time.strftime("%Y%m%d.nc"))
src_file = os.path.join(data_dir, src_name)
if not os.path.exists(src_file):
print(src_file)
return
try:
v = xr.open_dataset(src_file)[name]
except:
print(f"open {src_file} failed")
return
if np.isnan(v).sum() > 0:
print(f"{src_name} has nan value")
return
if v.shape[-2:] != (721, 1440):
v = v.interp(lat=lat, lon=lon, kwargs={"fill_value": "extrapolate"})
if np.isnan(v).sum() > 0:
print(f"{src_name} has nan value")
return
if name in pl_names:
level.extend([f'{name.lower()}{l}' for l in levels])
if name in sfc_names:
level.append(name.lower())
if name == "Z":
v = v * 9.8
if name == "tp":
v = v.clip(min=0, max=1000)
v = v.squeeze('step').drop('step')
v.attrs = {}
v.name = 'data'
vmin = v.min().values
vmax = v.max().values
if vmax > 1e10:
v = v.where(v < 1e10, 0)
vmax = v.max().values
assert vmax < 1e10
print(f'{src_name}: {v.shape}, {vmin:.2f} ~ {vmax:.2f}')
input.append(v)
input = xr.concat(input, "level") # T
input = input.rename({"latitude": "lat", "longitude": "lon"})
times = [pd.to_datetime(str(t), format='%Y%m%d%H') for t in input.time.values]
input = input.assign_coords(level=level)
input = input.assign_coords(time=times)
# TODO, we only need two time step input with dims: 2 x 70 x 721 x 1440
save_name = os.path.join(save_dir, init_time.strftime("%Y%m%d.nc"))
input = input.astype(np.float32)
input.to_netcdf(save_name)
def visualize(save_name, vars=[], titles=[], vmin=None, vmax=None): def visualize(save_name, vars=[], titles=[], vmin=None, vmax=None):
import cartopy.crs as ccrs import cartopy.crs as ccrs
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -188,7 +268,7 @@ def test_make_input(): ...@@ -188,7 +268,7 @@ def test_make_input():
data_dir = "data/HRES" data_dir = "data/HRES"
save_dir = "data/HRES/input" save_dir = "data/HRES/input"
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
make_input(init_time, data_dir, save_dir) make_hres_input(init_time, data_dir, save_dir)
def test_visualize(): def test_visualize():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment