Unverified Commit 5d7e9b06 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

AMP hotfix (#47)

* AMP hotfix

* Bumping up version to 0.7.1
parent 1bfda531
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
## Versioning ## Versioning
### v0.7.1
* Hotfix to AMP in SFNO example
### v0.7.0 ### v0.7.0
* CUDA-accelerated DISCO convolutions * CUDA-accelerated DISCO convolutions
......
...@@ -209,7 +209,7 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri ...@@ -209,7 +209,7 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri
## Remarks on automatic mixed precision (AMP) support ## Remarks on automatic mixed precision (AMP) support
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.cuda.amp.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically: Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
```python ```python
import torch import torch
...@@ -217,7 +217,7 @@ import torch_harmonics as th ...@@ -217,7 +217,7 @@ import torch_harmonics as th
sht = th.RealSHT(512, 1024, grid="equiangular").cuda() sht = th.RealSHT(512, 1024, grid="equiangular").cuda()
with torch.cuda.amp.autocast(enabled = True): with torch.autocast(device_type="cuda", enabled = True):
# do some AMP converted math here # do some AMP converted math here
x = some_math(x) x = some_math(x)
# convert tensor to float32 # convert tensor to float32
...@@ -225,7 +225,7 @@ with torch.cuda.amp.autocast(enabled = True): ...@@ -225,7 +225,7 @@ with torch.cuda.amp.autocast(enabled = True):
# now disable autocast specifically for the transform, # now disable autocast specifically for the transform,
# making sure that the tensors are not converted # making sure that the tensors are not converted
# back to reduced precision internally # back to reduced precision internally
with torch.cuda.amp.autocast(enabled = False): with torch.autocast(device_type="cuda", enabled = False):
xt = sht(x) xt = sht(x)
# continue operating on the transformed tensor # continue operating on the transformed tensor
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# #
...@@ -38,7 +38,6 @@ from functools import partial ...@@ -38,7 +38,6 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.cuda import amp
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -55,7 +54,7 @@ def l2loss_sphere(solver, prd, tar, relative=False, squared=True): ...@@ -55,7 +54,7 @@ def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1) loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)
if relative: if relative:
loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1) loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
if not squared: if not squared:
loss = torch.sqrt(loss) loss = torch.sqrt(loss)
loss = loss.mean() loss = loss.mean()
...@@ -124,7 +123,7 @@ def h1loss_sphere(solver, prd, tar, relative=False, squared=True): ...@@ -124,7 +123,7 @@ def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
h1_loss = torch.sum(h1_norm2, dim=(-1,-2)) h1_loss = torch.sum(h1_norm2, dim=(-1,-2))
l2_loss = torch.sum(l2_norm2, dim=(-1,-2)) l2_loss = torch.sum(l2_norm2, dim=(-1,-2))
# strictly speaking this is not exactly h1 loss # strictly speaking this is not exactly h1 loss
if not squared: if not squared:
loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss) loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss)
else: else:
...@@ -143,7 +142,7 @@ def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0): ...@@ -143,7 +142,7 @@ def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
fluct = solver.integrate_grid((tar - inp)**2, dimensionless=True, polar_opt=polar_opt) fluct = solver.integrate_grid((tar - inp)**2, dimensionless=True, polar_opt=polar_opt)
weight = fluct / torch.sum(fluct, dim=-1, keepdim=True) weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
# weight = weight.reshape(*weight.shape, 1, 1) # weight = weight.reshape(*weight.shape, 1, 1)
loss = weight * solver.integrate_grid((prd - tar)**2, dimensionless=True, polar_opt=polar_opt) loss = weight * solver.integrate_grid((prd - tar)**2, dimensionless=True, polar_opt=polar_opt)
if relative: if relative:
loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt)) loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
...@@ -194,7 +193,7 @@ def autoregressive_inference(model, ...@@ -194,7 +193,7 @@ def autoregressive_inference(model,
# classical model # classical model
start_time = time.time() start_time = time.time()
for i in range(1, autoreg_steps+1): for i in range(1, autoreg_steps+1):
# advance classical model # advance classical model
uspec = dataset.solver.timestep(uspec, nsteps) uspec = dataset.solver.timestep(uspec, nsteps)
...@@ -212,7 +211,7 @@ def autoregressive_inference(model, ...@@ -212,7 +211,7 @@ def autoregressive_inference(model,
ref = dataset.solver.spec2grid(uspec) ref = dataset.solver.spec2grid(uspec)
prd = prd * torch.sqrt(inp_var) + inp_mean prd = prd * torch.sqrt(inp_var) + inp_mean
losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item() losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
return losses, fno_times, nwp_times return losses, fno_times, nwp_times
...@@ -267,8 +266,8 @@ def train_model(model, ...@@ -267,8 +266,8 @@ def train_model(model,
model.train() model.train()
for inp, tar in dataloader: for inp, tar in dataloader:
with amp.autocast(enabled=enable_amp): with torch.autocast(device_type="cuda", enabled=enable_amp):
prd = model(inp) prd = model(inp)
for _ in range(nfuture): for _ in range(nfuture):
...@@ -357,7 +356,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -357,7 +356,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt_solver = 150 dt_solver = 150
nsteps = dt//dt_solver nsteps = dt//dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True) dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment # There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True) # dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
...@@ -418,7 +417,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -418,7 +417,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
# optimizer: # optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=3E-3) optimizer = torch.optim.Adam(model.parameters(), lr=3E-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp) gscaler = torch.GradScaler("cuda", enabled=enable_amp)
start_time = time.time() start_time = time.time()
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = "0.7.0" __version__ = "0.7.1"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# #
...@@ -33,7 +33,6 @@ import torch ...@@ -33,7 +33,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.fft import torch.fft
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torch.cuda import amp
import math import math
from torch_harmonics import * from torch_harmonics import *
...@@ -53,36 +52,36 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): ...@@ -53,36 +52,36 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x): def norm_cdf(x):
# Computes standard normal cumulative distribution function # Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2. return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std): if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.", "The distribution of values may be incorrect.",
stacklevel=2) stacklevel=2)
with torch.no_grad(): with torch.no_grad():
# Values are generated by using a truncated uniform distribution and # Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution. # then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values # Get upper and lower cdf values
l = norm_cdf((a - mean) / std) l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std) u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to # Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1]. # [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1) tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated # Use inverse cdf transform for normal distribution to get truncated
# standard normal # standard normal
tensor.erfinv_() tensor.erfinv_()
# Transform to proper mean, std # Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.)) tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean) tensor.add_(mean)
# Clamp to ensure it's in the proper range # Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b) tensor.clamp_(min=a, max=b)
return tensor return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the normal distribution. The values are effectively drawn from the
...@@ -128,9 +127,9 @@ class DropPath(nn.Module): ...@@ -128,9 +127,9 @@ class DropPath(nn.Module):
def __init__(self, drop_prob=None): def __init__(self, drop_prob=None):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x):
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, def __init__(self,
...@@ -171,11 +170,11 @@ class MLP(nn.Module): ...@@ -171,11 +170,11 @@ class MLP(nn.Module):
self.fwd = nn.Sequential(fc1, act, drop, fc2, drop) self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
else: else:
self.fwd = nn.Sequential(fc1, act, fc2) self.fwd = nn.Sequential(fc1, act, fc2)
@torch.jit.ignore @torch.jit.ignore
def checkpoint_forward(self, x): def checkpoint_forward(self, x):
return checkpoint(self.fwd, x) return checkpoint(self.fwd, x)
def forward(self, x): def forward(self, x):
if self.checkpointing: if self.checkpointing:
return self.checkpoint_forward(x) return self.checkpoint_forward(x)
...@@ -221,14 +220,14 @@ class InverseRealFFT2(nn.Module): ...@@ -221,14 +220,14 @@ class InverseRealFFT2(nn.Module):
def forward(self, x): def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho") return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class SpectralConvS2(nn.Module): class SpectralConvS2(nn.Module):
""" """
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2 Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers. domain via the RealFFT2 and InverseRealFFT2 wrappers.
""" """
def __init__(self, def __init__(self,
forward_transform, forward_transform,
inverse_transform, inverse_transform,
...@@ -277,18 +276,18 @@ class SpectralConvS2(nn.Module): ...@@ -277,18 +276,18 @@ class SpectralConvS2(nn.Module):
# get the right contraction function # get the right contraction function
self._contract = _contract self._contract = _contract
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
x = x.float() x = x.float()
residual = x residual = x
with amp.autocast(enabled=False): with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x) x = self.forward_transform(x)
if self.scale_residual: if self.scale_residual:
residual = self.inverse_transform(x) residual = self.inverse_transform(x)
...@@ -298,20 +297,20 @@ class SpectralConvS2(nn.Module): ...@@ -298,20 +297,20 @@ class SpectralConvS2(nn.Module):
x = self._contract(x, self.weight) x = self._contract(x, self.weight)
x = torch.view_as_complex(x) x = torch.view_as_complex(x)
with amp.autocast(enabled=False): with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x) x = self.inverse_transform(x)
if hasattr(self, "bias"): if hasattr(self, "bias"):
x = x + self.bias x = x + self.bias
x = x.type(dtype) x = x.type(dtype)
return x, residual return x, residual
class FactorizedSpectralConvS2(nn.Module): class FactorizedSpectralConvS2(nn.Module):
""" """
Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized
""" """
def __init__(self, def __init__(self,
forward_transform, forward_transform,
inverse_transform, inverse_transform,
...@@ -366,9 +365,9 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -366,9 +365,9 @@ class FactorizedSpectralConvS2(nn.Module):
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}") raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors # form weight tensors
self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization, self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization,
fixed_rank_modes=False, **decomposition_kwargs) fixed_rank_modes=False, **decomposition_kwargs)
# initialization of weights # initialization of weights
scale = math.sqrt(gain / in_channels) scale = math.sqrt(gain / in_channels)
self.weight.normal_(0, scale) self.weight.normal_(0, scale)
...@@ -376,29 +375,29 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -376,29 +375,29 @@ class FactorizedSpectralConvS2(nn.Module):
# get the right contraction function # get the right contraction function
from .factorizations import get_contract_fun from .factorizations import get_contract_fun
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable) self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
x = x.float() x = x.float()
residual = x residual = x
with amp.autocast(enabled=False): with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x) x = self.forward_transform(x)
if self.scale_residual: if self.scale_residual:
residual = self.inverse_transform(x) residual = self.inverse_transform(x)
x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type)
with amp.autocast(enabled=False): with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x) x = self.inverse_transform(x)
if hasattr(self, "bias"): if hasattr(self, "bias"):
x = x + self.bias x = x + self.bias
x = x.type(dtype) x = x.type(dtype)
return x, residual return x, residual
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment