Commit bdd87fae authored by zhangwenbo's avatar zhangwenbo
Browse files

Initial commit: FourCastNet source code only

parents
This diff is collapsed.
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import os
import time
import numpy as np
import argparse
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
from numpy.core.numeric import False_
import h5py
import torch
import torchvision
from torchvision.utils import save_image
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
import logging
from utils import logging_utils
from utils.weighted_acc_rmse import weighted_rmse_torch_channels, weighted_acc_torch_channels
logging_utils.config_logger()
from utils.YParams import YParams
from utils.data_loader_multifiles import get_data_loader
from networks.afnonet import AFNONet
import wandb
import matplotlib.pyplot as plt
import glob
fld = "z500" # diff flds have diff decor times and hence differnt ics
if fld == "z500" or fld == "2m_temperature" or fld == "t850":
DECORRELATION_TIME = 36 # 9 days (36) for z500, 2 (8 steps) days for u10, v10
else:
DECORRELATION_TIME = 8 # 9 days (36) for z500, 2 (8 steps) days for u10, v10
idxes = {"u10":0, "z500":14, "tp":0}
def gaussian_perturb(x, level=0.01, device=0):
noise = level * torch.randn(x.shape).to(device, dtype=torch.float)
return (x + noise)
def load_model(model, params, checkpoint_file): #, local_rank):
model.zero_grad()
checkpoint = torch.load(checkpoint_file)
try:
new_state_dict = OrderedDict()
for key, val in checkpoint['model_state'].items():
name = key[7:]
if name != 'ged':
new_state_dict[name] = val
model.load_state_dict(new_state_dict)
except:
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model
def downsample(x, scale=0.125):
return torch.nn.functional.interpolate(x, scale_factor=scale, mode='bilinear')
def setup(params):
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
#get data loader
valid_data_loader, valid_dataset = get_data_loader(params, params.inf_data_path, dist.is_initialized(), train=False)
img_shape_x = valid_dataset.img_shape_x
img_shape_y = valid_dataset.img_shape_y
params.img_shape_x = img_shape_x
params.img_shape_y = img_shape_y
if params.log_to_screen:
logging.info('Loading trained model checkpoint from {}'.format(params['best_checkpoint_path']))
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
params['N_in_channels'] = n_in_channels
params['N_out_channels'] = n_out_channels
params.means = np.load(params.global_means_path)[0, out_channels] # needed to standardize wind data
params.stds = np.load(params.global_stds_path)[0, out_channels]
# load the model
if params.nettype == 'afno':
model = AFNONet(params).to(device)
checkpoint_file = params['best_checkpoint_path']
model = load_model(model, params, checkpoint_file)
model = model.to(device)
# load the validation data
files_paths = glob.glob(params.inf_data_path + "/*.h5")
files_paths.sort()
if params.log_to_screen:
logging.info('Loading validation data')
logging.info('Validation data from {}'.format(files_paths[0]))
# which year
yr = 0
valid_data_full = h5py.File(files_paths[yr], 'r')['fields']
return valid_data_full, model
def autoregressive_inference(params, ic, valid_data_full, model):
ic = int(ic)
#initialize global variables
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
exp_dir = params['experiment_dir']
dt = int(params.dt)
prediction_length = int(params.prediction_length/dt)
n_history = params.n_history
img_shape_x = params.img_shape_x
img_shape_y = params.img_shape_y
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
means = params.means
stds = params.stds
n_pert = params.n_pert
#initialize memory for image sequences and RMSE/ACC
valid_loss = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
seq_real = torch.zeros((prediction_length+n_history, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred = torch.zeros((prediction_length+n_history, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
valid_data = valid_data_full[ic:(ic+prediction_length*dt+n_history*dt):dt, in_channels, 0:720] #extract valid data from first year
# standardize
valid_data = (valid_data - means)/stds
valid_data = torch.as_tensor(valid_data).to(device, dtype=torch.float)
#load time means
m = torch.as_tensor((np.load(params.time_means_path)[0][out_channels] - means)/stds)[:, 0:img_shape_x] # climatology
m = torch.unsqueeze(m, 0)
m = m.to(device)
std = torch.as_tensor(stds[:,0,0]).to(device, dtype=torch.float)
#autoregressive inference
if params.log_to_screen:
logging.info('Begin autoregressive inference')
for pert in range(n_pert):
logging.info('Running ensemble {}/{}'.format(pert+1, n_pert))
with torch.no_grad():
for i in range(valid_data.shape[0]):
if i==0: #start of sequence
first = valid_data[0:n_history+1]
future = valid_data[n_history+1]
for h in range(n_history+1):
seq_real[h] = first[h*n_in_channels : (h+1)*n_in_channels][0:n_out_channels] #extract history from 1st
seq_pred[h] = seq_real[h]
first = gaussian_perturb(first, level=params.n_level, device=device) # perturb the ic
future_pred = model(first)
else:
if i < prediction_length-1:
future = valid_data[n_history+i+1]
future_pred = model(future_pred) #autoregressive step
if i < prediction_length-1: #not on the last step
seq_pred[n_history+i+1] += torch.squeeze(future_pred,0) # add up predictions and average later
seq_real[n_history+i+1] += future
#Compute metrics
for i in range(valid_data.shape[0]):
if i>0:
# avg images
seq_pred[i] /= n_pert
seq_real[i] /= n_pert
pred = torch.unsqueeze(seq_pred[i], 0)
tar = torch.unsqueeze(seq_real[i], 0)
valid_loss[i] = weighted_rmse_torch_channels(pred, tar) * std
acc[i] = weighted_acc_torch_channels(pred-m, tar-m)
if params.log_to_screen:
idx = idxes[fld]
logging.info('Predicted timestep {} of {}. {} RMS Error: {}, ACC: {}'.format(i, prediction_length, fld, valid_loss[i, idx], acc[i, idx]))
valid_loss = valid_loss.cpu().numpy()
acc = acc.cpu().numpy()
return np.expand_dims(valid_loss,0), np.expand_dims(acc, 0)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--run_num", default='00', type=str)
parser.add_argument("--yaml_config", default='./config/AFNO.yaml', type=str)
parser.add_argument("--config", default='full_field', type=str)
parser.add_argument("--override_dir", default=None, type = str, help = 'Path to store inference outputs; must also set --weights arg')
parser.add_argument("--n_pert", default=100, type=int)
parser.add_argument("--n_level", default=0.3, type=float)
parser.add_argument("--weights", default=None, type=str, help = 'Path to model weights, for use with override_dir option')
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
params['world_size'] = 1
if 'WORLD_SIZE' in os.environ:
params['world_size'] = int(os.environ['WORLD_SIZE'])
params['n_pert'] = args.n_pert
world_rank = 0
world_size = params.world_size
params['global_batch_size'] = params.batch_size
local_rank = 0
if params['world_size'] > 1:
local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group(backend='nccl',
init_method='env://')
args.gpu = local_rank
world_rank = dist.get_rank()
world_size = dist.get_world_size()
params['global_batch_size'] = params.batch_size
#params['batch_size'] = int(params.batch_size//params['world_size'])
torch.cuda.set_device(local_rank)
torch.backends.cudnn.benchmark = True
# Set up directory
if args.override_dir is not None:
assert args.weights is not None, 'Must set --weights argument if using --override_dir'
expDir = args.override_dir
else:
assert args.weights is None, 'Cannot use --weights argument without also using --override_dir'
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
if world_rank==0:
if not os.path.isdir(expDir):
os.makedirs(expDir)
params['experiment_dir'] = os.path.abspath(expDir)
params['best_checkpoint_path'] = args.weights if args.override_dir is not None else os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
params['resuming'] = False
params['local_rank'] = local_rank
# this will be the wandb name
params['name'] = args.config + '_' + str(args.run_num)
params['group'] = args.config
if world_rank==0:
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference_out.log'))
logging_utils.log_versions()
params.log()
params['log_to_wandb'] = (world_rank==0) and params['log_to_wandb']
params['log_to_screen'] = (world_rank==0) and params['log_to_screen']
n_ics = params['n_initial_conditions']
if fld == "z500" or fld == "t850":
n_samples_per_year = 1336
else:
n_samples_per_year = 1460
if params["ics_type"] == 'default':
num_samples = n_samples_per_year-params.prediction_length
stop = num_samples
ics = np.arange(0, stop, DECORRELATION_TIME)
n_ics = len(ics)
logging.info("Inference for {} initial conditions".format(n_ics))
elif params["ics_type"] == "datetime":
date_strings = params["date_strings"]
ics = []
for date in date_strings:
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
ics.append(int(hours_since_jan_01_epoch/6))
print(ics)
n_ics = len(ics)
try:
autoregressive_inference_filetag = params["inference_file_tag"]
except:
autoregressive_inference_filetag = ""
params.n_level = args.n_level
autoregressive_inference_filetag += "_" + str(params.n_level) + "_" + str(params.n_pert) + "ens_" + fld
# get data and models
valid_data_full, model = setup(params)
#initialize lists for image sequences and RMSE/ACC
valid_loss = np.zeros
acc = []
# run autoregressive inference for multiple initial conditions
# parallelize over initial conditions
if world_size > 1:
tot_ics = len(ics)
ics_per_proc = n_ics//world_size
ics = ics[ics_per_proc*world_rank:ics_per_proc*(world_rank+1)] if world_rank < world_size - 1 else ics[(world_size - 1)*ics_per_proc:]
n_ics = len(ics)
logging.info('Rank %d running ics %s'%(world_rank, str(ics)))
for i, ic in enumerate(ics):
t0 = time.time()
logging.info("Initial condition {} of {}".format(i+1, n_ics))
vl, a = autoregressive_inference(params, ic, valid_data_full, model)
if i ==0:
valid_loss = vl
acc = a
else:
valid_loss = np.concatenate((valid_loss, vl), 0)
acc = np.concatenate((acc, a), 0)
t1 = time.time() - t0
logging.info("Time for inference for ic {} = {}".format(i, t1))
prediction_length = acc[0].shape[0]
n_out_channels = acc[0].shape[1]
#save predictions and loss
h5name = os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions'+ autoregressive_inference_filetag +'.h5')
if dist.is_initialized():
if params.log_to_screen:
logging.info("Saving files at {}".format(h5name))
logging.info("array shapes: %s"%str((tot_ics, prediction_length, n_out_channels)))
dist.barrier()
from mpi4py import MPI
with h5py.File(h5name, 'a', driver='mpio', comm=MPI.COMM_WORLD) as f:
if "rmse" in f.keys() or "acc" in f.keys():
del f["acc"]
del f["rmse"]
f.create_dataset("rmse", shape = (tot_ics, prediction_length, n_out_channels), dtype =np.float32)
f.create_dataset("acc", shape = (tot_ics, prediction_length, n_out_channels), dtype =np.float32)
start = world_rank*ics_per_proc
f["rmse"][start:start+n_ics] = valid_loss
f["acc"][start:start+n_ics] = acc
dist.barrier()
else:
if params.log_to_screen:
logging.info("Saving files at {}".format(os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions' + autoregressive_inference_filetag + '.h5')))
with h5py.File(os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions'+ autoregressive_inference_filetag +'.h5'), 'a') as f:
try:
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["rmse"]
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["rmse"][...] = valid_loss
try:
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc"]
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc"][...] = acc
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import os
import time
import numpy as np
import argparse
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
from numpy.core.numeric import False_
import h5py
import torch
import torchvision
from torchvision.utils import save_image
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
import logging
from utils import logging_utils
from utils.weighted_acc_rmse import weighted_rmse_torch_channels, weighted_acc_torch_channels, unlog_tp_torch, top_quantiles_error_torch
logging_utils.config_logger()
from utils.YParams import YParams
from utils.data_loader_multifiles import get_data_loader
from networks.afnonet import AFNONet, PrecipNet
import wandb
import matplotlib.pyplot as plt
import glob
from datetime import datetime
DECORRELATION_TIME = 8 # 2 days for preicp
def gaussian_perturb(x, level=0.01, device=0):
noise = level * torch.randn(x.shape).to(device, dtype=torch.float)
return (x + noise)
def load_model(model, params, checkpoint_file):
model.zero_grad()
checkpoint = torch.load(checkpoint_file)
try:
new_state_dict = OrderedDict()
for key, val in checkpoint['model_state'].items():
name = key[7:]
if name != 'ged':
new_state_dict[name] = val
model.load_state_dict(new_state_dict)
except:
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model
def setup(params):
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
#get data loader
valid_data_loader, valid_dataset = get_data_loader(params, params.inf_data_path, dist.is_initialized(), train=False)
img_shape_x = valid_dataset.img_shape_x
img_shape_y = valid_dataset.img_shape_y
params.img_shape_x = img_shape_x
params.img_shape_y = img_shape_y
if params.log_to_screen:
logging.info('Loading trained model checkpoint from {}'.format(params['best_checkpoint_path']))
in_channels = np.array(params.in_channels)
out_channels = np.array(params.in_channels) # same as in for the wind model
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
params['N_in_channels'] = n_in_channels
params['N_out_channels'] = n_out_channels
params.means = np.load(params.global_means_path)[0, out_channels] # needed to standardize wind data
params.stds = np.load(params.global_stds_path)[0, out_channels]
# load the models
# load wind model
if params.nettype_wind == 'afno':
model_wind = AFNONet(params).to(device)
checkpoint_file = params['model_wind_path']
model_wind = load_model(model_wind, params, checkpoint_file)
model_wind = model_wind.to(device)
out_channels = np.array(params.out_channels) # change out channels for precip model
n_out_channels = len(out_channels)
params['N_out_channels'] = n_out_channels
if params.nettype == 'afno':
model = AFNONet(params).to(device)
model = PrecipNet(params, backbone=model).to(device)
checkpoint_file = params['best_checkpoint_path']
model = load_model(model, params, checkpoint_file)
model = model.to(device)
# load the validation data
files_paths = glob.glob(params.inf_data_path + "/*.h5")
files_paths.sort()
if params.log_to_screen:
logging.info('Loading validation data')
logging.info('Validation data from {}'.format(files_paths[0]))
valid_data_full = h5py.File(files_paths[0], 'r')['fields']
# precip paths
path = params.precip + '/out_of_sample'
precip_paths = glob.glob(path + "/*.h5")
precip_paths.sort()
if params.log_to_screen:
logging.info('Loading validation precip data')
logging.info('Validation data from {}'.format(precip_paths[0]))
valid_data_tp_full = h5py.File(precip_paths[0], 'r')['tp']
return valid_data_full, valid_data_tp_full, model_wind, model
def autoregressive_inference(params, ic, valid_data_full, valid_data_tp_full, model_wind, model):
ic = int(ic)
#initialize global variables
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
exp_dir = params['experiment_dir']
dt = int(params.dt)
prediction_length = int(params.prediction_length/dt)
n_history = params.n_history
img_shape_x = params.img_shape_x
img_shape_y = params.img_shape_y
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
means = params.means
stds = params.stds
n_pert = params.n_pert
#initialize memory for image sequences and RMSE/ACC
valid_loss = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc_unweighted = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
tqe = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
# wind seqs
seq_real = torch.zeros((prediction_length+n_history, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred = torch.zeros((prediction_length+n_history, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
# precip sequences
seq_real_tp = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred_tp = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
valid_data = valid_data_full[ic:(ic+prediction_length*dt+n_history*dt):dt, in_channels, 0:720] #extract valid data from first year
# standardize
valid_data = (valid_data - means)/stds
valid_data = torch.as_tensor(valid_data).to(device, dtype=torch.float)
len_ic = prediction_length*dt
valid_data_tp = valid_data_tp_full[ic:(ic+prediction_length*dt):dt, 0:720].reshape(len_ic,n_out_channels,720,img_shape_y) #extract valid data from first year
# log normalize
eps = params.precip_eps
valid_data_tp = np.log1p(valid_data_tp/eps)
valid_data_tp = torch.as_tensor(valid_data_tp).to(device, dtype=torch.float)
#load time means
m = torch.as_tensor(np.load(params.time_means_path_tp)[0][out_channels])[:, 0:img_shape_x] # climatology
m = torch.unsqueeze(m, 0)
m = m.to(device)
if params.log_to_screen:
logging.info('Begin autoregressive+tp inference')
for pert in range(n_pert):
logging.info('Running ensemble {}/{}'.format(pert+1, n_pert))
with torch.no_grad():
for i in range(valid_data.shape[0]):
if i==0: #start of sequence
first = valid_data[0:n_history+1]
first_tp = valid_data_tp[0:1]
future = valid_data[n_history+1]
future_tp = valid_data_tp[1]
for h in range(n_history+1):
seq_real[h] = first[h*n_in_channels:(h+1)*n_in_channels][0:n_in_channels] #extract history from 1st
seq_pred[h] = seq_real[h]
seq_real_tp[0] = unlog_tp_torch(first_tp)
seq_pred_tp[0] = unlog_tp_torch(first_tp)
first = gaussian_perturb(first, level=params.n_level, device=device) # perturb the ic
future_pred = model_wind(first)
future_pred_tp = model(future_pred)
else:
if i < prediction_length-1:
future = valid_data[n_history+i+1]
future_tp = valid_data_tp[i+1]
future_pred = model_wind(future_pred) #autoregressive step
future_pred_tp = model(future_pred) # tp diagnosis
if i < prediction_length-1: #not on the last step
seq_pred_tp[n_history+i+1] += unlog_tp_torch(torch.squeeze(future_pred_tp,0)) # add up predictions and average later
seq_real_tp[n_history+i+1] += unlog_tp_torch(future_tp)
#Compute metrics
for i in range(valid_data.shape[0]):
if i>0:
# avg images
seq_pred_tp[i] /= n_pert
seq_real_tp[i] /= n_pert
pred = torch.unsqueeze(seq_pred_tp[i], 0)
tar = torch.unsqueeze(seq_real_tp[i], 0)
valid_loss[i] = weighted_rmse_torch_channels(pred, tar)
acc[i] = weighted_acc_torch_channels(pred-m, tar-m)
tqe[i] = top_quantiles_error_torch(pred, tar)
if params.log_to_screen:
logging.info('Timestep {} of {}. TP RMS Error: {}, ACC: {}'.format((i), prediction_length, valid_loss[i,0], acc[i,0]))
seq_real_tp = seq_real_tp.cpu().numpy()
seq_pred_tp = seq_pred_tp.cpu().numpy()
valid_loss = valid_loss.cpu().numpy()
acc = acc.cpu().numpy()
acc_unweighted = acc_unweighted.cpu().numpy()
tqe = tqe.cpu().numpy()
return np.expand_dims(valid_loss, 0), \
np.expand_dims(acc, 0), \
np.expand_dims(tqe, 0)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--run_num", default='00', type=str)
parser.add_argument("--yaml_config", default='./config/AFNO.yaml', type=str)
parser.add_argument("--config", default='full_field', type=str)
parser.add_argument("--n_level", default = 0.3, type = float)
parser.add_argument("--n_pert", default=100, type=int)
parser.add_argument("--override_dir", default=None, type = str, help = 'Path to store inference outputs; must also set --weights arg')
parser.add_argument("--weights", default=None, type=str, help = 'Path to model weights, for use with override_dir option')
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
params['world_size'] = 1
if 'WORLD_SIZE' in os.environ:
params['world_size'] = int(os.environ['WORLD_SIZE'])
world_rank = 0
local_rank = 0
if params['world_size'] > 1:
local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group(backend='nccl', init_method='env://')
args.gpu = local_rank
world_rank = dist.get_rank()
world_size = dist.get_world_size()
params['global_batch_size'] = params.batch_size
# params['batch_size'] = int(params.batch_size//params['world_size'])
torch.cuda.set_device(local_rank)
torch.backends.cudnn.benchmark = True
# Set up directory
if args.override_dir is not None:
assert args.weights is not None, 'Must set --weights argument if using --override_dir'
expDir = args.override_dir
else:
assert args.weights is None, 'Cannot use --weights argument without also using --override_dir'
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
if world_rank==0:
if not os.path.isdir(expDir):
os.makedirs(expDir)
params['experiment_dir'] = os.path.abspath(expDir)
params['best_checkpoint_path'] = args.weights if args.override_dir is not None else os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
args.resuming = False
params['resuming'] = args.resuming
params['local_rank'] = local_rank
# this will be the wandb name
params['name'] = args.config + '_' + str(args.run_num)
params['group'] = args.config
if world_rank==0:
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference_out.log'))
logging_utils.log_versions()
params.log()
params['log_to_wandb'] = (world_rank==0) and params['log_to_wandb']
params['log_to_screen'] = (world_rank==0) and params['log_to_screen']
n_ics = params['n_initial_conditions']
if params["ics_type"] == 'default':
num_samples = 1460 - params.prediction_length
stop = num_samples
ics = np.arange(0, stop, DECORRELATION_TIME)
n_ics = len(ics)
logging.info("Inference for {} initial conditions".format(n_ics))
elif params["ics_type"] == "datetime":
date_strings = params["date_strings"]
ics = []
for date in date_strings:
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
ics.append(int(hours_since_jan_01_epoch/6))
try:
autoregressive_inference_filetag = params["inference_file_tag"]
except:
autoregressive_inference_filetag = ""
n_level = args.n_level
params.n_level = n_level
params.n_pert = args.n_pert
logging.info("Doing level = {}".format(n_level))
autoregressive_inference_filetag += "_" + str(params.n_level) + "_" + str(params.n_pert) + "ens_tp"
# get data and models
valid_data_full, valid_data_tp_full, model_wind, model = setup(params)
#initialize lists for image sequences and RMSE/ACC
valid_loss = np.zeros
acc = []
tqe = []
# run autoregressive inference for multiple initial conditions
# parallelize over initial conditions
if world_size > 1:
tot_ics = len(ics)
ics_per_proc = n_ics//world_size
ics = ics[ics_per_proc*world_rank:ics_per_proc*(world_rank+1)] if world_rank < world_size - 1 else ics[(world_size - 1)*ics_per_proc:]
n_ics = len(ics)
logging.info('Rank %d running ics %s'%(world_rank, str(ics)))
for i, ic in enumerate(ics):
t1 = time.time()
logging.info("Initial condition {} of {}".format(i+1, n_ics))
vl, a, tq = autoregressive_inference(params, ic, valid_data_full, valid_data_tp_full, model_wind, model)
if i == 0:
valid_loss = vl
acc = a
tqe = tq
else:
valid_loss = np.concatenate((valid_loss, vl), 0)
acc = np.concatenate((acc, a), 0)
tqe = np.concatenate((tqe, tq), 0)
t2 = time.time()-t1
logging.info("Time for inference for ic {} = {}".format(i, t2))
prediction_length = acc[0].shape[0]
n_out_channels = acc[0].shape[1]
#save predictions and loss
#save predictions and loss
h5name = os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions'+ autoregressive_inference_filetag +'.h5')
if dist.is_initialized():
if params.log_to_screen:
logging.info("Saving files at {}".format(h5name))
logging.info("array shapes: %s"%str((tot_ics, prediction_length, n_out_channels)))
dist.barrier()
from mpi4py import MPI
with h5py.File(h5name, 'a', driver='mpio', comm=MPI.COMM_WORLD) as f:
if "rmse" in f.keys() or "acc" in f.keys():
del f["acc"]
del f["rmse"]
f.create_dataset("rmse", shape = (tot_ics, prediction_length, n_out_channels), dtype =np.float32)
f.create_dataset("acc", shape = (tot_ics, prediction_length, n_out_channels), dtype =np.float32)
start = world_rank*ics_per_proc
f["rmse"][start:start+n_ics] = valid_loss
f["acc"][start:start+n_ics] = acc
dist.barrier()
else:
if params.log_to_screen:
logging.info("Saving files at {}".format(os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions' + autoregressive_inference_filetag + '.h5')))
with h5py.File(os.path.join(params['experiment_dir'], 'ens_autoregressive_predictions'+ autoregressive_inference_filetag +'.h5'), 'a') as f:
try:
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["rmse"]
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["rmse"][...] = valid_loss
try:
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc"]
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc"][...] = acc
try:
f.create_dataset("tqe", data = tqe, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["tqe"]
f.create_dataset("tqe", data = tqe, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["tqe"][...] = tqe
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import os
import sys
import time
import numpy as np
import argparse
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
from numpy.core.numeric import False_
import h5py
import torch
import torchvision
from torchvision.utils import save_image
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
import logging
from utils import logging_utils
from utils.weighted_acc_rmse import weighted_rmse_torch_channels, weighted_acc_torch_channels, unlog_tp_torch, top_quantiles_error_torch
logging_utils.config_logger()
from utils.YParams import YParams
from utils.data_loader_multifiles import get_data_loader
from networks.afnonet import AFNONet, PrecipNet
import wandb
import matplotlib.pyplot as plt
import glob
from datetime import datetime
DECORRELATION_TIME = 8 # 2 days for preicp
def gaussian_perturb(x, level=0.01, device=0):
noise = level * torch.randn(x.shape).to(device, dtype=torch.float)
return (x + noise)
def load_model(model, params, checkpoint_file):
model.zero_grad()
checkpoint_fname = checkpoint_file
checkpoint = torch.load(checkpoint_fname)
try:
new_state_dict = OrderedDict()
for key, val in checkpoint['model_state'].items():
name = key[7:]
if name != 'ged':
new_state_dict[name] = val
model.load_state_dict(new_state_dict)
except:
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model
def downsample(x, scale=0.125):
return torch.nn.functional.interpolate(x, scale_factor=scale, mode='bilinear')
def setup(params):
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
#get data loader
valid_data_loader, valid_dataset = get_data_loader(params, params.inf_data_path, dist.is_initialized(), train=False)
img_shape_x = valid_dataset.img_shape_x
img_shape_y = valid_dataset.img_shape_y
params.img_shape_x = img_shape_x
params.img_shape_y = img_shape_y
if params.log_to_screen:
logging.info('Loading trained model checkpoint from {}'.format(params['best_checkpoint_path']))
in_channels = np.array(params.in_channels)
out_channels = np.array(params.in_channels)# for the backbone model, will be reset later
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
if params["orography"]:
params['N_in_channels'] = n_in_channels + 1
else:
params['N_in_channels'] = n_in_channels
params['N_out_channels'] = n_out_channels
params.means = np.load(params.global_means_path)[0, out_channels] # needed to standardize wind data
params.stds = np.load(params.global_stds_path)[0, out_channels]
# load wind model
if params.nettype_wind == 'afno':
model_wind = AFNONet(params).to(device)
if 'model_wind_path' not in params:
raise Exception("no backbone model weights specified")
checkpoint_file = params['model_wind_path']
model_wind = load_model(model_wind, params, checkpoint_file)
model_wind = model_wind.to(device)
# reset channels for precip
params['N_out_channels'] = len(params['out_channels'])
# load the model
if params.nettype == 'afno':
model = AFNONet(params).to(device)
else:
raise Exception("not implemented")
model = PrecipNet(params, backbone=model).to(device)
checkpoint_file = params['best_checkpoint_path']
model = load_model(model, params, checkpoint_file)
model = model.to(device)
# load the validation data
files_paths = glob.glob(params.inf_data_path + "/*.h5")
files_paths.sort()
# which year
yr = 0
if params.log_to_screen:
logging.info('Loading validation data')
logging.info('Validation data from {}'.format(files_paths[yr]))
valid_data_full = h5py.File(files_paths[yr], 'r')['fields']
# precip paths
path = params.precip + '/out_of_sample'
precip_paths = glob.glob(path + "/*.h5")
precip_paths.sort()
if params.log_to_screen:
logging.info('Loading validation precip data')
logging.info('Validation data from {}'.format(precip_paths[0]))
valid_data_tp_full = h5py.File(precip_paths[0], 'r')['tp']
return valid_data_full, valid_data_tp_full, model_wind, model
def autoregressive_inference(params, ic, valid_data_full, valid_data_tp_full, model_wind, model):
ic = int(ic)
#initialize global variables
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
exp_dir = params['experiment_dir']
dt = int(params.dt)
prediction_length = int(params.prediction_length/dt)
n_history = params.n_history
img_shape_x = params.img_shape_x
img_shape_y = params.img_shape_y
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
n_in_channels = len(in_channels)
n_out_channels = len(out_channels)
means = params.means
stds = params.stds
#initialize memory for image sequences and RMSE/ACC, tqe for precip
valid_loss = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
acc_unweighted = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
tqe = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
# wind seqs
seq_real = torch.zeros((prediction_length, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred = torch.zeros((prediction_length, n_in_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
# precip sequences
seq_real_tp = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
seq_pred_tp = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
valid_data = valid_data_full[ic:(ic+prediction_length*dt+n_history*dt):dt, in_channels, 0:720] #extract valid data from first year
# standardize
valid_data = (valid_data - means)/stds
valid_data = torch.as_tensor(valid_data).to(device, dtype=torch.float)
len_ic = prediction_length*dt
valid_data_tp = valid_data_tp_full[ic:(ic+prediction_length*dt):dt, 0:720].reshape(len_ic,n_out_channels,720,img_shape_y) #extract valid data from first year
# log normalize
eps = params.precip_eps
valid_data_tp = np.log1p(valid_data_tp/eps)
valid_data_tp = torch.as_tensor(valid_data_tp).to(device, dtype=torch.float)
m = torch.as_tensor(np.load(params.time_means_path_tp)[0][out_channels])[:, 0:img_shape_x] # climatology
m = torch.unsqueeze(m, 0)
m = m.to(device, dtype=torch.float)
std = torch.as_tensor(stds[:,0,0]).to(device, dtype=torch.float)
orography = params.orography
orography_path = params.orography_path
if orography:
orog = torch.as_tensor(np.expand_dims(np.expand_dims(h5py.File(orography_path, 'r')['orog'][0:720], axis = 0), axis = 0)).to(device, dtype = torch.float)
logging.info("orography loaded; shape:{}".format(orog.shape))
#autoregressive inference
if params.log_to_screen:
logging.info('Begin autoregressive inference')
with torch.no_grad():
for i in range(valid_data.shape[0]):
if i==0: #start of sequence
first = valid_data[0:n_history+1]
first_tp = valid_data_tp[0:1]
future = valid_data[n_history+1]
future_tp = valid_data_tp[1]
for h in range(n_history+1):
seq_real[h] = first[h*n_in_channels:(h+1)*n_in_channels][0:n_in_channels] #extract history from 1st
seq_pred[h] = seq_real[h]
seq_real_tp[0] = unlog_tp_torch(first_tp)
seq_pred_tp[0] = unlog_tp_torch(first_tp)
if params.perturb:
first = gaussian_perturb(first, level=params.n_level, device=device) # perturb the ic
if orography:
future_pred = model_wind(torch.cat((first, orog), axis=1))
else:
future_pred = model_wind(first)
future_pred_tp = model(future_pred)
else:
if i < prediction_length-1:
future = valid_data[n_history+i+1]
future_tp = valid_data_tp[i+1]
if orography:
future_pred = model_wind(torch.cat((future_pred, orog), axis=1)) #autoregressive step
else:
future_pred = model_wind(future_pred) #autoregressive step
future_pred_tp = model(future_pred) # tp diagnosis
if i < prediction_length-1: #not on the last step
seq_pred[n_history+i+1] = future_pred
seq_real[n_history+i+1] = future
seq_pred_tp[i+1] = unlog_tp_torch(future_pred_tp) # this predicts 6-12 precip: 0 -> 6 (afno) -> 6-12 precip
seq_real_tp[i+1] = unlog_tp_torch(future_tp) # which is the i+1th validation data
#collect history
history_stack = seq_pred[i+1:i+2+n_history]
# ic for next wind step
future_pred = history_stack
pred = torch.unsqueeze(seq_pred_tp[i], 0)
tar = torch.unsqueeze(seq_real_tp[i], 0)
valid_loss[i] = weighted_rmse_torch_channels(pred, tar)
acc[i] = weighted_acc_torch_channels(pred-m, tar-m)
tqe[i] = top_quantiles_error_torch(pred, tar)
if params.log_to_screen:
logging.info('Timestep {} of {}. TP RMS Error: {}, ACC: {}'.format((i), prediction_length, valid_loss[i,0], acc[i,0]))
seq_real_tp = seq_real_tp.cpu().numpy()
seq_pred_tp = seq_pred_tp.cpu().numpy()
valid_loss = valid_loss.cpu().numpy()
acc = acc.cpu().numpy()
acc_unweighted = acc_unweighted.cpu().numpy()
tqe = tqe.cpu().numpy()
return np.expand_dims(seq_real_tp, 0), np.expand_dims(seq_pred_tp, 0), np.expand_dims(valid_loss, 0), \
np.expand_dims(acc, 0), np.expand_dims(acc_unweighted, 0), np.expand_dims(tqe, 0)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--run_num", default='00', type=str)
parser.add_argument("--yaml_config", default='./config/AFNO.yaml', type=str)
parser.add_argument("--config", default='full_field', type=str)
parser.add_argument("--vis", action='store_true')
parser.add_argument("--override_dir", default=None, type = str, help = 'Path to store inference outputs; must also set --weights arg')
parser.add_argument("--weights", default=None, type=str, help = 'Path to model weights, for use with override_dir option')
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
params['world_size'] = 1
params['global_batch_size'] = params.batch_size
torch.cuda.set_device(0)
torch.backends.cudnn.benchmark = True
vis = args.vis
# Set up directory
if args.override_dir is not None:
assert args.weights is not None, 'Must set --weights argument if using --override_dir'
expDir = args.override_dir
else:
assert args.weights is None, 'Cannot use --weights argument without also using --override_dir'
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
if not os.path.isdir(expDir):
os.makedirs(expDir)
params['experiment_dir'] = os.path.abspath(expDir)
params['best_checkpoint_path'] = args.weights if args.override_dir is not None else os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
params['resuming'] = False
params['local_rank'] = 0
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference_out.log'))
logging_utils.log_versions()
params.log()
n_ics = params['n_initial_conditions']
ics = [1066, 1050, 1034]
n_samples_per_year = 1460
if params["ics_type"] == 'default':
num_samples = n_samples_per_year-params.prediction_length
stop = num_samples
ics = np.arange(0, stop, DECORRELATION_TIME)
if vis: # visualization for just the first ic (or any ic)
ics = [0]
n_ics = len(ics)
elif params["ics_type"] == "datetime":
date_strings = params["date_strings"]
ics = []
if params.perturb: #for perturbations use a single date and create n_ics perturbations
n_ics = params["n_perturbations"]
date = date_strings[0]
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
for ii in range(n_ics):
ics.append(int(hours_since_jan_01_epoch/6))
else:
for date in date_strings:
date_obj = datetime.strptime(date,'%Y-%m-%d %H:%M:%S')
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
ics.append(int(hours_since_jan_01_epoch/6))
n_ics = len(ics)
logging.info("Inference for {} initial conditions".format(n_ics))
try:
autoregressive_inference_filetag = params["inference_file_tag"]
except:
autoregressive_inference_filetag = ""
autoregressive_inference_filetag += "_tp"
# get data and models
valid_data_full, valid_data_tp_full, model_wind, model = setup(params)
#initialize lists for image sequences and RMSE/ACC
valid_loss = np.zeros
acc_unweighted = []
acc = []
tqe = []
seq_pred = []
seq_real = []
#run autoregressive inference for multiple initial conditions
for i, ic in enumerate(ics):
t1 = time.time()
logging.info("Initial condition {} of {}".format(i+1, n_ics))
sr, sp, vl, a, au, tq = autoregressive_inference(params, ic, valid_data_full, valid_data_tp_full, model_wind, model)
if i == 0:
seq_real = sr
seq_pred = sp
valid_loss = vl
acc = a
acc_unweighted = au
tqe = tq
else:
# seq_real = np.concatenate((seq_real, sr), 0)
# seq_pred = np.concatenate((seq_pred, sp), 0)
valid_loss = np.concatenate((valid_loss, vl), 0)
acc = np.concatenate((acc, a), 0)
acc_unweighted = np.concatenate((acc_unweighted, au), 0)
tqe = np.concatenate((tqe, tq), 0)
t2 = time.time()-t1
print("time for 1 autoreg inference = ", t2)
prediction_length = seq_real[0].shape[0]
n_out_channels = seq_real[0].shape[1]
img_shape_x = seq_real[0].shape[2]
img_shape_y = seq_real[0].shape[3]
#save predictions and loss
if params.log_to_screen:
logging.info("Saving files at {}".format(os.path.join(params['experiment_dir'], 'autoregressive_predictions' + autoregressive_inference_filetag + '.h5')))
with h5py.File(os.path.join(params['experiment_dir'], 'autoregressive_predictions'+ autoregressive_inference_filetag +'.h5'), 'a') as f:
if vis:
try:
f.create_dataset("ground_truth", data = seq_real, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
except:
del f["ground_truth"]
f.create_dataset("ground_truth", data = seq_real, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
f["ground_truth"][...] = seq_real
try:
f.create_dataset("predicted", data = seq_pred, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
except:
del f["predicted"]
f.create_dataset("predicted", data = seq_pred, shape = (n_ics, prediction_length, n_out_channels, img_shape_x, img_shape_y), dtype = np.float32)
f["predicted"][...]= seq_pred
try:
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["rmse"]
f.create_dataset("rmse", data = valid_loss, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["rmse"][...] = valid_loss
try:
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc"]
f.create_dataset("acc", data = acc, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc"][...] = acc
try:
f.create_dataset("acc_unweighted", data = acc_unweighted, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["acc_unweighted"]
f.create_dataset("acc_unweighted", data = acc_unweighted, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["acc_unweighted"][...] = acc_unweighted
try:
f.create_dataset("tqe", data = tqe, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
except:
del f["tqe"]
f.create_dataset("tqe", data = tqe, shape = (n_ics, prediction_length, n_out_channels), dtype =np.float32)
f["tqe"][...] = tqe
f.close()
#!/bin/bash
#shifter --image=nersc/pytorch:ngc-21.08-v1 --env PYTHONUSERBASE=/pscratch/home/jpathak/perlmutter/ngc-21.08-v1 python \
# train.py --enable_amp --config pretrained_two_step_afno_20ch_bs_64_lr1em4_blk_8_patch_8_cosine_sched --run_num test0
export MASTER_ADDR=$(hostname)
image=nersc/pytorch:ngc-22.02-v0
ngpu=4
config_file=./config/AFNO.yaml
config="afno_backbone"
run_num="check"
cmd="python train.py --enable_amp --yaml_config=$config_file --config=$config --run_num=$run_num"
srun -n $ngpu --cpus-per-task=32 --gpus-per-node $ngpu shifter --image=${image} bash -c "source export_DDP_vars.sh && $cmd"
#reference: https://github.com/NVlabs/AFNO-transformer
import math
from functools import partial
from collections import OrderedDict
from copy import Error, deepcopy
from re import S
from numpy.lib.arraypad import pad
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
#from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
import torch.fft
from torch.nn.modules.container import Sequential
from torch.utils.checkpoint import checkpoint_sequential
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from utils.img_utils import PeriodicPad2d
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class AFNO2D(nn.Module):
def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1):
super().__init__()
assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"
self.hidden_size = hidden_size
self.sparsity_threshold = sparsity_threshold
self.num_blocks = num_blocks
self.block_size = self.hidden_size // self.num_blocks
self.hard_thresholding_fraction = hard_thresholding_fraction
self.hidden_size_factor = hidden_size_factor
self.scale = 0.02
self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
def forward(self, x):
bias = x
dtype = x.dtype
x = x.float()
B, H, W, C = x.shape
x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)
o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
o2_real = torch.zeros(x.shape, device=x.device)
o2_imag = torch.zeros(x.shape, device=x.device)
total_modes = H // 2 + 1
kept_modes = int(total_modes * self.hard_thresholding_fraction)
o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \
self.b1[0]
)
o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \
self.b1[1]
)
o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \
torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
self.b2[0]
)
o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \
torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
self.b2[1]
)
x = torch.stack([o2_real, o2_imag], dim=-1)
x = F.softshrink(x, lambd=self.sparsity_threshold)
x = torch.view_as_complex(x)
x = x.reshape(B, H, W // 2 + 1, C)
x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho")
x = x.type(dtype)
return x + bias
class Block(nn.Module):
def __init__(
self,
dim,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
double_skip=True,
num_blocks=8,
sparsity_threshold=0.01,
hard_thresholding_fraction=1.0
):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
#self.drop_path = nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.double_skip = double_skip
def forward(self, x):
residual = x
x = self.norm1(x)
x = self.filter(x)
if self.double_skip:
x = x + residual
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path(x)
x = x + residual
return x
class PrecipNet(nn.Module):
def __init__(self, params, backbone):
super().__init__()
self.params = params
self.patch_size = (params.patch_size, params.patch_size)
self.in_chans = params.N_in_channels
self.out_chans = params.N_out_channels
self.backbone = backbone
self.ppad = PeriodicPad2d(1)
self.conv = nn.Conv2d(self.out_chans, self.out_chans, kernel_size=3, stride=1, padding=0, bias=True)
self.act = nn.ReLU()
def forward(self, x):
x = self.backbone(x)
x = self.ppad(x)
x = self.conv(x)
x = self.act(x)
return x
class AFNONet(nn.Module):
def __init__(
self,
params,
img_size=(720, 1440),
patch_size=(16, 16),
in_chans=2,
out_chans=2,
embed_dim=768,
depth=12,
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.,
num_blocks=16,
sparsity_threshold=0.01,
hard_thresholding_fraction=1.0,
):
super().__init__()
self.params = params
self.img_size = img_size
self.patch_size = (params.patch_size, params.patch_size)
self.in_chans = params.N_in_channels
self.out_chans = params.N_out_channels
self.num_features = self.embed_dim = embed_dim
self.num_blocks = params.num_blocks
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.h = img_size[0] // self.patch_size[0]
self.w = img_size[1] // self.patch_size[1]
self.blocks = nn.ModuleList([
Block(dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
num_blocks=self.num_blocks, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, self.out_chans*self.patch_size[0]*self.patch_size[1], bias=False)
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
x = x.reshape(B, self.h, self.w, self.embed_dim)
for blk in self.blocks:
x = blk(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = rearrange(
x,
"b h w (p1 p2 c_out) -> b c_out (h p1) (w p2)",
p1=self.patch_size[0],
p2=self.patch_size[1],
h=self.img_size[0] // self.patch_size[0],
w=self.img_size[1] // self.patch_size[1],
)
return x
class PatchEmbed(nn.Module):
def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
if __name__ == "__main__":
model = AFNONet(img_size=(720, 1440), patch_size=(4,4), in_chans=3, out_chans=10)
sample = torch.randn(1, 3, 720, 1440)
result = model(sample)
print(result.shape)
print(torch.norm(result))
# 1. 重新初始化,排除大文件
cd /workspace/FourCastNet
rm -rf .git
# 2. 创建 .gitignore 排除数据文件
cat > .gitignore << 'EOF'
data.tar.gz
data/
*.tar.gz
*.zip
EOF
# 3. 重新初始化 Git
git init
git add .
git commit -m "Initial commit: FourCastNet code only"
# 4. 强制推送(这次应该成功)
git remote add origin http://developer.sourcefind.cn/codes/bw_bestperf/fourcastnet_train.git
git push -f origin master
rm -rf ./exp
export HIP_MODULE_LOADING=EAGER
export HIP_ALLOC_INITIALIZE=0
export HIP_REUSE_MODULE=1
export HIP_VISIBLE_DEVICES=0
config_file=./config/AFNO-50epoch.yaml
config='afno_backbone'
run_num='check_exp'
export HDF5_USE_FILE_LOCKING=FALSE
python train.py --enable_amp --yaml_config=$config_file --config=$config --run_num=$run_num
#!/bin/bash -l
#SBATCH --time=06:00:00
#SBATCH -C gpu
#SBATCH --account=m4134_g
#SBATCH --nodes=16
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-task=32
#SBATCH -J afno
#SBATCH --image=nersc/pytorch:ngc-22.02-v0
#SBATCH -o afno_backbone_finetune.out
config_file=./config/AFNO.yaml
config='afno_backbone_finetune'
run_num='0'
export HDF5_USE_FILE_LOCKING=FALSE
export NCCL_NET_GDR_LEVEL=PHB
export MASTER_ADDR=$(hostname)
set -x
srun -u --mpi=pmi2 shifter \
bash -c "
source export_DDP_vars.sh
python train.py --enable_amp --yaml_config=$config_file --config=$config --run_num=$run_num
"
#!/bin/bash
#SBATCH --time=00:15:00
#SBATCH -N 4
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-task=32
#SBATCH -C gpu
#SBATCH --account=m4134_g
#SBATCH -q regular
#SBATCH --image=nersc/pytorch:ngc-22.02-v0
export HDF5_USE_FILE_LOCKING=FALSE
export MASTER_ADDR=$(hostname)
launch="python inference/inference_ensemble.py --config=afno_backbone_finetune --run_num=0 --n_level=0.3"
#launch="python inference/inference_ensemble_precip.py --config=precip --run_num=1 --n_level=0.1"
srun --mpi=pmi2 -u -l shifter --module gpu --env PYTHONUSERBASE=$HOME/.local/perlmutter/nersc-pytorch-22.02-v0 bash -c "
source export_DDP_vars.sh
$launch
"
This diff is collapsed.
from ruamel.yaml import YAML
import logging
class YParams():
""" Yaml file parser """
def __init__(self, yaml_filename, config_name, print_params=False):
self._yaml_filename = yaml_filename
self._config_name = config_name
self.params = {}
if print_params:
print("------------------ Configuration ------------------")
with open(yaml_filename) as _file:
for key, val in YAML().load(_file)[config_name].items():
if print_params: print(key, val)
if val =='None': val = None
self.params[key] = val
self.__setattr__(key, val)
if print_params:
print("---------------------------------------------------")
def __getitem__(self, key):
return self.params[key]
def __setitem__(self, key, val):
self.params[key] = val
self.__setattr__(key, val)
def __contains__(self, key):
return (key in self.params)
def update_params(self, config):
for key, val in config.items():
self.params[key] = val
self.__setattr__(key, val)
def log(self):
logging.info("------------------ Configuration ------------------")
logging.info("Configuration file: "+str(self._yaml_filename))
logging.info("Configuration name: "+str(self._config_name))
for key, val in self.params.items():
logging.info(str(key) + ' ' + str(val))
logging.info("---------------------------------------------------")
#MIT License
#
#Copyright (c) 2020 Zongyi Li
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.
import torch
import numpy as np
import scipy.io
import h5py
import torch.nn as nn
#################################################
#
# Utilities
#
#################################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# reading data
class MatReader(object):
def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True):
super(MatReader, self).__init__()
self.to_torch = to_torch
self.to_cuda = to_cuda
self.to_float = to_float
self.file_path = file_path
self.data = None
self.old_mat = None
self._load_file()
def _load_file(self):
try:
self.data = scipy.io.loadmat(self.file_path)
self.old_mat = True
except:
self.data = h5py.File(self.file_path)
self.old_mat = False
def load_file(self, file_path):
self.file_path = file_path
self._load_file()
def read_field(self, field):
x = self.data[field]
if not self.old_mat:
x = x[()]
x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1))
if self.to_float:
x = x.astype(np.float32)
if self.to_torch:
x = torch.from_numpy(x)
if self.to_cuda:
x = x.cuda()
return x
def set_cuda(self, to_cuda):
self.to_cuda = to_cuda
def set_torch(self, to_torch):
self.to_torch = to_torch
def set_float(self, to_float):
self.to_float = to_float
# normalization, pointwise gaussian
class UnitGaussianNormalizer(object):
def __init__(self, x, eps=0.00001):
super(UnitGaussianNormalizer, self).__init__()
# x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
self.mean = torch.mean(x, 0)
self.std = torch.std(x, 0)
self.eps = eps
def encode(self, x):
x = (x - self.mean) / (self.std + self.eps)
return x.float()
def decode(self, x, sample_idx=None):
if sample_idx is None:
std = self.std + self.eps # n
mean = self.mean
else:
if len(self.mean.shape) == len(sample_idx[0].shape):
std = self.std[sample_idx] + self.eps # batch*n
mean = self.mean[sample_idx]
if len(self.mean.shape) > len(sample_idx[0].shape):
std = self.std[:,sample_idx]+ self.eps # T*batch*n
mean = self.mean[:,sample_idx]
# x is in shape of batch*n or T*batch*n
x = (x * std) + mean
return x.float()
def cuda(self):
self.mean = self.mean.cuda()
self.std = self.std.cuda()
def cpu(self):
self.mean = self.mean.cpu()
self.std = self.std.cpu()
# normalization, Gaussian
class GaussianNormalizer(object):
def __init__(self, x, eps=0.00001):
super(GaussianNormalizer, self).__init__()
self.mean = torch.mean(x)
self.std = torch.std(x)
self.eps = eps
def encode(self, x):
x = (x - self.mean) / (self.std + self.eps)
return x
def decode(self, x, sample_idx=None):
x = (x * (self.std + self.eps)) + self.mean
return x
def cuda(self):
self.mean = self.mean.cuda()
self.std = self.std.cuda()
def cpu(self):
self.mean = self.mean.cpu()
self.std = self.std.cpu()
# normalization, scaling by range
class RangeNormalizer(object):
def __init__(self, x, low=0.0, high=1.0):
super(RangeNormalizer, self).__init__()
mymin = torch.min(x, 0)[0].view(-1)
mymax = torch.max(x, 0)[0].view(-1)
self.a = (high - low)/(mymax - mymin)
self.b = -self.a*mymax + high
def encode(self, x):
s = x.size()
x = x.view(s[0], -1)
x = self.a*x + self.b
x = x.view(s)
return x
def decode(self, x):
s = x.size()
x = x.view(s[0], -1)
x = (x - self.b)/self.a
x = x.view(s)
return x
#loss function with rel/abs Lp loss
class LpLoss(object):
def __init__(self, d=2, p=2, size_average=True, reduction=True):
super(LpLoss, self).__init__()
#Dimension and Lp-norm type are postive
assert d > 0 and p > 0
self.d = d
self.p = p
self.reduction = reduction
self.size_average = size_average
def abs(self, x, y):
num_examples = x.size()[0]
#Assume uniform mesh
h = 1.0 / (x.size()[1] - 1.0)
all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)
if self.reduction:
if self.size_average:
return torch.mean(all_norms)
else:
return torch.sum(all_norms)
return all_norms
def rel(self, x, y):
num_examples = x.size()[0]
diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)
if self.reduction:
if self.size_average:
return torch.mean(diff_norms/y_norms)
else:
return torch.sum(diff_norms/y_norms)
return diff_norms/y_norms
def __call__(self, x, y):
return self.rel(x, y)
# Sobolev norm (HS norm)
# where we also compare the numerical derivatives between the output and target
class HsLoss(object):
def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True):
super(HsLoss, self).__init__()
#Dimension and Lp-norm type are postive
assert d > 0 and p > 0
self.d = d
self.p = p
self.k = k
self.balanced = group
self.reduction = reduction
self.size_average = size_average
if a == None:
a = [1,] * k
self.a = a
def rel(self, x, y):
num_examples = x.size()[0]
diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)
if self.reduction:
if self.size_average:
return torch.mean(diff_norms/y_norms)
else:
return torch.sum(diff_norms/y_norms)
return diff_norms/y_norms
def __call__(self, x, y, a=None):
nx = x.size()[1]
ny = x.size()[2]
k = self.k
balanced = self.balanced
a = self.a
x = x.view(x.shape[0], nx, ny, -1)
y = y.view(y.shape[0], nx, ny, -1)
k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny)
k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1)
k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device)
k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device)
x = torch.fft.fftn(x, dim=[1, 2])
y = torch.fft.fftn(y, dim=[1, 2])
if balanced==False:
weight = 1
if k >= 1:
weight += a[0]**2 * (k_x**2 + k_y**2)
if k >= 2:
weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4)
weight = torch.sqrt(weight)
loss = self.rel(x*weight, y*weight)
else:
loss = self.rel(x, y)
if k >= 1:
weight = a[0] * torch.sqrt(k_x**2 + k_y**2)
loss += self.rel(x*weight, y*weight)
if k >= 2:
weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4)
loss += self.rel(x*weight, y*weight)
loss = loss / (k+1)
return loss
# A simple feedforward neural network
class DenseNet(torch.nn.Module):
def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False):
super(DenseNet, self).__init__()
self.n_layers = len(layers) - 1
assert self.n_layers >= 1
self.layers = nn.ModuleList()
for j in range(self.n_layers):
self.layers.append(nn.Linear(layers[j], layers[j+1]))
if j != self.n_layers - 1:
if normalize:
self.layers.append(nn.BatchNorm1d(layers[j+1]))
self.layers.append(nonlinearity())
if out_nonlinearity is not None:
self.layers.append(out_nonlinearity())
def forward(self, x):
for _, l in enumerate(self.layers):
x = l(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment