Unverified Commit 5d7e9b06 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

AMP hotfix (#47)

* AMP hotfix

* Bumping up version to 0.7.1
parent 1bfda531
......@@ -2,6 +2,10 @@
## Versioning
### v0.7.1
* Hotfix to AMP in SFNO example
### v0.7.0
* CUDA-accelerated DISCO convolutions
......
......@@ -209,7 +209,7 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri
## Remarks on automatic mixed precision (AMP) support
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.cuda.amp.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
```python
import torch
......@@ -217,7 +217,7 @@ import torch_harmonics as th
sht = th.RealSHT(512, 1024, grid="equiangular").cuda()
with torch.cuda.amp.autocast(enabled = True):
with torch.autocast(device_type="cuda", enabled = True):
# do some AMP converted math here
x = some_math(x)
# convert tensor to float32
......@@ -225,7 +225,7 @@ with torch.cuda.amp.autocast(enabled = True):
# now disable autocast specifically for the transform,
# making sure that the tensors are not converted
# back to reduced precision internally
with torch.cuda.amp.autocast(enabled = False):
with torch.autocast(device_type="cuda", enabled = False):
xt = sht(x)
# continue operating on the transformed tensor
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
......@@ -38,7 +38,6 @@ from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda import amp
import numpy as np
import pandas as pd
......@@ -55,7 +54,7 @@ def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)
if relative:
loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
......@@ -124,7 +123,7 @@ def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
h1_loss = torch.sum(h1_norm2, dim=(-1,-2))
l2_loss = torch.sum(l2_norm2, dim=(-1,-2))
# strictly speaking this is not exactly h1 loss
# strictly speaking this is not exactly h1 loss
if not squared:
loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss)
else:
......@@ -143,7 +142,7 @@ def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
fluct = solver.integrate_grid((tar - inp)**2, dimensionless=True, polar_opt=polar_opt)
weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
# weight = weight.reshape(*weight.shape, 1, 1)
loss = weight * solver.integrate_grid((prd - tar)**2, dimensionless=True, polar_opt=polar_opt)
if relative:
loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
......@@ -194,7 +193,7 @@ def autoregressive_inference(model,
# classical model
start_time = time.time()
for i in range(1, autoreg_steps+1):
# advance classical model
uspec = dataset.solver.timestep(uspec, nsteps)
......@@ -212,7 +211,7 @@ def autoregressive_inference(model,
ref = dataset.solver.spec2grid(uspec)
prd = prd * torch.sqrt(inp_var) + inp_mean
losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
return losses, fno_times, nwp_times
......@@ -267,8 +266,8 @@ def train_model(model,
model.train()
for inp, tar in dataloader:
with amp.autocast(enabled=enable_amp):
with torch.autocast(device_type="cuda", enabled=enable_amp):
prd = model(inp)
for _ in range(nfuture):
......@@ -357,7 +356,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt_solver = 150
nsteps = dt//dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
......@@ -418,7 +417,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
# optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=3E-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
gscaler = torch.GradScaler("cuda", enabled=enable_amp)
start_time = time.time()
......
......@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = "0.7.0"
__version__ = "0.7.1"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
......@@ -33,7 +33,6 @@ import torch
import torch.nn as nn
import torch.fft
from torch.utils.checkpoint import checkpoint
from torch.cuda import amp
import math
from torch_harmonics import *
......@@ -53,36 +52,36 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
......@@ -128,9 +127,9 @@ class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
return drop_path(x, self.drop_prob, self.training)
class MLP(nn.Module):
def __init__(self,
......@@ -171,11 +170,11 @@ class MLP(nn.Module):
self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
else:
self.fwd = nn.Sequential(fc1, act, fc2)
@torch.jit.ignore
def checkpoint_forward(self, x):
return checkpoint(self.fwd, x)
def forward(self, x):
if self.checkpointing:
return self.checkpoint_forward(x)
......@@ -221,14 +220,14 @@ class InverseRealFFT2(nn.Module):
def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class SpectralConvS2(nn.Module):
"""
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
"""
def __init__(self,
forward_transform,
inverse_transform,
......@@ -277,18 +276,18 @@ class SpectralConvS2(nn.Module):
# get the right contraction function
self._contract = _contract
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
......@@ -298,20 +297,20 @@ class SpectralConvS2(nn.Module):
x = self._contract(x, self.weight)
x = torch.view_as_complex(x)
with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)
if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)
return x, residual
class FactorizedSpectralConvS2(nn.Module):
"""
Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized
"""
def __init__(self,
forward_transform,
inverse_transform,
......@@ -366,9 +365,9 @@ class FactorizedSpectralConvS2(nn.Module):
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization,
self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization,
fixed_rank_modes=False, **decomposition_kwargs)
# initialization of weights
scale = math.sqrt(gain / in_channels)
self.weight.normal_(0, scale)
......@@ -376,29 +375,29 @@ class FactorizedSpectralConvS2(nn.Module):
# get the right contraction function
from .factorizations import get_contract_fun
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type)
with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)
if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)
return x, residual
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment