Unverified Commit 5d7e9b06 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

AMP hotfix (#47)

* AMP hotfix

* Bumping up version to 0.7.1
parent 1bfda531
......@@ -2,6 +2,10 @@
## Versioning
### v0.7.1
* Hotfix to AMP in SFNO example
### v0.7.0
* CUDA-accelerated DISCO convolutions
......
......@@ -209,7 +209,7 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri
## Remarks on automatic mixed precision (AMP) support
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.cuda.amp.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
```python
import torch
......@@ -217,7 +217,7 @@ import torch_harmonics as th
sht = th.RealSHT(512, 1024, grid="equiangular").cuda()
with torch.cuda.amp.autocast(enabled = True):
with torch.autocast(device_type="cuda", enabled = True):
# do some AMP converted math here
x = some_math(x)
# convert tensor to float32
......@@ -225,7 +225,7 @@ with torch.cuda.amp.autocast(enabled = True):
# now disable autocast specifically for the transform,
# making sure that the tensors are not converted
# back to reduced precision internally
with torch.cuda.amp.autocast(enabled = False):
with torch.autocast(device_type="cuda", enabled = False):
xt = sht(x)
# continue operating on the transformed tensor
......
......@@ -38,7 +38,6 @@ from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda import amp
import numpy as np
import pandas as pd
......@@ -268,7 +267,7 @@ def train_model(model,
for inp, tar in dataloader:
with amp.autocast(enabled=enable_amp):
with torch.autocast(device_type="cuda", enabled=enable_amp):
prd = model(inp)
for _ in range(nfuture):
......@@ -418,7 +417,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
# optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=3E-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
gscaler = torch.GradScaler("cuda", enabled=enable_amp)
start_time = time.time()
......
......@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = "0.7.0"
__version__ = "0.7.1"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
......@@ -33,7 +33,6 @@ import torch
import torch.nn as nn
import torch.fft
from torch.utils.checkpoint import checkpoint
from torch.cuda import amp
import math
from torch_harmonics import *
......@@ -288,7 +287,7 @@ class SpectralConvS2(nn.Module):
x = x.float()
residual = x
with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
......@@ -298,7 +297,7 @@ class SpectralConvS2(nn.Module):
x = self._contract(x, self.weight)
x = torch.view_as_complex(x)
with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)
if hasattr(self, "bias"):
......@@ -387,14 +386,14 @@ class FactorizedSpectralConvS2(nn.Module):
x = x.float()
residual = x
with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type)
with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)
if hasattr(self, "bias"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment