Commit bdd87fae authored by zhangwenbo's avatar zhangwenbo
Browse files

Initial commit: FourCastNet source code only

parents
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import logging
import glob
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch import Tensor
import h5py
import math
#import cv2
from utils.img_utils import reshape_fields, reshape_precip
def get_data_loader(params, files_pattern, distributed, train):
dataset = GetDataset(params, files_pattern, train)
sampler = DistributedSampler(dataset, shuffle=train) if distributed else None
dataloader = DataLoader(dataset,
batch_size=int(params.batch_size),
num_workers=params.num_data_workers,
shuffle=False, #(sampler is None),
sampler=sampler if train else None,
drop_last=True,
pin_memory=torch.cuda.is_available())
if train:
return dataloader, dataset, sampler
else:
return dataloader, dataset
class GetDataset(Dataset):
def __init__(self, params, location, train):
self.params = params
self.location = location
self.train = train
self.dt = params.dt
self.n_history = params.n_history
self.in_channels = np.array(params.in_channels)
self.out_channels = np.array(params.out_channels)
self.n_in_channels = len(self.in_channels)
self.n_out_channels = len(self.out_channels)
self.crop_size_x = params.crop_size_x
self.crop_size_y = params.crop_size_y
self.roll = params.roll
self._get_files_stats()
self.two_step_training = params.two_step_training
self.orography = params.orography
self.precip = True if "precip" in params else False
self.add_noise = params.add_noise if train else False
if self.precip:
path = params.precip+'/train' if train else params.precip+'/test'
self.precip_paths = glob.glob(path + "/*.h5")
self.precip_paths.sort()
try:
self.normalize = params.normalize
except:
self.normalize = True #by default turn on normalization if not specified in config
if self.orography:
self.orography_path = params.orography_path
def _get_files_stats(self):
self.files_paths = glob.glob(self.location + "/*.h5")
self.files_paths.sort()
self.n_years = len(self.files_paths)
with h5py.File(self.files_paths[0], 'r') as _f:
logging.info("Getting file stats from {}".format(self.files_paths[0]))
self.n_samples_per_year = _f['fields'].shape[0]
#original image shape (before padding)
self.img_shape_x = _f['fields'].shape[2] -1#just get rid of one of the pixels
self.img_shape_y = _f['fields'].shape[3]
self.n_samples_total = self.n_years * self.n_samples_per_year
self.files = [None for _ in range(self.n_years)]
self.precip_files = [None for _ in range(self.n_years)]
logging.info("Number of samples per year: {}".format(self.n_samples_per_year))
logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location, self.n_samples_total, self.img_shape_x, self.img_shape_y, self.n_in_channels))
logging.info("Delta t: {} hours".format(6*self.dt))
logging.info("Including {} hours of past history in training at a frequency of {} hours".format(6*self.dt*self.n_history, 6*self.dt))
def _open_file(self, year_idx):
_file = h5py.File(self.files_paths[year_idx], 'r')
self.files[year_idx] = _file['fields']
if self.orography:
_orog_file = h5py.File(self.orography_path, 'r')
self.orography_field = _orog_file['orog']
if self.precip:
self.precip_files[year_idx] = h5py.File(self.precip_paths[year_idx], 'r')['tp']
def __len__(self):
return self.n_samples_total
def __getitem__(self, global_idx):
year_idx = int(global_idx/self.n_samples_per_year) #which year we are on
local_idx = int(global_idx%self.n_samples_per_year) #which sample in that year we are on - determines indices for centering
y_roll = np.random.randint(0, 1440) if self.train else 0#roll image in y direction
#open image file
if self.files[year_idx] is None:
self._open_file(year_idx)
if not self.precip:
#if we are not at least self.dt*n_history timesteps into the prediction
if local_idx < self.dt*self.n_history:
local_idx += self.dt*self.n_history
#if we are on the last image in a year predict identity, else predict next timestep
step = 0 if local_idx >= self.n_samples_per_year-self.dt else self.dt
else:
inp_local_idx = local_idx
tar_local_idx = local_idx
#if we are on the last image in a year predict identity, else predict next timestep
step = 0 if tar_local_idx >= self.n_samples_per_year-self.dt else self.dt
# first year has 2 missing samples in precip (they are first two time points)
if year_idx == 0:
lim = 1458
local_idx = local_idx%lim
inp_local_idx = local_idx + 2
tar_local_idx = local_idx
step = 0 if tar_local_idx >= lim-self.dt else self.dt
#if two_step_training flag is true then ensure that local_idx is not the last or last but one sample in a year
if self.two_step_training:
if local_idx >= self.n_samples_per_year - 2*self.dt:
#set local_idx to last possible sample in a year that allows taking two steps forward
local_idx = self.n_samples_per_year - 3*self.dt
if self.train and self.roll:
y_roll = random.randint(0, self.img_shape_y)
else:
y_roll = 0
if self.orography:
orog = self.orography_field[0:720]
else:
orog = None
if self.train and (self.crop_size_x or self.crop_size_y):
rnd_x = random.randint(0, self.img_shape_x-self.crop_size_x)
rnd_y = random.randint(0, self.img_shape_y-self.crop_size_y)
else:
rnd_x = 0
rnd_y = 0
if self.precip:
return reshape_fields(self.files[year_idx][inp_local_idx, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train), \
reshape_precip(self.precip_files[year_idx][tar_local_idx+step], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train)
else:
if self.two_step_training:
return reshape_fields(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train, self.normalize, orog, self.add_noise), \
reshape_fields(self.files[year_idx][local_idx + step:local_idx + step + 2, self.out_channels], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train, self.normalize, orog)
else:
return reshape_fields(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train, self.normalize, orog, self.add_noise), \
reshape_fields(self.files[year_idx][local_idx + step, self.out_channels], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train, self.normalize, orog)
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import numpy as np
from datetime import datetime
#day_of_year = datetime.now().timetuple().tm_yday # returns 1 for January 1st
#time_tuple = datetime.now().timetuple()
date_strings = ["2016-01-01 00:00:00", "2016-09-13 00:00:00", "2016-09-17 00:00:00", "2016-09-21 00:00:00", "2016-09-25 00:00:00", "2016-09-29 00:00:00", "2016-10-03 00:00:00", "2016-10-07 00:00:00"]
ics = []
for date_ in date_strings:
date_obj = datetime.strptime(date_, '%Y-%m-%d %H:%M:%S') #datetime.fromisoformat(date_)
print(date_obj.timetuple())
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
ics.append(int(hours_since_jan_01_epoch/6))
print(day_of_year, hour_of_day)
print("hours = ", hours_since_jan_01_epoch )
print("steps = ", hours_since_jan_01_epoch/6)
print(ics)
ics = []
for date_ in date_strings:
date_obj = datetime.fromisoformat(date_) #datetime.strptime(date_, '%Y-%m-%d %H:%M:%S') #datetime.fromisoformat(date_)
print(date_obj.timetuple())
day_of_year = date_obj.timetuple().tm_yday - 1
hour_of_day = date_obj.timetuple().tm_hour
hours_since_jan_01_epoch = 24*day_of_year + hour_of_day
ics.append(int(hours_since_jan_01_epoch/6))
print(day_of_year, hour_of_day)
print("hours = ", hours_since_jan_01_epoch )
print("steps = ", hours_since_jan_01_epoch/6)
print(ics)
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
#Ashesh Chattopadhyay - Rice University
#Morteza Mardani - NVIDIA Corporation
#Thorsten Kurth - NVIDIA Corporation
#David Hall - NVIDIA Corporation
#Zongyi Li - California Institute of Technology, NVIDIA Corporation
#Kamyar Azizzadenesheli - Purdue University
#Pedram Hassanzadeh - Rice University
#Karthik Kashinath - NVIDIA Corporation
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
import logging
import glob
from types import new_class
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch import Tensor
import h5py
import math
import torchvision.transforms.functional as TF
import matplotlib
import matplotlib.pyplot as plt
class PeriodicPad2d(nn.Module):
"""
pad longitudinal (left-right) circular
and pad latitude (top-bottom) with zeros
"""
def __init__(self, pad_width):
super(PeriodicPad2d, self).__init__()
self.pad_width = pad_width
def forward(self, x):
# pad left and right circular
out = F.pad(x, (self.pad_width, self.pad_width, 0, 0), mode="circular")
# pad top and bottom zeros
out = F.pad(out, (0, 0, self.pad_width, self.pad_width), mode="constant", value=0)
return out
def reshape_fields(img, inp_or_tar, crop_size_x, crop_size_y,rnd_x, rnd_y, params, y_roll, train, normalize=True, orog=None, add_noise=False):
#Takes in np array of size (n_history+1, c, h, w) and returns torch tensor of size ((n_channels*(n_history+1), crop_size_x, crop_size_y)
if len(np.shape(img)) ==3:
img = np.expand_dims(img, 0)
img = img[:, :, 0:720] #remove last pixel
n_history = np.shape(img)[0] - 1
img_shape_x = np.shape(img)[-2]
img_shape_y = np.shape(img)[-1]
n_channels = np.shape(img)[1] #this will either be N_in_channels or N_out_channels
channels = params.in_channels if inp_or_tar =='inp' else params.out_channels
means = np.load(params.global_means_path)[:, channels]
stds = np.load(params.global_stds_path)[:, channels]
if crop_size_x == None:
crop_size_x = img_shape_x
if crop_size_y == None:
crop_size_y = img_shape_y
if normalize:
if params.normalization == 'minmax':
raise Exception("minmax not supported. Use zscore")
elif params.normalization == 'zscore':
img -=means
img /=stds
if params.add_grid:
if inp_or_tar == 'inp':
if params.gridtype == 'linear':
assert params.N_grid_channels == 2, "N_grid_channels must be set to 2 for gridtype linear"
x = np.meshgrid(np.linspace(-1, 1, img_shape_x))
y = np.meshgrid(np.linspace(-1, 1, img_shape_y))
grid_x, grid_y = np.meshgrid(y, x)
grid = np.stack((grid_x, grid_y), axis = 0)
elif params.gridtype == 'sinusoidal':
assert params.N_grid_channels == 4, "N_grid_channels must be set to 4 for gridtype sinusoidal"
x1 = np.meshgrid(np.sin(np.linspace(0, 2*np.pi, img_shape_x)))
x2 = np.meshgrid(np.cos(np.linspace(0, 2*np.pi, img_shape_x)))
y1 = np.meshgrid(np.sin(np.linspace(0, 2*np.pi, img_shape_y)))
y2 = np.meshgrid(np.cos(np.linspace(0, 2*np.pi, img_shape_y)))
grid_x1, grid_y1 = np.meshgrid(y1, x1)
grid_x2, grid_y2 = np.meshgrid(y2, x2)
grid = np.expand_dims(np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis = 0), axis = 0)
img = np.concatenate((img, grid), axis = 1 )
if params.orography and inp_or_tar == 'inp':
img = np.concatenate((img, np.expand_dims(orog, axis = (0,1) )), axis = 1)
n_channels += 1
if params.roll:
img = np.roll(img, y_roll, axis = -1)
if train and (crop_size_x or crop_size_y):
img = img[:,:,rnd_x:rnd_x+crop_size_x, rnd_y:rnd_y+crop_size_y]
if inp_or_tar == 'inp':
img = np.reshape(img, (n_channels*(n_history+1), crop_size_x, crop_size_y))
elif inp_or_tar == 'tar':
if params.two_step_training:
img = np.reshape(img, (n_channels*2, crop_size_x, crop_size_y))
else:
img = np.reshape(img, (n_channels, crop_size_x, crop_size_y))
if add_noise:
img = img + np.random.normal(0, scale=params.noise_std, size=img.shape)
return torch.as_tensor(img)
def reshape_precip(img, inp_or_tar, crop_size_x, crop_size_y,rnd_x, rnd_y, params, y_roll, train, normalize=True):
if len(np.shape(img)) ==2:
img = np.expand_dims(img, 0)
img = img[:,:720,:]
img_shape_x = img.shape[-2]
img_shape_y = img.shape[-1]
n_channels = 1
if crop_size_x == None:
crop_size_x = img_shape_x
if crop_size_y == None:
crop_size_y = img_shape_y
if normalize:
eps = params.precip_eps
img = np.log1p(img/eps)
if params.add_grid:
if inp_or_tar == 'inp':
if params.gridtype == 'linear':
assert params.N_grid_channels == 2, "N_grid_channels must be set to 2 for gridtype linear"
x = np.meshgrid(np.linspace(-1, 1, img_shape_x))
y = np.meshgrid(np.linspace(-1, 1, img_shape_y))
grid_x, grid_y = np.meshgrid(y, x)
grid = np.stack((grid_x, grid_y), axis = 0)
elif params.gridtype == 'sinusoidal':
assert params.N_grid_channels == 4, "N_grid_channels must be set to 4 for gridtype sinusoidal"
x1 = np.meshgrid(np.sin(np.linspace(0, 2*np.pi, img_shape_x)))
x2 = np.meshgrid(np.cos(np.linspace(0, 2*np.pi, img_shape_x)))
y1 = np.meshgrid(np.sin(np.linspace(0, 2*np.pi, img_shape_y)))
y2 = np.meshgrid(np.cos(np.linspace(0, 2*np.pi, img_shape_y)))
grid_x1, grid_y1 = np.meshgrid(y1, x1)
grid_x2, grid_y2 = np.meshgrid(y2, x2)
grid = np.expand_dims(np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis = 0), axis = 0)
img = np.concatenate((img, grid), axis = 1 )
if params.roll:
img = np.roll(img, y_roll, axis = -1)
if train and (crop_size_x or crop_size_y):
img = img[:,rnd_x:rnd_x+crop_size_x, rnd_y:rnd_y+crop_size_y]
img = np.reshape(img, (n_channels, crop_size_x, crop_size_y))
return torch.as_tensor(img)
def vis_precip(fields):
pred, tar = fields
fig, ax = plt.subplots(1, 2, figsize=(24,12))
ax[0].imshow(pred, cmap="coolwarm")
ax[0].set_title("tp pred")
ax[1].imshow(tar, cmap="coolwarm")
ax[1].set_title("tp tar")
fig.tight_layout()
return fig
import os
import logging
_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
def config_logger(log_level=logging.INFO):
logging.basicConfig(format=_format, level=log_level)
def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'):
if not os.path.exists(os.path.dirname(log_filename)):
os.makedirs(os.path.dirname(log_filename))
if logger_name is not None:
log = logging.getLogger(logger_name)
else:
log = logging.getLogger()
fh = logging.FileHandler(log_filename)
fh.setLevel(log_level)
fh.setFormatter(logging.Formatter(_format))
log.addHandler(fh)
def log_versions():
import torch
import subprocess
logging.info('--------------- Versions ---------------')
logging.info('git branch: ' + str(subprocess.check_output(['git', 'branch']).strip()))
logging.info('git hash: ' + str(subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()))
logging.info('Torch: ' + str(torch.__version__))
logging.info('----------------------------------------')
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment