"vscode:/vscode.git/clone" did not exist on "7bfe923ce8d7b9791b7f392c7f0b6754f203261c"
Unverified Commit 5d7e9b06 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

AMP hotfix (#47)

* AMP hotfix

* Bumping up version to 0.7.1
parent 1bfda531
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
## Versioning ## Versioning
### v0.7.1
* Hotfix to AMP in SFNO example
### v0.7.0 ### v0.7.0
* CUDA-accelerated DISCO convolutions * CUDA-accelerated DISCO convolutions
......
...@@ -209,7 +209,7 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri ...@@ -209,7 +209,7 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri
## Remarks on automatic mixed precision (AMP) support ## Remarks on automatic mixed precision (AMP) support
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.cuda.amp.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically: Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
```python ```python
import torch import torch
...@@ -217,7 +217,7 @@ import torch_harmonics as th ...@@ -217,7 +217,7 @@ import torch_harmonics as th
sht = th.RealSHT(512, 1024, grid="equiangular").cuda() sht = th.RealSHT(512, 1024, grid="equiangular").cuda()
with torch.cuda.amp.autocast(enabled = True): with torch.autocast(device_type="cuda", enabled = True):
# do some AMP converted math here # do some AMP converted math here
x = some_math(x) x = some_math(x)
# convert tensor to float32 # convert tensor to float32
...@@ -225,7 +225,7 @@ with torch.cuda.amp.autocast(enabled = True): ...@@ -225,7 +225,7 @@ with torch.cuda.amp.autocast(enabled = True):
# now disable autocast specifically for the transform, # now disable autocast specifically for the transform,
# making sure that the tensors are not converted # making sure that the tensors are not converted
# back to reduced precision internally # back to reduced precision internally
with torch.cuda.amp.autocast(enabled = False): with torch.autocast(device_type="cuda", enabled = False):
xt = sht(x) xt = sht(x)
# continue operating on the transformed tensor # continue operating on the transformed tensor
......
...@@ -38,7 +38,6 @@ from functools import partial ...@@ -38,7 +38,6 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.cuda import amp
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -268,7 +267,7 @@ def train_model(model, ...@@ -268,7 +267,7 @@ def train_model(model,
for inp, tar in dataloader: for inp, tar in dataloader:
with amp.autocast(enabled=enable_amp): with torch.autocast(device_type="cuda", enabled=enable_amp):
prd = model(inp) prd = model(inp)
for _ in range(nfuture): for _ in range(nfuture):
...@@ -418,7 +417,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -418,7 +417,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
# optimizer: # optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=3E-3) optimizer = torch.optim.Adam(model.parameters(), lr=3E-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp) gscaler = torch.GradScaler("cuda", enabled=enable_amp)
start_time = time.time() start_time = time.time()
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = "0.7.0" __version__ = "0.7.1"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
...@@ -33,7 +33,6 @@ import torch ...@@ -33,7 +33,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.fft import torch.fft
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torch.cuda import amp
import math import math
from torch_harmonics import * from torch_harmonics import *
...@@ -288,7 +287,7 @@ class SpectralConvS2(nn.Module): ...@@ -288,7 +287,7 @@ class SpectralConvS2(nn.Module):
x = x.float() x = x.float()
residual = x residual = x
with amp.autocast(enabled=False): with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x) x = self.forward_transform(x)
if self.scale_residual: if self.scale_residual:
residual = self.inverse_transform(x) residual = self.inverse_transform(x)
...@@ -298,7 +297,7 @@ class SpectralConvS2(nn.Module): ...@@ -298,7 +297,7 @@ class SpectralConvS2(nn.Module):
x = self._contract(x, self.weight) x = self._contract(x, self.weight)
x = torch.view_as_complex(x) x = torch.view_as_complex(x)
with amp.autocast(enabled=False): with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x) x = self.inverse_transform(x)
if hasattr(self, "bias"): if hasattr(self, "bias"):
...@@ -387,14 +386,14 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -387,14 +386,14 @@ class FactorizedSpectralConvS2(nn.Module):
x = x.float() x = x.float()
residual = x residual = x
with amp.autocast(enabled=False): with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x) x = self.forward_transform(x)
if self.scale_residual: if self.scale_residual:
residual = self.inverse_transform(x) residual = self.inverse_transform(x)
x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type)
with amp.autocast(enabled=False): with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x) x = self.inverse_transform(x)
if hasattr(self, "bias"): if hasattr(self, "bias"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment