Commit bdd87fae authored by zhangwenbo's avatar zhangwenbo
Browse files

Initial commit: FourCastNet source code only

parents
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import os
import sys
import time
import numpy as np
import argparse
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
from numpy.core.numeric import False_
import h5py
import torch
import torchvision
from torchvision.utils import save_image
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
import logging
from utils import logging_utils
from utils.weighted_acc_rmse import weighted_rmse_torch_channels, weighted_acc_torch_channels, unweighted_acc_torch_channels, weighted_acc_masked_torch_channels
logging_utils.config_logger()
from utils.YParams import YParams
from utils.data_loader_multifiles import get_data_loader
from networks.afnonet import AFNONet
import wandb
import matplotlib.pyplot as plt
import glob
from datetime import datetime
fld = "z500" # diff flds have diff decor times and hence differnt ics
if fld == "z500" or fld == "2m_temperature" or fld == "t850":
DECORRELATION_TIME = 36 # 9 days (36) for z500, 2 (8 steps) days for u10, v10
else:
DECORRELATION_TIME = 8 # 9 days (36) for z500, 2 (8 steps) days for u10, v10
idxes = {"u10":0, "z500":14, "2m_temperature":2, "v10":1, "t850":5}
def gaussian_perturb(x, level=0.01, device=0):
noise = level * torch.randn(x.shape).to(device, dtype=torch.float)
return (x + noise)
def load_model(model, params, checkpoint_file):
model.zero_grad()
checkpoint_fname = checkpoint_file
checkpoint = torch.load(checkpoint_fname)
try:
new_state_dict = OrderedDict()
for key, val in checkpoint['model_state'].items():
name = key[7:]
if name != 'ged':
new_state_dict[name] = val
model.load_state_dict(new_state_dict)
except:
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model
def downsample(x, scale=0.125):
return torch.nn.functional.interpolate(x, scale_factor=scale, mode='bilinear')
def setup(params):
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
#get data loader
valid_data_loader, valid_dataset = get_data_loader(params, params.inf_data_path, dist.is_initialized(), train=False)
img_shape_x = valid_dataset.img_shape_x
img_shape_y = valid_dataset.img_shape_y
params.img_shape_x = img_shape_x
params.img_shape_y = img_shape_y
if params.log_to_screen:
logging.info('Loading trained model checkpoint from {}'.format(params['best_checkpoint_path']))
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
if params["orography"]:
params['N_in_channels'] = n_in_channels + 1
else:
params['N_in_channels'] = n_in_channels
params['N_out_channels'] = n_out_channels
params.means = np.load(params.global_means_path)[0, out_channels] # needed to standardize wind data
params.stds = np.load(params.global_stds_path)[0, out_channels]
# load the model
if params.nettype == 'afno':
model = AFNONet(params).to(device)
else:
raise Exception("not implemented")
checkpoint_file = params['best_checkpoint_path']
model = load_model(model, params, checkpoint_file)
model = model.to(device)
# load the validation data
files_paths = glob.glob(params.inf_data_path + "/*.h5")
files_paths.sort()
# which year
yr = 0
if params.log_to_screen:
logging.info('Loading inference data')
logging.info('Inference data from {}'.format(files_paths[yr]))
valid_data_full = h5py.File(files_paths[yr], 'r')['fields']
return valid_data_full, model
def autoregressive_inference(params, ic, valid_data_full, model):
ic = int(ic)
#initialize global variables
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
exp_dir = params['experiment_dir']
dt = int(params.dt)
prediction_length = int(params.prediction_length/dt)
n_history = params.n_history
img_shape_x = params.img_shape_x
img_shape_y = params.img_shape_y
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
means = params.means
stds = params.stds
#initialize memory for image sequences and RMSE/ACC
valid_loss = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
# compute metrics in a coarse resolution too if params.interp is nonzero
valid_loss_coarse = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc_coarse = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc_coarse_unweighted = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc_unweighted = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
seq_real = torch.zeros((prediction_length, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred = torch.zeros((prediction_length, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
acc_land = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc_sea = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
if params.masked_acc:
maskarray = torch.as_tensor(np.load(params.maskpath)[0:720]).to(device, dtype=torch.float)
valid_data = valid_data_full[ic:(ic+prediction_length*dt+n_history*dt):dt, in_channels, 0:720] #extract valid data from first year
# standardize
valid_data = (valid_data - means)/stds
valid_data = torch.as_tensor(valid_data).to(device, dtype=torch.float)
#load time means
if not params.use_daily_climatology:
m = torch.as_tensor((np.load(params.time_means_path)[0][out_channels] - means)/stds)[:, 0:img_shape_x] # climatology
m = torch.unsqueeze(m, 0)
else:
# use daily clim like weyn et al. (different from rasp)
dc_path = params.dc_path
with h5py.File(dc_path, 'r') as f:
dc = f['time_means_daily'][ic:ic+prediction_length*dt:dt] # 1460,21,721,1440
m = torch.as_tensor((dc[:,out_channels,0:img_shape_x,:] - means)/stds)
m = m.to(device, dtype=torch.float)
if params.interp > 0:
m_coarse = downsample(m, scale=params.interp)
std = torch.as_tensor(stds[:,0,0]).to(device, dtype=torch.float)
orography = params.orography
orography_path = params.orography_path
if orography:
orog = torch.as_tensor(np.expand_dims(np.expand_dims(h5py.File(orography_path, 'r')['orog'][0:720], axis = 0), axis = 0)).to(device, dtype = torch.float)
logging.info("orography loaded; shape:{}".format(orog.shape))
#autoregressive inference
if params.log_to_screen:
logging.info('Begin autoregressive inference')
with torch.no_grad():
for i in range(valid_data.shape[0]):
if i==0: #start of sequence
first = valid_data[0:n_history+1]
future = valid_data[n_history+1]
for h in range(n_history+1):
seq_real[h] = first[h*n_in_channels : (h+1)*n_in_channels][0:n_out_channels] #extract history from 1st
seq_pred[h] = seq_real[h]
if params.perturb:
first = gaussian_perturb(first, level=params.n_level, device=device) # perturb the ic
if orography:
future_pred = model(torch.cat((first, orog), axis=1))
else:
future_pred = model(first)
else:
if i < prediction_length-1:
future = valid_data[n_history+i+1]
if orography:
future_pred = model(torch.cat((future_pred, orog), axis=1)) #autoregressive step
else:
future_pred = model(future_pred) #autoregressive step
if i < prediction_length-1: #not on the last step
seq_pred[n_history+i+1] = future_pred
seq_real[n_history+i+1] = future
history_stack = seq_pred[i+1:i+2+n_history]
future_pred = history_stack
#Compute metrics
if params.use_daily_climatology:
clim = m[i:i+1]
if params.interp > 0:
clim_coarse = m_coarse[i:i+1]
else:
clim = m
if params.interp > 0:
clim_coarse = m_coarse
pred = torch.unsqueeze(seq_pred[i], 0)
tar = torch.unsqueeze(seq_real[i], 0)
valid_loss[i] = weighted_rmse_torch_channels(pred, tar) * std
acc[i] = weighted_acc_torch_channels(pred-clim, tar-clim)
acc_unweighted[i] = unweighted_acc_torch_channels(pred-clim, tar-clim)
if params.masked_acc:
acc_land[i] = weighted_acc_masked_torch_channels(pred-clim, tar-clim, maskarray)
acc_sea[i] = weighted_acc_masked_torch_channels(pred-clim, tar-clim, 1-maskarray)
if params.interp > 0:
pred = downsample(pred, scale=params.interp)
tar = downsample(tar, scale=params.interp)
valid_loss_coarse[i] = weighted_rmse_torch_channels(pred, tar) * std
acc_coarse[i] = weighted_acc_torch_channels(pred-clim_coarse, tar-clim_coarse)
acc_coarse_unweighted[i] = unweighted_acc_torch_channels(pred-clim_coarse, tar-clim_coarse)
if params.log_to_screen:
idx = idxes[fld]
logging.info('Predicted timestep {} of {}. {} RMS Error: {}, ACC: {}'.format(i, prediction_length, fld, valid_loss[i, idx], acc[i, idx]))
if params.interp > 0:
logging.info('[COARSE] Predicted timestep {} of {}. {} RMS Error: {}, ACC: {}'.format(i, prediction_length, fld, valid_loss_coarse[i, idx],
acc_coarse[i, idx]))
seq_real = seq_real.cpu().numpy()
seq_pred = seq_pred.cpu().numpy()
valid_loss = valid_loss.cpu().numpy()
acc = acc.cpu().numpy()
acc_unweighted = acc_unweighted.cpu().numpy()
acc_coarse = acc_coarse.cpu().numpy()
acc_coarse_unweighted = acc_coarse_unweighted.cpu().numpy()
valid_loss_coarse = valid_loss_coarse.cpu().numpy()
acc_land = acc_land.cpu().numpy()
acc_sea = acc_sea.cpu().numpy()
return (np.expand_dims(seq_real[n_history:], 0), np.expand_dims(seq_pred[n_history:], 0), np.expand_dims(valid_loss,0), np.expand_dims(acc, 0),
np.expand_dims(acc_unweighted, 0), np.expand_dims(valid_loss_coarse, 0), np.expand_dims(acc_coarse, 0),
np.expand_dims(acc_coarse_unweighted, 0),
np.expand_dims(acc_land, 0),
np.expand_dims(acc_sea, 0))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--run_num", default='00', type=str)
parser.add_argument("--yaml_config", default='./config/AFNO.yaml', type=str)
parser.add_argument("--config", default='full_field', type=str)
parser.add_argument("--use_daily_climatology", action='store_true')
parser.add_argument("--vis", action='store_true')
parser.add_argument("--override_dir", default=None, type = str, help = 'Path to store inference outputs; must also set --weights arg')
parser.add_argument("--interp", default=0, type=float)
parser.add_argument("--weights", default=None, type=str, help = 'Path to model weights, for use with override_dir option')
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
params['world_size'] = 1
params['interp'] = args.interp
params['use_daily_climatology'] = args.use_daily_climatology
params['global_batch_size'] = params.batch_size
torch.cuda.set_device(0)
torch.backends.cudnn.benchmark = True
vis = args.vis
# Set up directory
if args.override_dir is not None:
assert args.weights is not None, 'Must set --weights argument if using --override_dir'
expDir = args.override_dir
else:
assert args.weights is None, 'Cannot use --weights argument without also using --override_dir'
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
if not os.path.isdir(expDir):
os.makedirs(expDir)
params['experiment_dir'] = os.path.abspath(expDir)
params['best_checkpoint_path'] = args.weights if args.override_dir is not None else os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
params['resuming'] = False
params['local_rank'] = 0
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference_out.log'))
logging_utils.log_versions()
params.log()
n_ics = params['n_initial_conditions']
if fld == "z500" or fld == "t850":
n_samples_per_year = 1336
else:
n_samples_per_year = 1460
if params["ics_type"] == 'default':
num_samples = n_samples_per_year-params.prediction_length
stop = num_samples
ics = np.arange(0, stop, DECORRELATION_TIME)
if vis: # visualization for just the first ic (or any ic)
ics = [0]
n_ics = len(ics)
elif params["ics_type"] == "datetime":
date_strings = params["date_strings"]
ics = []
if params.perturb: #for perturbations use a single date and create n_ics perturbations
n_ics = params["n_perturbations"]
date = date_strings[0]
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
for ii in range(n_ics):
ics.append(int(hours_since_jan_01_epoch/6))
else:
for date in date_strings:
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
ics.append(int(hours_since_jan_01_epoch/6))
n_ics = len(ics)
logging.info("Inference for {} initial conditions".format(n_ics))
try:
autoregressive_inference_filetag = params["inference_file_tag"]
except:
autoregressive_inference_filetag = ""
if params.interp > 0:
autoregressive_inference_filetag = "_coarse"
autoregressive_inference_filetag += "_" + fld + ""
if vis:
autoregressive_inference_filetag += "_vis"
# get data and models
valid_data_full, model = setup(params)
#initialize lists for image sequences and RMSE/ACC
valid_loss = []
valid_loss_coarse = []
acc_unweighted = []
acc = []
acc_coarse = []
acc_coarse_unweighted = []
seq_pred = []
seq_real = []
acc_land = []
acc_sea = []
#run autoregressive inference for multiple initial conditions
for i, ic in enumerate(ics):
logging.info("Initial condition {} of {}".format(i+1, n_ics))
sr, sp, vl, a, au, vc, ac, acu, accland, accsea = autoregressive_inference(params, ic, valid_data_full, model)
if i ==0 or len(valid_loss) == 0:
seq_real = sr
seq_pred = sp
valid_loss = vl
valid_loss_coarse = vc
acc = a
acc_coarse = ac
acc_coarse_unweighted = acu
acc_unweighted = au
acc_land = accland
acc_sea = accsea
else:
# seq_real = np.concatenate((seq_real, sr), 0)
# seq_pred = np.concatenate((seq_pred, sp), 0)
valid_loss = np.concatenate((valid_loss, vl), 0)
valid_loss_coarse = np.concatenate((valid_loss_coarse, vc), 0)
acc = np.concatenate((acc, a), 0)
acc_coarse = np.concatenate((acc_coarse, ac), 0)
acc_coarse_unweighted = np.concatenate((acc_coarse_unweighted, acu), 0)
acc_unweighted = np.concatenate((acc_unweighted, au), 0)
acc_land = np.concatenate((acc_land, accland), 0)
acc_sea = np.concatenate((acc_sea, accsea), 0)
prediction_length = seq_real[0].shape[0]
n_out_channels = seq_real[0].shape[1]
img_shape_x = seq_real[0].shape[2]
img_shape_y = seq_real[0].shape[3]
#save predictions and loss
if params.log_to_screen:
logging.info("Saving files at {}".format(os.path.join(params['experiment_dir'], 'autoregressive_predictions' + autoregressive_inference_filetag + '.h5')))
with h5py.File(os.path.join(params['experiment_dir'], 'autoregressive_predictions'+ autoregressive_inference_filetag +'.h5'), 'a') as f:
if vis:
try:
f.create_dataset("ground_truth", data = seq_real, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
except:
del f["ground_truth"]
f.create_dataset("ground_truth", data = seq_real, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
f["ground_truth"][...] = seq_real
try:
f.create_dataset("predicted", data = seq_pred, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
except:
del f["predicted"]
f.create_dataset("predicted", data = seq_pred, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
f["predicted"][...]= seq_pred
if params.masked_acc:
try:
f.create_dataset("acc_land", data = acc_land)#, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc_land"]
f.create_dataset("acc_land", data = acc_land)#, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc_land"][...] = acc_land
try:
f.create_dataset("acc_sea", data = acc_sea)#, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc_sea"]
f.create_dataset("acc_sea", data = acc_sea)#, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc_sea"][...] = acc_sea
try:
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["rmse"]
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["rmse"][...] = valid_loss
try:
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc"]
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc"][...] = acc
try:
f.create_dataset("rmse_coarse", data = valid_loss_coarse, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["rmse_coarse"]
f.create_dataset("rmse_coarse", data = valid_loss_coarse, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["rmse_coarse"][...] = valid_loss_coarse
try:
f.create_dataset("acc_coarse", data = acc_coarse, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc_coarse"]
f.create_dataset("acc_coarse", data = acc_coarse, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc_coarse"][...] = acc_coarse
try:
f.create_dataset("acc_unweighted", data = acc_unweighted, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc_unweighted"]
f.create_dataset("acc_unweighted", data = acc_unweighted, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc_unweighted"][...] = acc_unweighted
try:
f.create_dataset("acc_coarse_unweighted", data = acc_coarse_unweighted, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc_coarse_unweighted"]
f.create_dataset("acc_coarse_unweighted", data = acc_coarse_unweighted, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc_coarse_unweighted"][...] = acc_coarse_unweighted
f.close()
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import os
import time
import numpy as np
import argparse
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
from numpy.core.numeric import False_
import h5py
import torch
import torchvision
from torchvision.utils import save_image
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
import logging
from utils import logging_utils
from utils.weighted_acc_rmse import weighted_rmse_torch_channels, weighted_acc_torch_channels
logging_utils.config_logger()
from utils.YParams import YParams
from utils.data_loader_multifiles import get_data_loader
from networks.afnonet import AFNONet
import wandb
import matplotlib.pyplot as plt
import glob
fld = "z500" # diff flds have diff decor times and hence differnt ics
if fld == "z500" or fld == "2m_temperature" or fld == "t850":
DECORRELATION_TIME = 36 # 9 days (36) for z500, 2 (8 steps) days for u10, v10
else:
DECORRELATION_TIME = 8 # 9 days (36) for z500, 2 (8 steps) days for u10, v10
idxes = {"u10":0, "z500":14, "tp":0}
def gaussian_perturb(x, level=0.01, device=0):
noise = level * torch.randn(x.shape).to(device, dtype=torch.float)
return (x + noise)
def load_model(model, params, checkpoint_file): #, local_rank):
model.zero_grad()
checkpoint = torch.load(checkpoint_file)
try:
new_state_dict = OrderedDict()
for key, val in checkpoint['model_state'].items():
name = key[7:]
if name != 'ged':
new_state_dict[name] = val
model.load_state_dict(new_state_dict)
except:
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model
def downsample(x, scale=0.125):
return torch.nn.functional.interpolate(x, scale_factor=scale, mode='bilinear')
def setup(params):
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
#get data loader
valid_data_loader, valid_dataset = get_data_loader(params, params.inf_data_path, dist.is_initialized(), train=False)
img_shape_x = valid_dataset.img_shape_x
img_shape_y = valid_dataset.img_shape_y
params.img_shape_x = img_shape_x
params.img_shape_y = img_shape_y
if params.log_to_screen:
logging.info('Loading trained model checkpoint from {}'.format(params['best_checkpoint_path']))
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
params['N_in_channels'] = n_in_channels
params['N_out_channels'] = n_out_channels
params.means = np.load(params.global_means_path)[0, out_channels] # needed to standardize wind data
params.stds = np.load(params.global_stds_path)[0, out_channels]
# load the model
if params.nettype == 'afno':
model = AFNONet(params).to(device)
checkpoint_file = params['best_checkpoint_path']
model = load_model(model, params, checkpoint_file)
model = model.to(device)
# load the validation data
files_paths = glob.glob(params.inf_data_path + "/*.h5")
files_paths.sort()
if params.log_to_screen:
logging.info('Loading validation data')
logging.info('Validation data from {}'.format(files_paths[0]))
# which year
yr = 0
valid_data_full = h5py.File(files_paths[yr], 'r')['fields']
return valid_data_full, model
def autoregressive_inference(params, ic, valid_data_full, model):
ic = int(ic)
#initialize global variables
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
exp_dir = params['experiment_dir']
dt = int(params.dt)
prediction_length = int(params.prediction_length/dt)
n_history = params.n_history
img_shape_x = params.img_shape_x
img_shape_y = params.img_shape_y
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
means = params.means
stds = params.stds
n_pert = params.n_pert
#initialize memory for image sequences and RMSE/ACC
valid_loss = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
seq_real = torch.zeros((prediction_length+n_history, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred = torch.zeros((prediction_length+n_history, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
valid_data = valid_data_full[ic:(ic+prediction_length*dt+n_history*dt):dt, in_channels, 0:720] #extract valid data from first year
# standardize
valid_data = (valid_data - means)/stds
valid_data = torch.as_tensor(valid_data).to(device, dtype=torch.float)
#load time means
m = torch.as_tensor((np.load(params.time_means_path)[0][out_channels] - means)/stds)[:, 0:img_shape_x] # climatology
m = torch.unsqueeze(m, 0)
m = m.to(device)
std = torch.as_tensor(stds[:,0,0]).to(device, dtype=torch.float)
#autoregressive inference
if params.log_to_screen:
logging.info('Begin autoregressive inference')
for pert in range(n_pert):
logging.info('Running ensemble {}/{}'.format(pert+1, n_pert))
with torch.no_grad():
for i in range(valid_data.shape[0]):
if i==0: #start of sequence
first = valid_data[0:n_history+1]
future = valid_data[n_history+1]
for h in range(n_history+1):
seq_real[h] = first[h*n_in_channels : (h+1)*n_in_channels][0:n_out_channels] #extract history from 1st
seq_pred[h] = seq_real[h]
first = gaussian_perturb(first, level=params.n_level, device=device) # perturb the ic
future_pred = model(first)
else:
if i < prediction_length-1:
future = valid_data[n_history+i+1]
future_pred = model(future_pred) #autoregressive step
if i < prediction_length-1: #not on the last step
seq_pred[n_history+i+1] += torch.squeeze(future_pred,0) # add up predictions and average later
seq_real[n_history+i+1] += future
#Compute metrics
for i in range(valid_data.shape[0]):
if i>0:
# avg images
seq_pred[i] /= n_pert
seq_real[i] /= n_pert
pred = torch.unsqueeze(seq_pred[i], 0)
tar = torch.unsqueeze(seq_real[i], 0)
valid_loss[i] = weighted_rmse_torch_channels(pred, tar) * std
acc[i] = weighted_acc_torch_channels(pred-m, tar-m)
if params.log_to_screen:
idx = idxes[fld]
logging.info('Predicted timestep {} of {}. {} RMS Error: {}, ACC: {}'.format(i, prediction_length, fld, valid_loss[i, idx], acc[i, idx]))
valid_loss = valid_loss.cpu().numpy()
acc = acc.cpu().numpy()
return np.expand_dims(valid_loss,0), np.expand_dims(acc, 0)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--run_num", default='00', type=str)
parser.add_argument("--yaml_config", default='./config/AFNO.yaml', type=str)
parser.add_argument("--config", default='full_field', type=str)
parser.add_argument("--override_dir", default=None, type = str, help = 'Path to store inference outputs; must also set --weights arg')
parser.add_argument("--n_pert", default=100, type=int)
parser.add_argument("--n_level", default=0.3, type=float)
parser.add_argument("--weights", default=None, type=str, help = 'Path to model weights, for use with override_dir option')
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
params['world_size'] = 1
if 'WORLD_SIZE' in os.environ:
params['world_size'] = int(os.environ['WORLD_SIZE'])
params['n_pert'] = args.n_pert
world_rank = 0
world_size = params.world_size
params['global_batch_size'] = params.batch_size
local_rank = 0
if params['world_size'] > 1:
local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group(backend='nccl',
init_method='env://')
args.gpu = local_rank
world_rank = dist.get_rank()
world_size = dist.get_world_size()
params['global_batch_size'] = params.batch_size
#params['batch_size'] = int(params.batch_size//params['world_size'])
torch.cuda.set_device(local_rank)
torch.backends.cudnn.benchmark = True
# Set up directory
if args.override_dir is not None:
assert args.weights is not None, 'Must set --weights argument if using --override_dir'
expDir = args.override_dir
else:
assert args.weights is None, 'Cannot use --weights argument without also using --override_dir'
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
if world_rank==0:
if not os.path.isdir(expDir):
os.makedirs(expDir)
params['experiment_dir'] = os.path.abspath(expDir)
params['best_checkpoint_path'] = args.weights if args.override_dir is not None else os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
params['resuming'] = False
params['local_rank'] = local_rank
# this will be the wandb name
params['name'] = args.config + '_' + str(args.run_num)
params['group'] = args.config
if world_rank==0:
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference_out.log'))
logging_utils.log_versions()
params.log()
params['log_to_wandb'] = (world_rank==0) and params['log_to_wandb']
params['log_to_screen'] = (world_rank==0) and params['log_to_screen']
n_ics = params['n_initial_conditions']
if fld == "z500" or fld == "t850":
n_samples_per_year = 1336
else:
n_samples_per_year = 1460
if params["ics_type"] == 'default':
num_samples = n_samples_per_year-params.prediction_length
stop = num_samples
ics = np.arange(0, stop, DECORRELATION_TIME)
n_ics = len(ics)
logging.info("Inference for {} initial conditions".format(n_ics))
elif params["ics_type"] == "datetime":
date_strings = params["date_strings"]
ics = []
for date in date_strings:
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
ics.append(int(hours_since_jan_01_epoch/6))
print(ics)
n_ics = len(ics)
try:
autoregressive_inference_filetag = params["inference_file_tag"]
except:
autoregressive_inference_filetag = ""
params.n_level = args.n_level
autoregressive_inference_filetag += "_" + str(params.n_level) + "_" + str(params.n_pert) + "ens_" + fld
# get data and models
valid_data_full, model = setup(params)
#initialize lists for image sequences and RMSE/ACC
valid_loss = np.zeros
acc = []
# run autoregressive inference for multiple initial conditions
# parallelize over initial conditions
if world_size > 1:
tot_ics = len(ics)
ics_per_proc = n_ics//world_size
ics = ics[ics_per_proc*world_rank:ics_per_proc*(world_rank+1)] if world_rank < world_size - 1 else ics[(world_size - 1)*ics_per_proc:]
n_ics = len(ics)
logging.info('Rank %d running ics %s'%(world_rank, str(ics)))
for i, ic in enumerate(ics):
t0 = time.time()
logging.info("Initial condition {} of {}".format(i+1, n_ics))
vl, a = autoregressive_inference(params, ic, valid_data_full, model)
if i ==0:
valid_loss = vl
acc = a
else:
valid_loss = np.concatenate((valid_loss, vl), 0)
acc = np.concatenate((acc, a), 0)
t1 = time.time() - t0
logging.info("Time for inference for ic {} = {}".format(i, t1))
prediction_length = acc[0].shape[0]
n_out_channels = acc[0].shape[1]
#save predictions and loss
h5name = os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions'+ autoregressive_inference_filetag +'.h5')
if dist.is_initialized():
if params.log_to_screen:
logging.info("Saving files at {}".format(h5name))
logging.info("array shapes: %s"%str((tot_ics, prediction_length, n_out_channels)))
dist.barrier()
from mpi4py import MPI
with h5py.File(h5name, 'a', driver='mpio', comm=MPI.COMM_WORLD) as f:
if "rmse" in f.keys() or "acc" in f.keys():
del f["acc"]
del f["rmse"]
f.create_dataset("rmse", shape = (tot_ics, prediction_length, n_out_channels), dtype =np.float32)
f.create_dataset("acc", shape = (tot_ics, prediction_length, n_out_channels), dtype =np.float32)
start = world_rank*ics_per_proc
f["rmse"][start:start+n_ics] = valid_loss
f["acc"][start:start+n_ics] = acc
dist.barrier()
else:
if params.log_to_screen:
logging.info("Saving files at {}".format(os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions' + autoregressive_inference_filetag + '.h5')))
with h5py.File(os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions'+ autoregressive_inference_filetag +'.h5'), 'a') as f:
try:
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["rmse"]
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["rmse"][...] = valid_loss
try:
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc"]
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc"][...] = acc
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import os
import time
import numpy as np
import argparse
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
from numpy.core.numeric import False_
import h5py
import torch
import torchvision
from torchvision.utils import save_image
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
import logging
from utils import logging_utils
from utils.weighted_acc_rmse import weighted_rmse_torch_channels, weighted_acc_torch_channels, unlog_tp_torch, top_quantiles_error_torch
logging_utils.config_logger()
from utils.YParams import YParams
from utils.data_loader_multifiles import get_data_loader
from networks.afnonet import AFNONet, PrecipNet
import wandb
import matplotlib.pyplot as plt
import glob
from datetime import datetime
DECORRELATION_TIME = 8 # 2 days for preicp
def gaussian_perturb(x, level=0.01, device=0):
noise = level * torch.randn(x.shape).to(device, dtype=torch.float)
return (x + noise)
def load_model(model, params, checkpoint_file):
model.zero_grad()
checkpoint = torch.load(checkpoint_file)
try:
new_state_dict = OrderedDict()
for key, val in checkpoint['model_state'].items():
name = key[7:]
if name != 'ged':
new_state_dict[name] = val
model.load_state_dict(new_state_dict)
except:
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model
def setup(params):
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
#get data loader
valid_data_loader, valid_dataset = get_data_loader(params, params.inf_data_path, dist.is_initialized(), train=False)
img_shape_x = valid_dataset.img_shape_x
img_shape_y = valid_dataset.img_shape_y
params.img_shape_x = img_shape_x
params.img_shape_y = img_shape_y
if params.log_to_screen:
logging.info('Loading trained model checkpoint from {}'.format(params['best_checkpoint_path']))
in_channels = np.array(params.in_channels)
out_channels = np.array(params.in_channels) # same as in for the wind model
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
params['N_in_channels'] = n_in_channels
params['N_out_channels'] = n_out_channels
params.means = np.load(params.global_means_path)[0, out_channels] # needed to standardize wind data
params.stds = np.load(params.global_stds_path)[0, out_channels]
# load the models
# load wind model
if params.nettype_wind == 'afno':
model_wind = AFNONet(params).to(device)
checkpoint_file = params['model_wind_path']
model_wind = load_model(model_wind, params, checkpoint_file)
model_wind = model_wind.to(device)
out_channels = np.array(params.out_channels) # change out channels for precip model
n_out_channels = len(out_channels)
params['N_out_channels'] = n_out_channels
if params.nettype == 'afno':
model = AFNONet(params).to(device)
model = PrecipNet(params, backbone=model).to(device)
checkpoint_file = params['best_checkpoint_path']
model = load_model(model, params, checkpoint_file)
model = model.to(device)
# load the validation data
files_paths = glob.glob(params.inf_data_path + "/*.h5")
files_paths.sort()
if params.log_to_screen:
logging.info('Loading validation data')
logging.info('Validation data from {}'.format(files_paths[0]))
valid_data_full = h5py.File(files_paths[0], 'r')['fields']
# precip paths
path = params.precip + '/out_of_sample'
precip_paths = glob.glob(path + "/*.h5")
precip_paths.sort()
if params.log_to_screen:
logging.info('Loading validation precip data')
logging.info('Validation data from {}'.format(precip_paths[0]))
valid_data_tp_full = h5py.File(precip_paths[0], 'r')['tp']
return valid_data_full, valid_data_tp_full, model_wind, model
def autoregressive_inference(params, ic, valid_data_full, valid_data_tp_full, model_wind, model):
ic = int(ic)
#initialize global variables
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
exp_dir = params['experiment_dir']
dt = int(params.dt)
prediction_length = int(params.prediction_length/dt)
n_history = params.n_history
img_shape_x = params.img_shape_x
img_shape_y = params.img_shape_y
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
means = params.means
stds = params.stds
n_pert = params.n_pert
#initialize memory for image sequences and RMSE/ACC
valid_loss = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc_unweighted = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
tqe = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
# wind seqs
seq_real = torch.zeros((prediction_length+n_history, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred = torch.zeros((prediction_length+n_history, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
# precip sequences
seq_real_tp = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred_tp = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
valid_data = valid_data_full[ic:(ic+prediction_length*dt+n_history*dt):dt, in_channels, 0:720] #extract valid data from first year
# standardize
valid_data = (valid_data - means)/stds
valid_data = torch.as_tensor(valid_data).to(device, dtype=torch.float)
len_ic = prediction_length*dt
valid_data_tp = valid_data_tp_full[ic:(ic+prediction_length*dt):dt, 0:720].reshape(len_ic,n_out_channels,720,img_shape_y) #extract valid data from first year
# log normalize
eps = params.precip_eps
valid_data_tp = np.log1p(valid_data_tp/eps)
valid_data_tp = torch.as_tensor(valid_data_tp).to(device, dtype=torch.float)
#load time means
m = torch.as_tensor(np.load(params.time_means_path_tp)[0][out_channels])[:, 0:img_shape_x] # climatology
m = torch.unsqueeze(m, 0)
m = m.to(device)
if params.log_to_screen:
logging.info('Begin autoregressive+tp inference')
for pert in range(n_pert):
logging.info('Running ensemble {}/{}'.format(pert+1, n_pert))
with torch.no_grad():
for i in range(valid_data.shape[0]):
if i==0: #start of sequence
first = valid_data[0:n_history+1]
first_tp = valid_data_tp[0:1]
future = valid_data[n_history+1]
future_tp = valid_data_tp[1]
for h in range(n_history+1):
seq_real[h] = first[h*n_in_channels:(h+1)*n_in_channels][0:n_in_channels] #extract history from 1st
seq_pred[h] = seq_real[h]
seq_real_tp[0] = unlog_tp_torch(first_tp)
seq_pred_tp[0] = unlog_tp_torch(first_tp)
first = gaussian_perturb(first, level=params.n_level, device=device) # perturb the ic
future_pred = model_wind(first)
future_pred_tp = model(future_pred)
else:
if i < prediction_length-1:
future = valid_data[n_history+i+1]
future_tp = valid_data_tp[i+1]
future_pred = model_wind(future_pred) #autoregressive step
future_pred_tp = model(future_pred) # tp diagnosis
if i < prediction_length-1: #not on the last step
seq_pred_tp[n_history+i+1] += unlog_tp_torch(torch.squeeze(future_pred_tp,0)) # add up predictions and average later
seq_real_tp[n_history+i+1] += unlog_tp_torch(future_tp)
#Compute metrics
for i in range(valid_data.shape[0]):
if i>0:
# avg images
seq_pred_tp[i] /= n_pert
seq_real_tp[i] /= n_pert
pred = torch.unsqueeze(seq_pred_tp[i], 0)
tar = torch.unsqueeze(seq_real_tp[i], 0)
valid_loss[i] = weighted_rmse_torch_channels(pred, tar)
acc[i] = weighted_acc_torch_channels(pred-m, tar-m)
tqe[i] = top_quantiles_error_torch(pred, tar)
if params.log_to_screen:
logging.info('Timestep {} of {}. TP RMS Error: {}, ACC: {}'.format((i), prediction_length, valid_loss[i,0], acc[i,0]))
seq_real_tp = seq_real_tp.cpu().numpy()
seq_pred_tp = seq_pred_tp.cpu().numpy()
valid_loss = valid_loss.cpu().numpy()
acc = acc.cpu().numpy()
acc_unweighted = acc_unweighted.cpu().numpy()
tqe = tqe.cpu().numpy()
return np.expand_dims(valid_loss, 0), \
np.expand_dims(acc, 0), \
np.expand_dims(tqe, 0)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--run_num", default='00', type=str)
parser.add_argument("--yaml_config", default='./config/AFNO.yaml', type=str)
parser.add_argument("--config", default='full_field', type=str)
parser.add_argument("--n_level", default = 0.3, type = float)
parser.add_argument("--n_pert", default=100, type=int)
parser.add_argument("--override_dir", default=None, type = str, help = 'Path to store inference outputs; must also set --weights arg')
parser.add_argument("--weights", default=None, type=str, help = 'Path to model weights, for use with override_dir option')
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
params['world_size'] = 1
if 'WORLD_SIZE' in os.environ:
params['world_size'] = int(os.environ['WORLD_SIZE'])
world_rank = 0
local_rank = 0
if params['world_size'] > 1:
local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group(backend='nccl', init_method='env://')
args.gpu = local_rank
world_rank = dist.get_rank()
world_size = dist.get_world_size()
params['global_batch_size'] = params.batch_size
# params['batch_size'] = int(params.batch_size//params['world_size'])
torch.cuda.set_device(local_rank)
torch.backends.cudnn.benchmark = True
# Set up directory
if args.override_dir is not None:
assert args.weights is not None, 'Must set --weights argument if using --override_dir'
expDir = args.override_dir
else:
assert args.weights is None, 'Cannot use --weights argument without also using --override_dir'
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
if world_rank==0:
if not os.path.isdir(expDir):
os.makedirs(expDir)
params['experiment_dir'] = os.path.abspath(expDir)
params['best_checkpoint_path'] = args.weights if args.override_dir is not None else os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
args.resuming = False
params['resuming'] = args.resuming
params['local_rank'] = local_rank
# this will be the wandb name
params['name'] = args.config + '_' + str(args.run_num)
params['group'] = args.config
if world_rank==0:
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference_out.log'))
logging_utils.log_versions()
params.log()
params['log_to_wandb'] = (world_rank==0) and params['log_to_wandb']
params['log_to_screen'] = (world_rank==0) and params['log_to_screen']
n_ics = params['n_initial_conditions']
if params["ics_type"] == 'default':
num_samples = 1460 - params.prediction_length
stop = num_samples
ics = np.arange(0, stop, DECORRELATION_TIME)
n_ics = len(ics)
logging.info("Inference for {} initial conditions".format(n_ics))
elif params["ics_type"] == "datetime":
date_strings = params["date_strings"]
ics = []
for date in date_strings:
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
ics.append(int(hours_since_jan_01_epoch/6))
try:
autoregressive_inference_filetag = params["inference_file_tag"]
except:
autoregressive_inference_filetag = ""
n_level = args.n_level
params.n_level = n_level
params.n_pert = args.n_pert
logging.info("Doing level = {}".format(n_level))
autoregressive_inference_filetag += "_" + str(params.n_level) + "_" + str(params.n_pert) + "ens_tp"
# get data and models
valid_data_full, valid_data_tp_full, model_wind, model = setup(params)
#initialize lists for image sequences and RMSE/ACC
valid_loss = np.zeros
acc = []
tqe = []
# run autoregressive inference for multiple initial conditions
# parallelize over initial conditions
if world_size > 1:
tot_ics = len(ics)
ics_per_proc = n_ics//world_size
ics = ics[ics_per_proc*world_rank:ics_per_proc*(world_rank+1)] if world_rank < world_size - 1 else ics[(world_size - 1)*ics_per_proc:]
n_ics = len(ics)
logging.info('Rank %d running ics %s'%(world_rank, str(ics)))
for i, ic in enumerate(ics):
t1 = time.time()
logging.info("Initial condition {} of {}".format(i+1, n_ics))
vl, a, tq = autoregressive_inference(params, ic, valid_data_full, valid_data_tp_full, model_wind, model)
if i == 0:
valid_loss = vl
acc = a
tqe = tq
else:
valid_loss = np.concatenate((valid_loss, vl), 0)
acc = np.concatenate((acc, a), 0)
tqe = np.concatenate((tqe, tq), 0)
t2 = time.time()-t1
logging.info("Time for inference for ic {} = {}".format(i, t2))
prediction_length = acc[0].shape[0]
n_out_channels = acc[0].shape[1]
#save predictions and loss
#save predictions and loss
h5name = os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions'+ autoregressive_inference_filetag +'.h5')
if dist.is_initialized():
if params.log_to_screen:
logging.info("Saving files at {}".format(h5name))
logging.info("array shapes: %s"%str((tot_ics, prediction_length, n_out_channels)))
dist.barrier()
from mpi4py import MPI
with h5py.File(h5name, 'a', driver='mpio', comm=MPI.COMM_WORLD) as f:
if "rmse" in f.keys() or "acc" in f.keys():
del f["acc"]
del f["rmse"]
f.create_dataset("rmse", shape = (tot_ics, prediction_length, n_out_channels), dtype =np.float32)
f.create_dataset("acc", shape = (tot_ics, prediction_length, n_out_channels), dtype =np.float32)
start = world_rank*ics_per_proc
f["rmse"][start:start+n_ics] = valid_loss
f["acc"][start:start+n_ics] = acc
dist.barrier()
else:
if params.log_to_screen:
logging.info("Saving files at {}".format(os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions' + autoregressive_inference_filetag + '.h5')))
with h5py.File(os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions'+ autoregressive_inference_filetag +'.h5'), 'a') as f:
try:
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["rmse"]
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["rmse"][...] = valid_loss
try:
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc"]
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc"][...] = acc
try:
f.create_dataset("tqe", data = tqe, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["tqe"]
f.create_dataset("tqe", data = tqe, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["tqe"][...] = tqe
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import os
import sys
import time
import numpy as np
import argparse
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
from numpy.core.numeric import False_
import h5py
import torch
import torchvision
from torchvision.utils import save_image
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
import logging
from utils import logging_utils
from utils.weighted_acc_rmse import weighted_rmse_torch_channels, weighted_acc_torch_channels, unlog_tp_torch, top_quantiles_error_torch
logging_utils.config_logger()
from utils.YParams import YParams
from utils.data_loader_multifiles import get_data_loader
from networks.afnonet import AFNONet, PrecipNet
import wandb
import matplotlib.pyplot as plt
import glob
from datetime import datetime
DECORRELATION_TIME = 8 # 2 days for preicp
def gaussian_perturb(x, level=0.01, device=0):
noise = level * torch.randn(x.shape).to(device, dtype=torch.float)
return (x + noise)
def load_model(model, params, checkpoint_file):
model.zero_grad()
checkpoint_fname = checkpoint_file
checkpoint = torch.load(checkpoint_fname)
try:
new_state_dict = OrderedDict()
for key, val in checkpoint['model_state'].items():
name = key[7:]
if name != 'ged':
new_state_dict[name] = val
model.load_state_dict(new_state_dict)
except:
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model
def downsample(x, scale=0.125):
return torch.nn.functional.interpolate(x, scale_factor=scale, mode='bilinear')
def setup(params):
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
#get data loader
valid_data_loader, valid_dataset = get_data_loader(params, params.inf_data_path, dist.is_initialized(), train=False)
img_shape_x = valid_dataset.img_shape_x
img_shape_y = valid_dataset.img_shape_y
params.img_shape_x = img_shape_x
params.img_shape_y = img_shape_y
if params.log_to_screen:
logging.info('Loading trained model checkpoint from {}'.format(params['best_checkpoint_path']))
in_channels = np.array(params.in_channels)
out_channels = np.array(params.in_channels)# for the backbone model, will be reset later
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
if params["orography"]:
params['N_in_channels'] = n_in_channels + 1
else:
params['N_in_channels'] = n_in_channels
params['N_out_channels'] = n_out_channels
params.means = np.load(params.global_means_path)[0, out_channels] # needed to standardize wind data
params.stds = np.load(params.global_stds_path)[0, out_channels]
# load wind model
if params.nettype_wind == 'afno':
model_wind = AFNONet(params).to(device)
if 'model_wind_path' not in params:
raise Exception("no backbone model weights specified")
checkpoint_file = params['model_wind_path']
model_wind = load_model(model_wind, params, checkpoint_file)
model_wind = model_wind.to(device)
# reset channels for precip
params['N_out_channels'] = len(params['out_channels'])
# load the model
if params.nettype == 'afno':
model = AFNONet(params).to(device)
else:
raise Exception("not implemented")
model = PrecipNet(params, backbone=model).to(device)
checkpoint_file = params['best_checkpoint_path']
model = load_model(model, params, checkpoint_file)
model = model.to(device)
# load the validation data
files_paths = glob.glob(params.inf_data_path + "/*.h5")
files_paths.sort()
# which year
yr = 0
if params.log_to_screen:
logging.info('Loading validation data')
logging.info('Validation data from {}'.format(files_paths[yr]))
valid_data_full = h5py.File(files_paths[yr], 'r')['fields']
# precip paths
path = params.precip + '/out_of_sample'
precip_paths = glob.glob(path + "/*.h5")
precip_paths.sort()
if params.log_to_screen:
logging.info('Loading validation precip data')
logging.info('Validation data from {}'.format(precip_paths[0]))
valid_data_tp_full = h5py.File(precip_paths[0], 'r')['tp']
return valid_data_full, valid_data_tp_full, model_wind, model
def autoregressive_inference(params, ic, valid_data_full, valid_data_tp_full, model_wind, model):
ic = int(ic)
#initialize global variables
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
exp_dir = params['experiment_dir']
dt = int(params.dt)
prediction_length = int(params.prediction_length/dt)
n_history = params.n_history
img_shape_x = params.img_shape_x
img_shape_y = params.img_shape_y
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
means = params.means
stds = params.stds
#initialize memory for image sequences and RMSE/ACC, tqe for precip
valid_loss = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc_unweighted = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
tqe = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
# wind seqs
seq_real = torch.zeros((prediction_length, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred = torch.zeros((prediction_length, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
# precip sequences
seq_real_tp = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred_tp = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
valid_data = valid_data_full[ic:(ic+prediction_length*dt+n_history*dt):dt, in_channels, 0:720] #extract valid data from first year
# standardize
valid_data = (valid_data - means)/stds
valid_data = torch.as_tensor(valid_data).to(device, dtype=torch.float)
len_ic = prediction_length*dt
valid_data_tp = valid_data_tp_full[ic:(ic+prediction_length*dt):dt, 0:720].reshape(len_ic,n_out_channels,720,img_shape_y) #extract valid data from first year
# log normalize
eps = params.precip_eps
valid_data_tp = np.log1p(valid_data_tp/eps)
valid_data_tp = torch.as_tensor(valid_data_tp).to(device, dtype=torch.float)
m = torch.as_tensor(np.load(params.time_means_path_tp)[0][out_channels])[:, 0:img_shape_x] # climatology
m = torch.unsqueeze(m, 0)
m = m.to(device, dtype=torch.float)
std = torch.as_tensor(stds[:,0,0]).to(device, dtype=torch.float)
orography = params.orography
orography_path = params.orography_path
if orography:
orog = torch.as_tensor(np.expand_dims(np.expand_dims(h5py.File(orography_path, 'r')['orog'][0:720], axis = 0), axis = 0)).to(device, dtype = torch.float)
logging.info("orography loaded; shape:{}".format(orog.shape))
#autoregressive inference
if params.log_to_screen:
logging.info('Begin autoregressive inference')
with torch.no_grad():
for i in range(valid_data.shape[0]):
if i==0: #start of sequence
first = valid_data[0:n_history+1]
first_tp = valid_data_tp[0:1]
future = valid_data[n_history+1]
future_tp = valid_data_tp[1]
for h in range(n_history+1):
seq_real[h] = first[h*n_in_channels:(h+1)*n_in_channels][0:n_in_channels] #extract history from 1st
seq_pred[h] = seq_real[h]
seq_real_tp[0] = unlog_tp_torch(first_tp)
seq_pred_tp[0] = unlog_tp_torch(first_tp)
if params.perturb:
first = gaussian_perturb(first, level=params.n_level, device=device) # perturb the ic
if orography:
future_pred = model_wind(torch.cat((first, orog), axis=1))
else:
future_pred = model_wind(first)
future_pred_tp = model(future_pred)
else:
if i < prediction_length-1:
future = valid_data[n_history+i+1]
future_tp = valid_data_tp[i+1]
if orography:
future_pred = model_wind(torch.cat((future_pred, orog), axis=1)) #autoregressive step
else:
future_pred = model_wind(future_pred) #autoregressive step
future_pred_tp = model(future_pred) # tp diagnosis
if i < prediction_length-1: #not on the last step
seq_pred[n_history+i+1] = future_pred
seq_real[n_history+i+1] = future
seq_pred_tp[i+1] = unlog_tp_torch(future_pred_tp) # this predicts 6-12 precip: 0 -> 6 (afno) -> 6-12 precip
seq_real_tp[i+1] = unlog_tp_torch(future_tp) # which is the i+1th validation data
#collect history
history_stack = seq_pred[i+1:i+2+n_history]
# ic for next wind step
future_pred = history_stack
pred = torch.unsqueeze(seq_pred_tp[i], 0)
tar = torch.unsqueeze(seq_real_tp[i], 0)
valid_loss[i] = weighted_rmse_torch_channels(pred, tar)
acc[i] = weighted_acc_torch_channels(pred-m, tar-m)
tqe[i] = top_quantiles_error_torch(pred, tar)
if params.log_to_screen:
logging.info('Timestep {} of {}. TP RMS Error: {}, ACC: {}'.format((i), prediction_length, valid_loss[i,0], acc[i,0]))
seq_real_tp = seq_real_tp.cpu().numpy()
seq_pred_tp = seq_pred_tp.cpu().numpy()
valid_loss = valid_loss.cpu().numpy()
acc = acc.cpu().numpy()
acc_unweighted = acc_unweighted.cpu().numpy()
tqe = tqe.cpu().numpy()
return np.expand_dims(seq_real_tp, 0), np.expand_dims(seq_pred_tp, 0), np.expand_dims(valid_loss, 0), \
np.expand_dims(acc, 0), np.expand_dims(acc_unweighted, 0), np.expand_dims(tqe, 0)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--run_num", default='00', type=str)
parser.add_argument("--yaml_config", default='./config/AFNO.yaml', type=str)
parser.add_argument("--config", default='full_field', type=str)
parser.add_argument("--vis", action='store_true')
parser.add_argument("--override_dir", default=None, type = str, help = 'Path to store inference outputs; must also set --weights arg')
parser.add_argument("--weights", default=None, type=str, help = 'Path to model weights, for use with override_dir option')
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
params['world_size'] = 1
params['global_batch_size'] = params.batch_size
torch.cuda.set_device(0)
torch.backends.cudnn.benchmark = True
vis = args.vis
# Set up directory
if args.override_dir is not None:
assert args.weights is not None, 'Must set --weights argument if using --override_dir'
expDir = args.override_dir
else:
assert args.weights is None, 'Cannot use --weights argument without also using --override_dir'
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
if not os.path.isdir(expDir):
os.makedirs(expDir)
params['experiment_dir'] = os.path.abspath(expDir)
params['best_checkpoint_path'] = args.weights if args.override_dir is not None else os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
params['resuming'] = False
params['local_rank'] = 0
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference_out.log'))
logging_utils.log_versions()
params.log()
n_ics = params['n_initial_conditions']
ics = [1066, 1050, 1034]
n_samples_per_year = 1460
if params["ics_type"] == 'default':
num_samples = n_samples_per_year-params.prediction_length
stop = num_samples
ics = np.arange(0, stop, DECORRELATION_TIME)
if vis: # visualization for just the first ic (or any ic)
ics = [0]
n_ics = len(ics)
elif params["ics_type"] == "datetime":
date_strings = params["date_strings"]
ics = []
if params.perturb: #for perturbations use a single date and create n_ics perturbations
n_ics = params["n_perturbations"]
date = date_strings[0]
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
for ii in range(n_ics):
ics.append(int(hours_since_jan_01_epoch/6))
else:
for date in date_strings:
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
ics.append(int(hours_since_jan_01_epoch/6))
n_ics = len(ics)
logging.info("Inference for {} initial conditions".format(n_ics))
try:
autoregressive_inference_filetag = params["inference_file_tag"]
except:
autoregressive_inference_filetag = ""
autoregressive_inference_filetag += "_tp"
# get data and models
valid_data_full, valid_data_tp_full, model_wind, model = setup(params)
#initialize lists for image sequences and RMSE/ACC
valid_loss = np.zeros
acc_unweighted = []
acc = []
tqe = []
seq_pred = []
seq_real = []
#run autoregressive inference for multiple initial conditions
for i, ic in enumerate(ics):
t1 = time.time()
logging.info("Initial condition {} of {}".format(i+1, n_ics))
sr, sp, vl, a, au, tq = autoregressive_inference(params, ic, valid_data_full, valid_data_tp_full, model_wind, model)
if i == 0:
seq_real = sr
seq_pred = sp
valid_loss = vl
acc = a
acc_unweighted = au
tqe = tq
else:
# seq_real = np.concatenate((seq_real, sr), 0)
# seq_pred = np.concatenate((seq_pred, sp), 0)
valid_loss = np.concatenate((valid_loss, vl), 0)
acc = np.concatenate((acc, a), 0)
acc_unweighted = np.concatenate((acc_unweighted, au), 0)
tqe = np.concatenate((tqe, tq), 0)
t2 = time.time()-t1
print("time for 1 autoreg inference = ", t2)
prediction_length = seq_real[0].shape[0]
n_out_channels = seq_real[0].shape[1]
img_shape_x = seq_real[0].shape[2]
img_shape_y = seq_real[0].shape[3]
#save predictions and loss
if params.log_to_screen:
logging.info("Saving files at {}".format(os.path.join(params['experiment_dir'], 'autoregressive_predictions' + autoregressive_inference_filetag + '.h5')))
with h5py.File(os.path.join(params['experiment_dir'], 'autoregressive_predictions'+ autoregressive_inference_filetag +'.h5'), 'a') as f:
if vis:
try:
f.create_dataset("ground_truth", data = seq_real, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
except:
del f["ground_truth"]
f.create_dataset("ground_truth", data = seq_real, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
f["ground_truth"][...] = seq_real
try:
f.create_dataset("predicted", data = seq_pred, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
except:
del f["predicted"]
f.create_dataset("predicted", data = seq_pred, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
f["predicted"][...]= seq_pred
try:
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["rmse"]
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["rmse"][...] = valid_loss
try:
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc"]
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc"][...] = acc
try:
f.create_dataset("acc_unweighted", data = acc_unweighted, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc_unweighted"]
f.create_dataset("acc_unweighted", data = acc_unweighted, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc_unweighted"][...] = acc_unweighted
try:
f.create_dataset("tqe", data = tqe, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["tqe"]
f.create_dataset("tqe", data = tqe, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["tqe"][...] = tqe
f.close()
#!/bin/bash
#shifter --image=nersc/pytorch:ngc-21.08-v1 --env PYTHONUSERBASE=/pscratch/home/jpathak/perlmutter/ngc-21.08-v1 python \
# train.py --enable_amp --config pretrained_two_step_afno_20ch_bs_64_lr1em4_blk_8_patch_8_cosine_sched --run_num test0
export MASTER_ADDR=$(hostname)
image=nersc/pytorch:ngc-22.02-v0
ngpu=4
config_file=./config/AFNO.yaml
config="afno_backbone"
run_num="check"
cmd="python train.py --enable_amp --yaml_config=$config_file --config=$config --run_num=$run_num"
srun -n $ngpu --cpus-per-task=32 --gpus-per-node $ngpu shifter --image=${image} bash -c "source export_DDP_vars.sh && $cmd"
#reference: https://github.com/NVlabs/AFNO-transformer
import math
from functools import partial
from collections import OrderedDict
from copy import Error, deepcopy
from re import S
from numpy.lib.arraypad import pad
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
#from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
import torch.fft
from torch.nn.modules.container import Sequential
from torch.utils.checkpoint import checkpoint_sequential
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from utils.img_utils import PeriodicPad2d
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class AFNO2D(nn.Module):
def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1):
super().__init__()
assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"
self.hidden_size = hidden_size
self.sparsity_threshold = sparsity_threshold
self.num_blocks = num_blocks
self.block_size = self.hidden_size // self.num_blocks
self.hard_thresholding_fraction = hard_thresholding_fraction
self.hidden_size_factor = hidden_size_factor
self.scale = 0.02
self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
def forward(self, x):
bias = x
dtype = x.dtype
x = x.float()
B, H, W, C = x.shape
x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)
o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
o2_real = torch.zeros(x.shape, device=x.device)
o2_imag = torch.zeros(x.shape, device=x.device)
total_modes = H // 2 + 1
kept_modes = int(total_modes * self.hard_thresholding_fraction)
o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \
self.b1[0]
)
o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \
self.b1[1]
)
o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \
torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
self.b2[0]
)
o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \
torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
self.b2[1]
)
x = torch.stack([o2_real, o2_imag], dim=-1)
x = F.softshrink(x, lambd=self.sparsity_threshold)
x = torch.view_as_complex(x)
x = x.reshape(B, H, W // 2 + 1, C)
x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho")
x = x.type(dtype)
return x + bias
class Block(nn.Module):
def __init__(
self,
dim,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
double_skip=True,
num_blocks=8,
sparsity_threshold=0.01,
hard_thresholding_fraction=1.0
):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
#self.drop_path = nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.double_skip = double_skip
def forward(self, x):
residual = x
x = self.norm1(x)
x = self.filter(x)
if self.double_skip:
x = x + residual
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path(x)
x = x + residual
return x
class PrecipNet(nn.Module):
def __init__(self, params, backbone):
super().__init__()
self.params = params
self.patch_size = (params.patch_size, params.patch_size)
self.in_chans = params.N_in_channels
self.out_chans = params.N_out_channels
self.backbone = backbone
self.ppad = PeriodicPad2d(1)
self.conv = nn.Conv2d(self.out_chans, self.out_chans, kernel_size=3, stride=1, padding=0, bias=True)
self.act = nn.ReLU()
def forward(self, x):
x = self.backbone(x)
x = self.ppad(x)
x = self.conv(x)
x = self.act(x)
return x
class AFNONet(nn.Module):
def __init__(
self,
params,
img_size=(720, 1440),
patch_size=(16, 16),
in_chans=2,
out_chans=2,
embed_dim=768,
depth=12,
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.,
num_blocks=16,
sparsity_threshold=0.01,
hard_thresholding_fraction=1.0,
):
super().__init__()
self.params = params
self.img_size = img_size
self.patch_size = (params.patch_size, params.patch_size)
self.in_chans = params.N_in_channels
self.out_chans = params.N_out_channels
self.num_features = self.embed_dim = embed_dim
self.num_blocks = params.num_blocks
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.h = img_size[0] // self.patch_size[0]
self.w = img_size[1] // self.patch_size[1]
self.blocks = nn.ModuleList([
Block(dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
num_blocks=self.num_blocks, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, self.out_chans*self.patch_size[0]*self.patch_size[1], bias=False)
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
x = x.reshape(B, self.h, self.w, self.embed_dim)
for blk in self.blocks:
x = blk(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = rearrange(
x,
"b h w (p1 p2 c_out) -> b c_out (h p1) (w p2)",
p1=self.patch_size[0],
p2=self.patch_size[1],
h=self.img_size[0] // self.patch_size[0],
w=self.img_size[1] // self.patch_size[1],
)
return x
class PatchEmbed(nn.Module):
def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
if __name__ == "__main__":
model = AFNONet(img_size=(720, 1440), patch_size=(4,4), in_chans=3, out_chans=10)
sample = torch.randn(1, 3, 720, 1440)
result = model(sample)
print(result.shape)
print(torch.norm(result))
# 1. 重新初始化,排除大文件
cd /workspace/FourCastNet
rm -rf .git
# 2. 创建 .gitignore 排除数据文件
cat > .gitignore << 'EOF'
data.tar.gz
data/
*.tar.gz
*.zip
EOF
# 3. 重新初始化 Git
git init
git add .
git commit -m "Initial commit: FourCastNet code only"
# 4. 强制推送(这次应该成功)
git remote add origin http://developer.sourcefind.cn/codes/bw_bestperf/fourcastnet_train.git
git push -f origin master
rm -rf ./exp
export HIP_MODULE_LOADING=EAGER
export HIP_ALLOC_INITIALIZE=0
export HIP_REUSE_MODULE=1
export HIP_VISIBLE_DEVICES=0
config_file=./config/AFNO-50epoch.yaml
config='afno_backbone'
run_num='check_exp'
export HDF5_USE_FILE_LOCKING=FALSE
python train.py --enable_amp --yaml_config=$config_file --config=$config --run_num=$run_num
#!/bin/bash -l
#SBATCH --time=06:00:00
#SBATCH -C gpu
#SBATCH --account=m4134_g
#SBATCH --nodes=16
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-task=32
#SBATCH -J afno
#SBATCH --image=nersc/pytorch:ngc-22.02-v0
#SBATCH -o afno_backbone_finetune.out
config_file=./config/AFNO.yaml
config='afno_backbone_finetune'
run_num='0'
export HDF5_USE_FILE_LOCKING=FALSE
export NCCL_NET_GDR_LEVEL=PHB
export MASTER_ADDR=$(hostname)
set -x
srun -u --mpi=pmi2 shifter \
bash -c "
source export_DDP_vars.sh
python train.py --enable_amp --yaml_config=$config_file --config=$config --run_num=$run_num
"
#!/bin/bash
#SBATCH --time=00:15:00
#SBATCH -N 4
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-task=32
#SBATCH -C gpu
#SBATCH --account=m4134_g
#SBATCH -q regular
#SBATCH --image=nersc/pytorch:ngc-22.02-v0
export HDF5_USE_FILE_LOCKING=FALSE
export MASTER_ADDR=$(hostname)
launch="python inference/inference_ensemble.py --config=afno_backbone_finetune --run_num=0 --n_level=0.3"
#launch="python inference/inference_ensemble_precip.py --config=precip --run_num=1 --n_level=0.1"
srun --mpi=pmi2 -u -l shifter --module gpu --env PYTHONUSERBASE=$HOME/.local/perlmutter/nersc-pytorch-22.02-v0 bash -c "
source export_DDP_vars.sh
$launch
"
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import os
import time
import numpy as np
import argparse
import h5py
import torch
import cProfile
import re
import torchvision
from torchvision.utils import save_image
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import logging
from utils import logging_utils
logging_utils.config_logger()
from utils.YParams import YParams
from utils.data_loader_multifiles import get_data_loader
from networks.afnonet import AFNONet, PrecipNet
from utils.img_utils import vis_precip
import wandb
from utils.weighted_acc_rmse import weighted_acc, weighted_rmse, weighted_rmse_torch, unlog_tp_torch
from apex import optimizers
from utils.darcy_loss import LpLoss
import matplotlib.pyplot as plt
from collections import OrderedDict
import pickle
DECORRELATION_TIME = 36 # 9 days
import json
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap as ruamelDict
class Trainer():
def count_parameters(self):
return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
def __init__(self, params, world_rank):
self.params = params
self.world_rank = world_rank
self.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
if params.log_to_wandb:
wandb.init(config=params, name=params.name, group=params.group, project=params.project, entity=params.entity)
logging.info('rank %d, begin data loader init'%world_rank)
self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader(params, params.train_data_path, dist.is_initialized(), train=True)
self.valid_data_loader, self.valid_dataset = get_data_loader(params, params.valid_data_path, dist.is_initialized(), train=False)
self.loss_obj = LpLoss()
logging.info('rank %d, data loader initialized'%world_rank)
params.crop_size_x = self.valid_dataset.crop_size_x
params.crop_size_y = self.valid_dataset.crop_size_y
params.img_shape_x = self.valid_dataset.img_shape_x
params.img_shape_y = self.valid_dataset.img_shape_y
# precip models
self.precip = True if "precip" in params else False
if self.precip:
if 'model_wind_path' not in params:
raise Exception("no backbone model weights specified")
# load a wind model
# the wind model has out channels = in channels
out_channels = np.array(params['in_channels'])
params['N_out_channels'] = len(out_channels)
if params.nettype_wind == 'afno':
self.model_wind = AFNONet(params).to(self.device)
else:
raise Exception("not implemented")
if dist.is_initialized():
self.model_wind = DistributedDataParallel(self.model_wind,
device_ids=[params.local_rank],
output_device=[params.local_rank],find_unused_parameters=True)
self.load_model_wind(params.model_wind_path)
self.switch_off_grad(self.model_wind) # no backprop through the wind model
# reset out_channels for precip models
if self.precip:
params['N_out_channels'] = len(params['out_channels'])
if params.nettype == 'afno':
self.model = AFNONet(params).to(self.device)
else:
raise Exception("not implemented")
# precip model
if self.precip:
self.model = PrecipNet(params, backbone=self.model).to(self.device)
if self.params.enable_nhwc:
# NHWC: Convert model to channels_last memory format
self.model = self.model.to(memory_format=torch.channels_last)
if params.log_to_wandb:
wandb.watch(self.model)
if params.optimizer_type == 'FusedAdam':
self.optimizer = optimizers.FusedAdam(self.model.parameters(), lr = params.lr)
else:
self.optimizer = torch.optim.Adam(self.model.parameters(), lr = params.lr)
if params.enable_amp == True:
self.gscaler = amp.GradScaler()
if dist.is_initialized():
self.model = DistributedDataParallel(self.model,
device_ids=[params.local_rank],
output_device=[params.local_rank],find_unused_parameters=True)
self.iters = 0
self.startEpoch = 0
if params.resuming:
logging.info("Loading checkpoint %s"%params.checkpoint_path)
self.restore_checkpoint(params.checkpoint_path)
if params.two_step_training:
if params.resuming == False and params.pretrained == True:
logging.info("Starting from pretrained one-step afno model at %s"%params.pretrained_ckpt_path)
self.restore_checkpoint(params.pretrained_ckpt_path)
self.iters = 0
self.startEpoch = 0
#logging.info("Pretrained checkpoint was trained for %d epochs"%self.startEpoch)
#logging.info("Adding %d epochs specified in config file for refining pretrained model"%self.params.max_epochs)
#self.params.max_epochs += self.startEpoch
self.epoch = self.startEpoch
if params.scheduler == 'ReduceLROnPlateau':
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.2, patience=5, mode='min')
elif params.scheduler == 'CosineAnnealingLR':
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=params.max_epochs, last_epoch=self.startEpoch-1)
else:
self.scheduler = None
'''if params.log_to_screen:
logging.info(self.model)'''
if params.log_to_screen:
logging.info("Number of trainable model parameters: {}".format(self.count_parameters()))
def switch_off_grad(self, model):
for param in model.parameters():
param.requires_grad = False
def train(self):
if self.params.log_to_screen:
logging.info("Starting Training Loop...")
best_valid_loss = 1.e6
for epoch in range(self.startEpoch, self.params.max_epochs):
if dist.is_initialized():
self.train_sampler.set_epoch(epoch)
# self.valid_sampler.set_epoch(epoch)
start = time.time()
tr_time, data_time, train_logs = self.train_one_epoch()
valid_time, valid_logs = self.validate_one_epoch()
if epoch==self.params.max_epochs-1 and self.params.prediction_type == 'direct':
valid_weighted_rmse = self.validate_final()
if self.params.scheduler == 'ReduceLROnPlateau':
self.scheduler.step(valid_logs['valid_loss'])
elif self.params.scheduler == 'CosineAnnealingLR':
self.scheduler.step()
if self.epoch >= self.params.max_epochs:
logging.info("Terminating training after reaching params.max_epochs while LR scheduler is set to CosineAnnealingLR")
exit()
if self.params.log_to_wandb:
for pg in self.optimizer.param_groups:
lr = pg['lr']
wandb.log({'lr': lr})
if self.world_rank == 0:
if self.params.save_checkpoint:
#checkpoint at the end of every epoch
self.save_checkpoint(self.params.checkpoint_path)
if valid_logs['valid_loss'] <= best_valid_loss:
#logging.info('Val loss improved from {} to {}'.format(best_valid_loss, valid_logs['valid_loss']))
self.save_checkpoint(self.params.best_checkpoint_path)
best_valid_loss = valid_logs['valid_loss']
if self.params.log_to_screen:
logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
#logging.info('train data time={}, train step time={}, valid step time={}'.format(data_time, tr_time, valid_time))
logging.info('Train loss: {}. Valid loss: {}'.format(train_logs['loss'], valid_logs['valid_loss']))
# if epoch==self.params.max_epochs-1 and self.params.prediction_type == 'direct':
# logging.info('Final Valid RMSE: Z500- {}. T850- {}, 2m_T- {}'.format(valid_weighted_rmse[0], valid_weighted_rmse[1], valid_weighted_rmse[2]))
def train_one_epoch(self):
self.epoch += 1
tr_time = 0
data_time = 0
self.model.train()
for i, data in enumerate(self.train_data_loader, 0):
self.iters += 1
# adjust_LR(optimizer, params, iters)
data_start = time.time()
inp, tar = map(lambda x: x.to(self.device, dtype = torch.float), data)
if self.params.orography and self.params.two_step_training:
orog = inp[:,-2:-1]
if self.params.enable_nhwc:
inp = inp.to(memory_format=torch.channels_last)
tar = tar.to(memory_format=torch.channels_last)
if 'residual_field' in self.params.target:
tar -= inp[:, 0:tar.size()[1]]
data_time += time.time() - data_start
tr_start = time.time()
self.model.zero_grad()
if self.params.two_step_training:
with amp.autocast(self.params.enable_amp):
gen_step_one = self.model(inp).to(self.device, dtype = torch.float)
loss_step_one = self.loss_obj(gen_step_one, tar[:,0:self.params.N_out_channels])
if self.params.orography:
gen_step_two = self.model(torch.cat( (gen_step_one, orog), axis = 1) ).to(self.device, dtype = torch.float)
else:
gen_step_two = self.model(gen_step_one).to(self.device, dtype = torch.float)
loss_step_two = self.loss_obj(gen_step_two, tar[:,self.params.N_out_channels:2*self.params.N_out_channels])
loss = loss_step_one + loss_step_two
else:
with amp.autocast(self.params.enable_amp):
if self.precip: # use a wind model to predict 17(+n) channels at t+dt
with torch.no_grad():
inp = self.model_wind(inp).to(self.device, dtype = torch.float)
gen = self.model(inp.detach()).to(self.device, dtype = torch.float)
else:
gen = self.model(inp).to(self.device, dtype = torch.float)
loss = self.loss_obj(gen, tar)
if self.params.enable_amp:
self.gscaler.scale(loss).backward()
self.gscaler.step(self.optimizer)
else:
loss.backward()
self.optimizer.step()
if self.params.enable_amp:
self.gscaler.update()
tr_time += time.time() - tr_start
try:
logs = {'loss': loss, 'loss_step_one': loss_step_one, 'loss_step_two': loss_step_two}
except:
logs = {'loss': loss}
if dist.is_initialized():
for key in sorted(logs.keys()):
dist.all_reduce(logs[key].detach())
logs[key] = float(logs[key]/dist.get_world_size())
if self.params.log_to_wandb:
wandb.log(logs, step=self.epoch)
return tr_time, data_time, logs
def validate_one_epoch(self):
self.model.eval()
n_valid_batches = 20 #do validation on first 20 images, just for LR scheduler
if self.params.normalization == 'minmax':
raise Exception("minmax normalization not supported")
elif self.params.normalization == 'zscore':
mult = torch.as_tensor(np.load(self.params.global_stds_path)[0, self.params.out_channels, 0, 0]).to(self.device)
valid_buff = torch.zeros((3), dtype=torch.float32, device=self.device)
valid_loss = valid_buff[0].view(-1)
valid_l1 = valid_buff[1].view(-1)
valid_steps = valid_buff[2].view(-1)
valid_weighted_rmse = torch.zeros((self.params.N_out_channels), dtype=torch.float32, device=self.device)
valid_weighted_acc = torch.zeros((self.params.N_out_channels), dtype=torch.float32, device=self.device)
valid_start = time.time()
sample_idx = np.random.randint(len(self.valid_data_loader))
with torch.no_grad():
for i, data in enumerate(self.valid_data_loader, 0):
if (not self.precip) and i>=n_valid_batches:
break
inp, tar = map(lambda x: x.to(self.device, dtype = torch.float), data)
if self.params.orography and self.params.two_step_training:
orog = inp[:,-2:-1]
if self.params.two_step_training:
gen_step_one = self.model(inp).to(self.device, dtype = torch.float)
loss_step_one = self.loss_obj(gen_step_one, tar[:,0:self.params.N_out_channels])
if self.params.orography:
gen_step_two = self.model(torch.cat( (gen_step_one, orog), axis = 1) ).to(self.device, dtype = torch.float)
else:
gen_step_two = self.model(gen_step_one).to(self.device, dtype = torch.float)
loss_step_two = self.loss_obj(gen_step_two, tar[:,self.params.N_out_channels:2*self.params.N_out_channels])
valid_loss += loss_step_one + loss_step_two
valid_l1 += nn.functional.l1_loss(gen_step_one, tar[:,0:self.params.N_out_channels])
else:
if self.precip:
with torch.no_grad():
inp = self.model_wind(inp).to(self.device, dtype = torch.float)
gen = self.model(inp.detach())
else:
gen = self.model(inp).to(self.device, dtype = torch.float)
valid_loss += self.loss_obj(gen, tar)
valid_l1 += nn.functional.l1_loss(gen, tar)
valid_steps += 1.
# save fields for vis before log norm
if (i == sample_idx) and (self.precip and self.params.log_to_wandb):
fields = [gen[0,0].detach().cpu().numpy(), tar[0,0].detach().cpu().numpy()]
if self.precip:
gen = unlog_tp_torch(gen, self.params.precip_eps)
tar = unlog_tp_torch(tar, self.params.precip_eps)
#direct prediction weighted rmse
if self.params.two_step_training:
if 'residual_field' in self.params.target:
valid_weighted_rmse += weighted_rmse_torch((gen_step_one + inp), (tar[:,0:self.params.N_out_channels] + inp))
else:
valid_weighted_rmse += weighted_rmse_torch(gen_step_one, tar[:,0:self.params.N_out_channels])
else:
if 'residual_field' in self.params.target:
valid_weighted_rmse += weighted_rmse_torch((gen + inp), (tar + inp))
else:
valid_weighted_rmse += weighted_rmse_torch(gen, tar)
if not self.precip:
try:
os.mkdir(params['experiment_dir'] + "/" + str(i))
except:
pass
#save first channel of image
if self.params.two_step_training:
save_image(torch.cat((gen_step_one[0,0], torch.zeros((self.valid_dataset.img_shape_x, 4)).to(self.device, dtype = torch.float), tar[0,0]), axis = 1), params['experiment_dir'] + "/" + str(i) + "/" + str(self.epoch) + ".png")
else:
save_image(torch.cat((gen[0,0], torch.zeros((self.valid_dataset.img_shape_x, 4)).to(self.device, dtype = torch.float), tar[0,0]), axis = 1), params['experiment_dir'] + "/" + str(i) + "/" + str(self.epoch) + ".png")
if dist.is_initialized():
dist.all_reduce(valid_buff)
dist.all_reduce(valid_weighted_rmse)
# divide by number of steps
valid_buff[0:2] = valid_buff[0:2] / valid_buff[2]
valid_weighted_rmse = valid_weighted_rmse / valid_buff[2]
if not self.precip:
valid_weighted_rmse *= mult
# download buffers
valid_buff_cpu = valid_buff.detach().cpu().numpy()
valid_weighted_rmse_cpu = valid_weighted_rmse.detach().cpu().numpy()
valid_time = time.time() - valid_start
valid_weighted_rmse = mult*torch.mean(valid_weighted_rmse, axis = 0)
if self.precip:
logs = {'valid_l1': valid_buff_cpu[1], 'valid_loss': valid_buff_cpu[0], 'valid_rmse_tp': valid_weighted_rmse_cpu[0]}
else:
try:
logs = {'valid_l1': valid_buff_cpu[1], 'valid_loss': valid_buff_cpu[0], 'valid_rmse_u10': valid_weighted_rmse_cpu[0], 'valid_rmse_v10': valid_weighted_rmse_cpu[1]}
except:
logs = {'valid_l1': valid_buff_cpu[1], 'valid_loss': valid_buff_cpu[0], 'valid_rmse_u10': valid_weighted_rmse_cpu[0]}#, 'valid_rmse_v10': valid_weighted_rmse[1]}
if self.params.log_to_wandb:
if self.precip:
fig = vis_precip(fields)
logs['vis'] = wandb.Image(fig)
plt.close(fig)
wandb.log(logs, step=self.epoch)
return valid_time, logs
def validate_final(self):
self.model.eval()
n_valid_batches = int(self.valid_dataset.n_patches_total/self.valid_dataset.n_patches) #validate on whole dataset
valid_weighted_rmse = torch.zeros(n_valid_batches, self.params.N_out_channels)
if self.params.normalization == 'minmax':
raise Exception("minmax normalization not supported")
elif self.params.normalization == 'zscore':
mult = torch.as_tensor(np.load(self.params.global_stds_path)[0, self.params.out_channels, 0, 0]).to(self.device)
with torch.no_grad():
for i, data in enumerate(self.valid_data_loader):
if i>100:
break
inp, tar = map(lambda x: x.to(self.device, dtype = torch.float), data)
if self.params.orography and self.params.two_step_training:
orog = inp[:,-2:-1]
if 'residual_field' in self.params.target:
tar -= inp[:, 0:tar.size()[1]]
if self.params.two_step_training:
gen_step_one = self.model(inp).to(self.device, dtype = torch.float)
loss_step_one = self.loss_obj(gen_step_one, tar[:,0:self.params.N_out_channels])
if self.params.orography:
gen_step_two = self.model(torch.cat( (gen_step_one, orog), axis = 1) ).to(self.device, dtype = torch.float)
else:
gen_step_two = self.model(gen_step_one).to(self.device, dtype = torch.float)
loss_step_two = self.loss_obj(gen_step_two, tar[:,self.params.N_out_channels:2*self.params.N_out_channels])
valid_loss[i] = loss_step_one + loss_step_two
valid_l1[i] = nn.functional.l1_loss(gen_step_one, tar[:,0:self.params.N_out_channels])
else:
gen = self.model(inp)
valid_loss[i] += self.loss_obj(gen, tar)
valid_l1[i] += nn.functional.l1_loss(gen, tar)
if self.params.two_step_training:
for c in range(self.params.N_out_channels):
if 'residual_field' in self.params.target:
valid_weighted_rmse[i, c] = weighted_rmse_torch((gen_step_one[0,c] + inp[0,c]), (tar[0,c]+inp[0,c]), self.device)
else:
valid_weighted_rmse[i, c] = weighted_rmse_torch(gen_step_one[0,c], tar[0,c], self.device)
else:
for c in range(self.params.N_out_channels):
if 'residual_field' in self.params.target:
valid_weighted_rmse[i, c] = weighted_rmse_torch((gen[0,c] + inp[0,c]), (tar[0,c]+inp[0,c]), self.device)
else:
valid_weighted_rmse[i, c] = weighted_rmse_torch(gen[0,c], tar[0,c], self.device)
#un-normalize
valid_weighted_rmse = mult*torch.mean(valid_weighted_rmse[0:100], axis = 0).to(self.device)
return valid_weighted_rmse
def load_model_wind(self, model_path):
if self.params.log_to_screen:
logging.info('Loading the wind model weights from {}'.format(model_path))
checkpoint = torch.load(model_path, map_location='cuda:{}'.format(self.params.local_rank))
if dist.is_initialized():
self.model_wind.load_state_dict(checkpoint['model_state'])
else:
new_model_state = OrderedDict()
model_key = 'model_state' if 'model_state' in checkpoint else 'state_dict'
for key in checkpoint[model_key].keys():
if 'module.' in key: # model was stored using ddp which prepends module
name = str(key[7:])
new_model_state[name] = checkpoint[model_key][key]
else:
new_model_state[key] = checkpoint[model_key][key]
self.model_wind.load_state_dict(new_model_state)
self.model_wind.eval()
def save_checkpoint(self, checkpoint_path, model=None):
""" We intentionally require a checkpoint_dir to be passed
in order to allow Ray Tune to use this function """
if not model:
model = self.model
torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path)
def restore_checkpoint(self, checkpoint_path):
""" We intentionally require a checkpoint_dir to be passed
in order to allow Ray Tune to use this function """
checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.params.local_rank))
try:
self.model.load_state_dict(checkpoint['model_state'])
except:
new_state_dict = OrderedDict()
for key, val in checkpoint['model_state'].items():
name = key[7:]
new_state_dict[name] = val
self.model.load_state_dict(new_state_dict)
self.iters = checkpoint['iters']
self.startEpoch = checkpoint['epoch']
if self.params.resuming: #restore checkpoint is used for finetuning as well as resuming. If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr.
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--run_num", default='00', type=str)
parser.add_argument("--yaml_config", default='./config/AFNO.yaml', type=str)
parser.add_argument("--config", default='default', type=str)
parser.add_argument("--enable_amp", action='store_true')
parser.add_argument("--epsilon_factor", default = 0, type = float)
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
params['epsilon_factor'] = args.epsilon_factor
params['world_size'] = 1
if 'WORLD_SIZE' in os.environ:
params['world_size'] = int(os.environ['WORLD_SIZE'])
world_rank = 0
local_rank = 0
if params['world_size'] > 1:
dist.init_process_group(backend='nccl',
init_method='env://')
local_rank = int(os.environ["LOCAL_RANK"])
args.gpu = local_rank
world_rank = dist.get_rank()
params['global_batch_size'] = params.batch_size
params['batch_size'] = int(params.batch_size//params['world_size'])
torch.cuda.set_device(local_rank)
torch.backends.cudnn.benchmark = True
# Set up directory
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
if world_rank==0:
if not os.path.isdir(expDir):
os.makedirs(expDir)
os.makedirs(os.path.join(expDir, 'training_checkpoints/'))
params['experiment_dir'] = os.path.abspath(expDir)
params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
# Do not comment this line out please:
args.resuming = True if os.path.isfile(params.checkpoint_path) else False
params['resuming'] = args.resuming
params['local_rank'] = local_rank
params['enable_amp'] = args.enable_amp
# this will be the wandb name
# params['name'] = args.config + '_' + str(args.run_num)
# params['group'] = "era5_wind" + args.config
params['name'] = args.config + '_' + str(args.run_num)
params['group'] = "era5_precip" + args.config
params['project'] = "ERA5_precip"
params['entity'] = "flowgan"
if world_rank==0:
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log'))
logging_utils.log_versions()
params.log()
params['log_to_wandb'] = (world_rank==0) and params['log_to_wandb']
params['log_to_screen'] = (world_rank==0) and params['log_to_screen']
params['in_channels'] = np.array(params['in_channels'])
params['out_channels'] = np.array(params['out_channels'])
if params.orography:
params['N_in_channels'] = len(params['in_channels']) +1
else:
params['N_in_channels'] = len(params['in_channels'])
params['N_out_channels'] = len(params['out_channels'])
if world_rank == 0:
hparams = ruamelDict()
yaml = YAML()
for key, value in params.params.items():
hparams[str(key)] = str(value)
with open(os.path.join(expDir, 'hyperparams.yaml'), 'w') as hpfile:
yaml.dump(hparams, hpfile )
trainer = Trainer(params, world_rank)
trainer.train()
logging.info('DONE ---- rank %d'%world_rank)
from ruamel.yaml import YAML
import logging
class YParams():
""" Yaml file parser """
def __init__(self, yaml_filename, config_name, print_params=False):
self._yaml_filename = yaml_filename
self._config_name = config_name
self.params = {}
if print_params:
print("------------------ Configuration ------------------")
with open(yaml_filename) as _file:
for key, val in YAML().load(_file)[config_name].items():
if print_params: print(key, val)
if val =='None': val = None
self.params[key] = val
self.__setattr__(key, val)
if print_params:
print("---------------------------------------------------")
def __getitem__(self, key):
return self.params[key]
def __setitem__(self, key, val):
self.params[key] = val
self.__setattr__(key, val)
def __contains__(self, key):
return (key in self.params)
def update_params(self, config):
for key, val in config.items():
self.params[key] = val
self.__setattr__(key, val)
def log(self):
logging.info("------------------ Configuration ------------------")
logging.info("Configuration file: "+str(self._yaml_filename))
logging.info("Configuration name: "+str(self._config_name))
for key, val in self.params.items():
logging.info(str(key) + ' ' + str(val))
logging.info("---------------------------------------------------")
#MIT License
#
#Copyright (c) 2020 Zongyi Li
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.
import torch
import numpy as np
import scipy.io
import h5py
import torch.nn as nn
#################################################
#
# Utilities
#
#################################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# reading data
class MatReader(object):
def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True):
super(MatReader, self).__init__()
self.to_torch = to_torch
self.to_cuda = to_cuda
self.to_float = to_float
self.file_path = file_path
self.data = None
self.old_mat = None
self._load_file()
def _load_file(self):
try:
self.data = scipy.io.loadmat(self.file_path)
self.old_mat = True
except:
self.data = h5py.File(self.file_path)
self.old_mat = False
def load_file(self, file_path):
self.file_path = file_path
self._load_file()
def read_field(self, field):
x = self.data[field]
if not self.old_mat:
x = x[()]
x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1))
if self.to_float:
x = x.astype(np.float32)
if self.to_torch:
x = torch.from_numpy(x)
if self.to_cuda:
x = x.cuda()
return x
def set_cuda(self, to_cuda):
self.to_cuda = to_cuda
def set_torch(self, to_torch):
self.to_torch = to_torch
def set_float(self, to_float):
self.to_float = to_float
# normalization, pointwise gaussian
class UnitGaussianNormalizer(object):
def __init__(self, x, eps=0.00001):
super(UnitGaussianNormalizer, self).__init__()
# x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
self.mean = torch.mean(x, 0)
self.std = torch.std(x, 0)
self.eps = eps
def encode(self, x):
x = (x - self.mean) / (self.std + self.eps)
return x.float()
def decode(self, x, sample_idx=None):
if sample_idx is None:
std = self.std + self.eps # n
mean = self.mean
else:
if len(self.mean.shape) == len(sample_idx[0].shape):
std = self.std[sample_idx] + self.eps # batch*n
mean = self.mean[sample_idx]
if len(self.mean.shape) > len(sample_idx[0].shape):
std = self.std[:,sample_idx]+ self.eps # T*batch*n
mean = self.mean[:,sample_idx]
# x is in shape of batch*n or T*batch*n
x = (x * std) + mean
return x.float()
def cuda(self):
self.mean = self.mean.cuda()
self.std = self.std.cuda()
def cpu(self):
self.mean = self.mean.cpu()
self.std = self.std.cpu()
# normalization, Gaussian
class GaussianNormalizer(object):
def __init__(self, x, eps=0.00001):
super(GaussianNormalizer, self).__init__()
self.mean = torch.mean(x)
self.std = torch.std(x)
self.eps = eps
def encode(self, x):
x = (x - self.mean) / (self.std + self.eps)
return x
def decode(self, x, sample_idx=None):
x = (x * (self.std + self.eps)) + self.mean
return x
def cuda(self):
self.mean = self.mean.cuda()
self.std = self.std.cuda()
def cpu(self):
self.mean = self.mean.cpu()
self.std = self.std.cpu()
# normalization, scaling by range
class RangeNormalizer(object):
def __init__(self, x, low=0.0, high=1.0):
super(RangeNormalizer, self).__init__()
mymin = torch.min(x, 0)[0].view(-1)
mymax = torch.max(x, 0)[0].view(-1)
self.a = (high - low)/(mymax - mymin)
self.b = -self.a*mymax + high
def encode(self, x):
s = x.size()
x = x.view(s[0], -1)
x = self.a*x + self.b
x = x.view(s)
return x
def decode(self, x):
s = x.size()
x = x.view(s[0], -1)
x = (x - self.b)/self.a
x = x.view(s)
return x
#loss function with rel/abs Lp loss
class LpLoss(object):
def __init__(self, d=2, p=2, size_average=True, reduction=True):
super(LpLoss, self).__init__()
#Dimension and Lp-norm type are postive
assert d > 0 and p > 0
self.d = d
self.p = p
self.reduction = reduction
self.size_average = size_average
def abs(self, x, y):
num_examples = x.size()[0]
#Assume uniform mesh
h = 1.0 / (x.size()[1] - 1.0)
all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)
if self.reduction:
if self.size_average:
return torch.mean(all_norms)
else:
return torch.sum(all_norms)
return all_norms
def rel(self, x, y):
num_examples = x.size()[0]
diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)
if self.reduction:
if self.size_average:
return torch.mean(diff_norms/y_norms)
else:
return torch.sum(diff_norms/y_norms)
return diff_norms/y_norms
def __call__(self, x, y):
return self.rel(x, y)
# Sobolev norm (HS norm)
# where we also compare the numerical derivatives between the output and target
class HsLoss(object):
def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True):
super(HsLoss, self).__init__()
#Dimension and Lp-norm type are postive
assert d > 0 and p > 0
self.d = d
self.p = p
self.k = k
self.balanced = group
self.reduction = reduction
self.size_average = size_average
if a == None:
a = [1,] * k
self.a = a
def rel(self, x, y):
num_examples = x.size()[0]
diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)
if self.reduction:
if self.size_average:
return torch.mean(diff_norms/y_norms)
else:
return torch.sum(diff_norms/y_norms)
return diff_norms/y_norms
def __call__(self, x, y, a=None):
nx = x.size()[1]
ny = x.size()[2]
k = self.k
balanced = self.balanced
a = self.a
x = x.view(x.shape[0], nx, ny, -1)
y = y.view(y.shape[0], nx, ny, -1)
k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny)
k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1)
k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device)
k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device)
x = torch.fft.fftn(x, dim=[1, 2])
y = torch.fft.fftn(y, dim=[1, 2])
if balanced==False:
weight = 1
if k >= 1:
weight += a[0]**2 * (k_x**2 + k_y**2)
if k >= 2:
weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4)
weight = torch.sqrt(weight)
loss = self.rel(x*weight, y*weight)
else:
loss = self.rel(x, y)
if k >= 1:
weight = a[0] * torch.sqrt(k_x**2 + k_y**2)
loss += self.rel(x*weight, y*weight)
if k >= 2:
weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4)
loss += self.rel(x*weight, y*weight)
loss = loss / (k+1)
return loss
# A simple feedforward neural network
class DenseNet(torch.nn.Module):
def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False):
super(DenseNet, self).__init__()
self.n_layers = len(layers) - 1
assert self.n_layers >= 1
self.layers = nn.ModuleList()
for j in range(self.n_layers):
self.layers.append(nn.Linear(layers[j], layers[j+1]))
if j != self.n_layers - 1:
if normalize:
self.layers.append(nn.BatchNorm1d(layers[j+1]))
self.layers.append(nonlinearity())
if out_nonlinearity is not None:
self.layers.append(out_nonlinearity())
def forward(self, x):
for _, l in enumerate(self.layers):
x = l(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment