"vscode:/vscode.git/clone" did not exist on "8cd15cbeb72bd0968fc19c25a740e10006191de3"
Unverified Commit 813f6e61 authored by Jinze Xue's avatar Jinze Xue Committed by GitHub
Browse files

CUAEV double backward for force training (#571)



* init

* init

* double backward test

* fix doublebackward test

* add another test

* rm gaev

* radial done

* angular init

* angular done

* update

* force training benchmark

* format

* update

* benchmark

* update

* update

* clean redundancy codes

* update

* adapt review request

* update

* update

* update

* update

* update

* update

* fix

* fix

* cuAngularAEVs code deduplicate

* pairwise double backward

* cuRadialAEVs dedup

* pairwiseDistance dedup

* format

* readme build notes

* save

* update

* save

* save

* update

* fix

* save

* add equations on comments
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>
parent efae6d9d
......@@ -50,6 +50,7 @@ def cuda_extension(build_all=False):
import torch
from torch.utils.cpp_extension import CUDAExtension
SMs = None
print('-' * 75)
if not build_all:
SMs = []
devices = torch.cuda.device_count()
......@@ -81,12 +82,13 @@ def cuda_extension(build_all=False):
if cuda_version >= 11.1:
nvcc_args.append("-gencode=arch=compute_86,code=sm_86")
print("nvcc_args: ", nvcc_args)
print('-' * 75)
return CUDAExtension(
name='torchani.cuaev',
pkg='torchani.cuaev',
sources=glob.glob('torchani/cuaev/*.cu'),
include_dirs=maybe_download_cub(),
extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args})
extra_compile_args={'cxx': ['-std=c++17'], 'nvcc': nvcc_args})
def cuaev_kwargs():
......
......@@ -3,9 +3,9 @@ import torch
import torchani
import unittest
import pickle
import copy
from torchani.testing import TestCase, make_tensor
path = os.path.dirname(os.path.realpath(__file__))
skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(),
......@@ -52,6 +52,64 @@ class TestCUAEV(TestCase):
num_species = 4
self.aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
self.cuaev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species, use_cuda_extension=True)
self.nn = torch.nn.Sequential(torch.nn.Linear(384, 1, False)).to(self.device)
self.radial_length = self.aev_computer.radial_length
def _double_backward_1_test(self, species, coordinates):
def double_backward(aev_computer, species, coordinates):
torch.manual_seed(12345)
self.nn.zero_grad()
_, aev = aev_computer((species, coordinates))
E = self.nn(aev).sum()
force = -torch.autograd.grad(E, coordinates, create_graph=True, retain_graph=True)[0]
force_true = torch.randn_like(force)
loss = torch.abs(force_true - force).sum(dim=(1, 2)).mean()
loss.backward()
param = next(self.nn.parameters())
param_grad = copy.deepcopy(param.grad)
return aev, force, param_grad
aev, force_ref, param_grad_ref = double_backward(self.aev_computer, species, coordinates)
cu_aev, force_cuaev, param_grad = double_backward(self.cuaev_computer, species, coordinates)
self.assertEqual(cu_aev, aev, f'cu_aev: {cu_aev}\n aev: {aev}')
self.assertEqual(force_cuaev, force_ref, f'\nforce_cuaev: {force_cuaev}\n force_ref: {force_ref}')
self.assertEqual(param_grad, param_grad_ref, f'\nparam_grad: {param_grad}\n param_grad_ref: {param_grad_ref}', atol=5e-5, rtol=5e-5)
def _double_backward_2_test(self, species, coordinates):
def double_backward(aev_computer, species, coordinates):
"""
# We want to get the gradient of `grad_aev`, which requires `grad_aev` to be a leaf node
# due to `torch.autograd`'s limitation. So we split the coord->aev->energy graph into two separate
# graphs: coord->aev and aev->energy, so that aev and grad_aev are now leaves.
"""
torch.manual_seed(12345)
# graph1 input -> aev
coordinates = coordinates.clone().detach().requires_grad_()
_, aev = aev_computer((species, coordinates))
# graph2 aev -> E
aev_ = aev.clone().detach().requires_grad_()
E = self.nn(aev_).sum()
# graph2 backward
aev_grad = torch.autograd.grad(E, aev_, create_graph=True, retain_graph=True)[0]
# graph1 backward
aev_grad_ = aev_grad.clone().detach().requires_grad_()
force = torch.autograd.grad(aev, coordinates, aev_grad_, create_graph=True, retain_graph=True)[0]
# force loss backward
force_true = torch.randn_like(force)
loss = torch.abs(force_true - force).sum(dim=(1, 2)).mean()
aev_grad_grad = torch.autograd.grad(loss, aev_grad_, create_graph=True, retain_graph=True)[0]
return aev, force, aev_grad_grad
aev, force_ref, aev_grad_grad = double_backward(self.aev_computer, species, coordinates)
cu_aev, force_cuaev, cuaev_grad_grad = double_backward(self.cuaev_computer, species, coordinates)
self.assertEqual(cu_aev, aev, f'cu_aev: {cu_aev}\n aev: {aev}', atol=5e-5, rtol=5e-5)
self.assertEqual(force_cuaev, force_ref, f'\nforce_cuaev: {force_cuaev}\n force_ref: {force_ref}', atol=5e-5, rtol=5e-5)
self.assertEqual(cuaev_grad_grad, aev_grad_grad, f'\ncuaev_grad_grad: {cuaev_grad_grad}\n aev_grad_grad: {aev_grad_grad}', atol=5e-5, rtol=5e-5)
def testSimple(self):
coordinates = torch.tensor([
......@@ -89,15 +147,58 @@ class TestCUAEV(TestCase):
_, aev = self.aev_computer((species, coordinates))
aev.backward(torch.ones_like(aev))
aev_grad = coordinates.grad
force_ref = coordinates.grad
coordinates = coordinates.clone().detach()
coordinates.requires_grad_()
_, cu_aev = self.cuaev_computer((species, coordinates))
cu_aev.backward(torch.ones_like(cu_aev))
cuaev_grad = coordinates.grad
force_cuaev = coordinates.grad
self.assertEqual(cu_aev, aev, f'cu_aev: {cu_aev}\n aev: {aev}')
self.assertEqual(cuaev_grad, aev_grad, f'\ncuaev_grad: {cuaev_grad}\n aev_grad: {aev_grad}')
self.assertEqual(force_cuaev, force_ref, f'\nforce_cuaev: {force_cuaev}\n aev_grad: {force_ref}')
def testSimpleDoubleBackward_1(self):
"""
Test Double Backward (Force training) by parameters' gradient
"""
coordinates = torch.tensor([
[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]],
[[-4.1862600, 0.0575700, -0.0381200],
[-3.1689400, 0.0523700, 0.0200000],
[-4.4978600, 0.8211300, 0.5604100],
[-4.4978700, -0.8000100, 0.4155600],
[0.00000000, -0.00000000, -0.00000000]]
], requires_grad=True, device=self.device)
species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device)
self._double_backward_1_test(species, coordinates)
def testSimpleDoubleBackward_2(self):
"""
Test Double Backward (Force training) directly.
Double backward:
Forward: input is dE/dAEV, output is force
Backward: input is dLoss/dForce, output is dLoss/(dE/dAEV)
"""
coordinates = torch.tensor([
[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]],
[[-4.1862600, 0.0575700, -0.0381200],
[-3.1689400, 0.0523700, 0.0200000],
[-4.4978600, 0.8211300, 0.5604100],
[-4.4978700, -0.8000100, 0.4155600],
[0.00000000, -0.00000000, -0.00000000]]
], requires_grad=True, device=self.device)
species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device)
self._double_backward_2_test(species, coordinates)
def testTripeptideMD(self):
for i in range(100):
......@@ -129,6 +230,15 @@ class TestCUAEV(TestCase):
self.assertEqual(cu_aev, aev)
self.assertEqual(cuaev_grad, aev_grad, atol=5e-5, rtol=5e-5)
def testTripeptideMDDoubleBackward_2(self):
for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, *_ = pickle.load(f)
coordinates = torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device).requires_grad_(True)
species = torch.from_numpy(species).unsqueeze(0).to(self.device)
self._double_backward_2_test(species, coordinates)
def testNIST(self):
datafile = os.path.join(path, 'test_data/NIST/all')
with open(datafile, 'rb') as f:
......@@ -144,7 +254,7 @@ class TestCUAEV(TestCase):
datafile = os.path.join(path, 'test_data/NIST/all')
with open(datafile, 'rb') as f:
data = pickle.load(f)
for coordinates, species, _, _, _, _ in data:
for coordinates, species, _, _, _, _ in data[:10]:
coordinates = torch.from_numpy(coordinates).to(torch.float).to(self.device).requires_grad_(True)
species = torch.from_numpy(species).to(self.device)
_, aev = self.aev_computer((species, coordinates))
......@@ -159,12 +269,21 @@ class TestCUAEV(TestCase):
self.assertEqual(cu_aev, aev)
self.assertEqual(cuaev_grad, aev_grad, atol=5e-5, rtol=5e-5)
def testNISTDoubleBackward_2(self):
datafile = os.path.join(path, 'test_data/NIST/all')
with open(datafile, 'rb') as f:
data = pickle.load(f)
for coordinates, species, _, _, _, _ in data[:3]:
coordinates = torch.from_numpy(coordinates).to(torch.float).to(self.device).requires_grad_(True)
species = torch.from_numpy(species).to(self.device)
self._double_backward_2_test(species, coordinates)
def testVeryDenseMolecule(self):
"""
Test very dense molecule for aev correctness, especially for angular kernel when center atom pairs are more than 32.
issue: https://github.com/aiqm/torchani/pull/555
"""
for i in range(100):
for i in range(5):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, *_ = pickle.load(f)
......@@ -176,7 +295,7 @@ class TestCUAEV(TestCase):
self.assertEqual(cu_aev, aev, atol=5e-5, rtol=5e-5)
def testVeryDenseMoleculeBackward(self):
for i in range(100):
for i in range(5):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, *_ = pickle.load(f)
......
......@@ -26,7 +26,7 @@ def info(text):
print('\033[32m{}\33[0m'.format(text)) # green
def benchmark(speciesPositions, aev_comp, N, check_gpu_mem):
def benchmark(speciesPositions, aev_comp, N, check_gpu_mem, nn=None, verbose=True):
torch.cuda.empty_cache()
gc.collect()
torch.cuda.synchronize()
......@@ -34,14 +34,25 @@ def benchmark(speciesPositions, aev_comp, N, check_gpu_mem):
aev = None
for i in range(N):
aev = aev_comp(speciesPositions).aevs
species, coordinates = speciesPositions
if nn is not None: # double backward
coordinates = coordinates.requires_grad_()
_, aev = aev_computer((species, coordinates))
E = nn(aev).sum()
force = -torch.autograd.grad(E, coordinates, create_graph=True, retain_graph=True)[0]
force_true = torch.randn_like(force)
loss = torch.abs(force_true - force).sum(dim=(1, 2)).mean()
loss.backward()
else:
_, aev = aev_comp((species, coordinates))
if i == 2 and check_gpu_mem:
checkgpu()
torch.cuda.synchronize()
delta = time.time() - start
print(f' Duration: {delta:.2f} s')
print(f' Speed: {delta/N*1000:.2f} ms/it')
if verbose:
print(f' Duration: {delta:.2f} s')
print(f' Speed: {delta/N*1000:.2f} ms/it')
return aev, delta
......@@ -63,10 +74,14 @@ if __name__ == "__main__":
dest='check_gpu_mem',
action='store_const',
const=1)
parser.add_argument('--nsight',
parser.add_argument('-s', '--nsight',
action='store_true',
help='use nsight profile')
parser.add_argument('-b', '--backward',
action='store_true',
help='benchmark double backward')
parser.set_defaults(check_gpu_mem=0)
parser.set_defaults(backward=0)
parser = parser.parse_args()
path = os.path.dirname(os.path.realpath(__file__))
......@@ -74,7 +89,7 @@ if __name__ == "__main__":
device = torch.device('cuda')
files = ['small.pdb', '1hz5.pdb', '6W8H.pdb']
N = 500
N = 200
if parser.nsight:
N = 3
torch.cuda.profiler.start()
......@@ -89,17 +104,24 @@ if __name__ == "__main__":
nnp = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
speciesPositions = nnp.species_converter((species, positions))
aev_computer = nnp.aev_computer
if parser.backward:
nn = torch.nn.Sequential(torch.nn.Linear(nnp.aev_computer.aev_length, 1, False)).to(device)
else:
nn = None
if parser.nsight:
torch.cuda.nvtx.range_push(file)
print('Original TorchANI:')
aev_ref, delta_ref = benchmark(speciesPositions, aev_computer, N, check_gpu_mem)
aev_ref, delta_ref = benchmark(speciesPositions, aev_computer, N, check_gpu_mem, nn)
print()
print('CUaev:')
nnp.aev_computer.use_cuda_extension = True
cuaev_computer = nnp.aev_computer
aev, delta = benchmark(speciesPositions, cuaev_computer, N, check_gpu_mem)
# warm up
_, _ = benchmark(speciesPositions, cuaev_computer, 1, check_gpu_mem, nn, verbose=False)
# run
aev, delta = benchmark(speciesPositions, cuaev_computer, N, check_gpu_mem, nn)
if parser.nsight:
torch.cuda.nvtx.range_pop()
......
......@@ -10,6 +10,9 @@ import os
import pickle
from torchani.units import hartree2kcalmol
summary = ''
runcounter = 0
def build_network():
H_network = torch.nn.Sequential(
......@@ -51,7 +54,17 @@ def build_network():
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
return [H_network, C_network, N_network, O_network]
nets = [H_network, C_network, N_network, O_network]
for net in nets:
net.apply(init_normal)
return nets
def init_normal(m):
if type(m) == torch.nn.Linear:
torch.nn.init.kaiming_uniform_(m.weight)
def checkgpu(device=None):
......@@ -66,6 +79,7 @@ def checkgpu(device=None):
info = pynvml.nvmlDeviceGetMemoryInfo(h)
name = pynvml.nvmlDeviceGetName(h)
print(' GPU Memory Used (nvidia-smi): {:7.1f}MB / {:.1f}MB ({})'.format(info.used / 1024 / 1024, info.total / 1024 / 1024, name.decode()))
return f'{(info.used / 1024 / 1024):.1f}MB'
def alert(text):
......@@ -85,7 +99,20 @@ def print_timer(label, t):
print(f'{label} - {t}')
def benchmark(parser, dataset, use_cuda_extension, force_inference=False):
def format_time(t):
if t < 1:
t = f'{t * 1000:.1f} ms'
else:
t = f'{t:.3f} sec'
return t
def benchmark(parser, dataset, use_cuda_extension, force_train=False):
global summary
global runcounter
if parser.nsight and runcounter >= 0:
torch.cuda.nvtx.range_push(parser.runname)
synchronize = True
timers = {}
......@@ -145,14 +172,14 @@ def benchmark(parser, dataset, use_cuda_extension, force_inference=False):
for i, properties in enumerate(dataset):
species = properties['species'].to(parser.device)
coordinates = properties['coordinates'].to(parser.device).float().requires_grad_(force_inference)
coordinates = properties['coordinates'].to(parser.device).float().requires_grad_(force_train)
true_energies = properties['energies'].to(parser.device).float()
num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
_, predicted_energies = model((species, coordinates))
# TODO add sync after aev is done
sync_cuda(synchronize)
energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
if force_inference:
if force_train:
sync_cuda(synchronize)
force_coefficient = 0.1
true_forces = properties['forces'].to(parser.device).float()
......@@ -172,21 +199,21 @@ def benchmark(parser, dataset, use_cuda_extension, force_inference=False):
loss = energy_loss
rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
progbar.update(i, values=[("rmse", rmse)])
if not force_inference:
sync_cuda(synchronize)
loss_start = time.time()
loss.backward()
# print('2', coordinates.grad)
sync_cuda(synchronize)
loss_stop = time.time()
loss_time += loss_stop - loss_start
optimizer.step()
sync_cuda(synchronize)
sync_cuda(synchronize)
loss_start = time.time()
loss.backward()
sync_cuda(synchronize)
loss_stop = time.time()
loss_time += loss_stop - loss_start
optimizer.step()
sync_cuda(synchronize)
checkgpu()
gpumem = checkgpu()
sync_cuda(synchronize)
stop = time.time()
if parser.nsight and runcounter >= 0:
torch.cuda.nvtx.range_pop()
print('=> More detail about benchmark PER EPOCH')
total_time = (stop - start) / parser.num_epochs
loss_time = loss_time / parser.num_epochs
......@@ -199,9 +226,18 @@ def benchmark(parser, dataset, use_cuda_extension, force_inference=False):
print_timer(' Backward', loss_time)
print_timer(' Force', force_time)
print_timer(' Optimizer', opti_time)
print_timer(' Others', total_time - loss_time - aev_time - forward_time - opti_time - force_time)
others_time = total_time - loss_time - aev_time - forward_time - opti_time - force_time
print_timer(' Others', others_time)
print_timer(' Epoch time', total_time)
if runcounter == 0:
summary += '\n' + 'RUN'.ljust(27) + 'Total AEV'.ljust(13) + 'Forward'.ljust(13) + 'Backward'.ljust(13) + 'Force'.ljust(13) + \
'Optimizer'.ljust(13) + 'Others'.ljust(13) + 'Epoch time'.ljust(13) + 'GPU'.ljust(13) + '\n'
if runcounter >= 0:
summary += f'{runcounter} {parser.runname}'.ljust(27) + f'{format_time(aev_time)}'.ljust(13) + f'{format_time(forward_time)}'.ljust(13) + f'{format_time(loss_time)}'.ljust(13) + f'{format_time(force_time)}'.ljust(13) + \
f'{format_time(opti_time)}'.ljust(13) + f'{format_time(others_time)}'.ljust(13) + f'{format_time(total_time)}'.ljust(13) + f'{gpumem}'.ljust(13) + '\n'
runcounter += 1
if __name__ == "__main__":
# parse command line arguments
......@@ -249,20 +285,43 @@ if __name__ == "__main__":
print(' {}'.format(torch.cuda.get_device_properties(i)))
checkgpu(i)
print("\n\n=> Test 1: USE cuda extension, Energy training")
# Warming UP
if len(dataset_shuffled) < 100:
runcounter = -1
parser.runname = 'Warning UP'
print(f"\n\n=> Test 0: {parser.runname}")
torch.cuda.empty_cache()
gc.collect()
benchmark(parser, dataset_shuffled, use_cuda_extension=True, force_train=False)
if parser.nsight:
torch.cuda.profiler.start()
parser.runname = 'cu Energy train'
print(f"\n\n=> Test 1: {parser.runname}")
torch.cuda.empty_cache()
gc.collect()
benchmark(parser, dataset_shuffled, use_cuda_extension=True, force_inference=False)
print("\n\n=> Test 2: NO cuda extension, Energy training")
benchmark(parser, dataset_shuffled, use_cuda_extension=True, force_train=False)
parser.runname = 'py Energy train'
print(f"\n\n=> Test 2: {parser.runname}")
torch.cuda.empty_cache()
gc.collect()
benchmark(parser, dataset_shuffled, use_cuda_extension=False, force_inference=False)
benchmark(parser, dataset_shuffled, use_cuda_extension=False, force_train=False)
print("\n\n=> Test 3: USE cuda extension, Force and Energy inference")
parser.runname = 'cu Energy + Force train'
print(f"\n\n=> Test 3: {parser.runname}")
torch.cuda.empty_cache()
gc.collect()
benchmark(parser, dataset_shuffled, use_cuda_extension=True, force_inference=True)
print("\n\n=> Test 4: NO cuda extension, Force and Energy inference")
benchmark(parser, dataset_shuffled, use_cuda_extension=True, force_train=True)
parser.runname = 'py Energy + Force train'
print(f"\n\n=> Test 4: {parser.runname}")
torch.cuda.empty_cache()
gc.collect()
benchmark(parser, dataset_shuffled, use_cuda_extension=False, force_inference=True)
benchmark(parser, dataset_shuffled, use_cuda_extension=False, force_train=True)
print(summary)
if parser.nsight:
torch.cuda.profiler.stop()
# CUAEV
CUDA Extension for AEV calculation.
Performance improvement is expected to be ~3X for AEV computation and ~1.5X for overall training workflow.
Performance improvement is expected to be ~3X for AEV computation and ~1.5X for energy training, 2.6X for energy+force training.
## Requirement
CUAEV needs the nightly version [pytorch](https://pytorch.org/) to be able to work.
If you you use conda, you could install it by
If you use conda, you could install it by
```
conda install pytorch torchvision torchaudio cudatoolkit={YOUR_CUDA_VERSION} -c pytorch-nightly
```
......@@ -18,21 +18,102 @@ cd torchani
# choose one option below
# use --cuaev-all-sms if you are building in SLURM environment and there are multiple different gpus in a node
# use --cuaev will only build for detected gpus
python setup.py install --cuaev-all-sms # build for all sms
python setup.py install --cuaev # only build for detected gpus
python setup.py install --cuaev-all-sms # build for all gpus
# or for development
# `pip install -e . && ` is only needed for the very first install (because issue of https://github.com/pypa/pip/issues/1883)
pip install -e . && pip install -v -e . --global-option="--cuaev-all-sms" # build for all sms
pip install -e . && pip install -v -e . --global-option="--cuaev" # only build for detected gpus
pip install -e . && pip install -v -e . --global-option="--cuaev-all-sms" # build for all gpus
```
<del>Notes for install on Hipergator</del> (Currently not working because Pytorch dropped the official build for cuda/10.0)
Notes for build CUAEV on multiple HPC
<details>
<summary>Bridges2</summary>
```bash
# prepare
srun -p GPU-small --ntasks=1 --cpus-per-task=5 --gpus=1 --time=02:00:00 --mem=20gb --pty -u bash -i
module load cuda/10.2.0
conda create -n cuaev python=3.8
conda activate cuaev
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-nightly
# install torchani
git clone https://github.com/aiqm/torchani.git
cd torchani
pip install -e . && pip install -v -e . --global-option="--cuaev"
```
</details>
<details>
<summary>Hipergator</summary>
```bash
srun -p gpu --ntasks=1 --cpus-per-task=2 --gpus=geforce:1 --time=02:00:00 --mem=10gb --pty -u bash -i
module load cuda/10.0.130 gcc/7.3.0 git
conda remove --name cuaev --all -y && conda create -n cuaev python=3.8 -y
conda activate cuaev
# install compiled torch-cu100 because pytorch droped official build for cuda 10.0
. /home/jinzexue/pytorch/loadmodule # note that there is a space after .
. /home/jinzexue/pytorch/install_deps
pip install $(realpath /home/jinzexue/pytorch/dist/torch-nightly-cu100.whl)
# check if pytorch is working, should print available's gpu infomations
python /home/jinzexue/pytorch/testcuda/testcuda.py
# install torchani
git clone https://github.com/aiqm/torchani.git
cd torchani
pip install -e . && pip install -v -e . --global-option="--cuaev"
```
</details>
<details>
<summary>Expanse</summary>
```bash
srun -p gpu-shared --ntasks=1 --account=cwr109 --cpus-per-task=1 --gpus=1 --time=01:00:00 --mem=10gb --pty -u bash -i
# create env if necessary
conda create -n cuaev python=3.8
conda activate cuaev
# modules
module load cuda10.2/toolkit/10.2.89 gcc/7.5.0
# pytorch
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-nightly
# install
git clone https://github.com/aiqm/torchani.git
cd torchani
pip install -e . && pip install -v -e . --global-option="--cuaev"
```
</details>
<details>
<summary>Moria</summary>
```bash
srun -p gpu --gpus=geforce:1 --time=01:00:00 --mem=10gb --pty -u bash -i # compile may fail because of low on memery (when memery is less than 5gb)
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch # make sure it's cudatoolkit=10.0
module load cuda/10.0.130
module load gcc/7.3.0
python setup.py install --cuaev-all-sms
srun --ntasks=1 --cpus-per-task=2 --gpus=1 --time=02:00:00 --mem=10gb --pty -u bash -i
# create env if necessary
conda create -n cuaev python=3.8
conda activate cuaev
# cuda path (could be added to ~/.bashrc)
export PATH=/usr/local/cuda/bin:$PATH # nvcc for cuda 9.2
# pytorch
conda install pytorch torchvision cudatoolkit=9.2 -c pytorch-nightly
# install
git clone https://github.com/aiqm/torchani.git
cd torchani
pip install -e . && pip install -v -e . --global-option="--cuaev"
```
</details>
## Test
```bash
cd torchani
./download.sh
python tests/test_cuaev.py
```
## Usage
......@@ -44,27 +125,33 @@ cuaev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, Sh
## TODOs
- [x] CUAEV Forward
- [x] CUAEV Backwad (Force)
- [x] CUAEV Double Backwad (Force training need aev's double backward w.r.t. grad_aev)
- [ ] PBC
- [ ] Force training (Need cuaev's second derivative)
## Benchmark
Benchmark of [torchani/tools/training-aev-benchmark.py](https://github.com/aiqm/torchani/blob/master/tools/training-aev-benchmark.py) on TITAN V:
Benchmark of [torchani/tools/training-aev-benchmark.py](https://github.com/aiqm/torchani/blob/master/tools/training-aev-benchmark.py):
| ANI-1x dataset (Batchsize 2560) | Energy Training | Energy and Force Inference |
|---------------------------------|-------------------------|-----------------------------------|
| Time per Epoch / Memory | AEV / Total / GPU Mem | AEV / Force / Total / GPU Mem |
| aev cuda extension | 3.90s / 31.5s / 2088 MB | 3.90s / 22.6s / 43.0s / 4234 MB |
| aev python code | 23.7s / 50.2s / 3540 MB | 25.3s / 48.0s / 88.2s / 11316 MB |
Train ANI-1x dataset (Batchsize 2560) on Tesla V100 for 1 epoch:
```
RUN Total AEV Forward Backward Force Optimizer Others Epoch time GPU
0 cu Energy 3.355 sec 4.470 sec 4.685 sec 0.0 ms 3.508 sec 2.223 sec 18.241 sec 2780.8MB
1 py Energy 19.682 sec 4.149 sec 4.663 sec 0.0 ms 3.495 sec 2.220 sec 34.209 sec 4038.8MB
2 cu Energy+Force 3.351 sec 4.200 sec 27.402 sec 16.514 sec 3.467 sec 4.556 sec 59.490 sec 7492.8MB
3 py Energy+Force 19.964 sec 4.176 sec 91.866 sec 36.554 sec 3.473 sec 5.403 sec 161.435 sec 8034.8MB
```
## Test
```bash
cd torchani
./download.sh
python tests/test_cuaev.py
Train ANI-1x dataset (Batchsize 1500) on GTX 1080 for 1 epoch:
```
RUN Total AEV Forward Backward Force Optimizer Others Epoch time GPU
0 cu Energy 14.373 sec 10.870 sec 13.100 sec 0.0 ms 11.043 sec 2.913 sec 52.299 sec 1527.5MB
1 py Energy 51.545 sec 10.228 sec 13.154 sec 0.0 ms 11.384 sec 2.874 sec 89.185 sec 2403.5MB
2 cu Energy+Force 14.275 sec 10.024 sec 85.423 sec 51.380 sec 7.396 sec 5.494 sec 173.992 sec 3577.5MB
3 py Energy+Force 51.305 sec 9.951 sec 271.078 sec 107.252 sec 7.835 sec 4.941 sec 452.362 sec 7307.5MB
```
benchmark
```
pip install pynvml pkbar
python tools/training-aev-benchmark.py download/dataset/ani-1x/sample.h5
python tools/aev-benchmark-size.py
```
......@@ -10,8 +10,109 @@
#define PI 3.141592653589793
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::tensor_list;
// [Computation graph for forward, backward, and double backward]
//
// backward
// force = (dE / daev) * (daev / dcoord) = g * (daev / dcoord)
//
// double backward (to do force training, the term needed is)
// dloss / dg = (dloss / dforce) * (dforce / dg) = (dloss / dforce) * (daev / dcoord)
//
//
// [Forward]
// out ^
// | ^
// ... ^
// | ^
// e n e r g y ^
// | \ ^
// aev \ ^
// / | \ ^
// radial angular params ^
// / / | ^
// dist---^ / ^
// \ / ^
// coord ^
//
// Functional relationship:
// coord <-- input
// dist(coord)
// radial(dist)
// angular(dist, coord)
// aev = concatenate(radial, angular)
// energy(aev, params)
// out(energy, ....) <-- output
//
//
// [Backward]
// dout v
// | v
// ... v
// | v
// aev params denergy aev params v
// \ | / \ | / v
// d a e v dparams v
// / \____ v
// dist dradial \ v
// \ / \ v
// ddist dist coord dangular dist coord v
// \ / / \ | / v
// \_/____/ \___|___/ v
// | __________________/ v
// | / v
// dcoord v
// | v
// ... v
// | v
// out2 v
//
// Functional relationship:
// dout <-- input
// denergy(dout)
// dparams(denergy, aev, params) <-- output
// daev(denergy, aev, params)
// dradial = slice(daev)
// dangular = slice(daev)
// ddist = radial_backward(dradial, dist) + angular_backward_dist(dangular, ...)
// = radial_backward(dradial, dist) + 0 (all contributions route to dcoord)
// = radial_backward(dradial, dist)
// dcoord = dist_backward(ddist, coord, dist) + angular_backward_coord(dangular, coord, dist)
// out2(dcoord, ...) <-- output
//
//
// [Double backward w.r.t params (i.e. force training)]
// Note: only a very limited subset of double backward is implemented
// currently it can only do force training, there is no hessian support
// not implemented terms are marked by $s
// $$$ [dparams] $$$$ ^
// \_ | __/ ^
// [ddaev] ^
// / \_____ ^
// $$$$ [ddradial] \ ^
// \ / \ ^
// [dddist] $$$$ $$$$ [ddangular] $$$$ $$$$ ^
// \ / / \ | / ^
// \_/____/ \_____|___/ ^
// | _____________________/ ^
// | / ^
// [ddcoord] ^
// | ^
// ... ^
// | ^
// [dout2] ^
//
// Functional relationship:
// dout2 <-- input
// ddcoord(dout2, ...)
// dddist = dist_doublebackward(ddcoord, coord, dist)
// ddradial = radial_doublebackward(dddist, dist)
// ddangular = angular_doublebackward(ddcord, coord, dist)
// ddaev = concatenate(ddradial, ddangular)
// dparams(ddaev, ...) <-- output
template <typename DataT, typename IndexT = int>
struct AEVScalarParams {
DataT Rcr;
......@@ -170,11 +271,13 @@ __global__ void pairwiseDistanceSingleMolecule(
}
// every block compute blocksize RIJ's gradient by column major, to avoid atomicAdd waiting
template <typename DataT, typename IndexT = int>
__global__ void pairwiseDistance_backward(
template <bool is_double_backward, typename DataT, typename IndexT = int>
__global__ void pairwiseDistance_backward_or_doublebackward(
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> grad_radial_dist,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> grad_coord,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits>
grad_dist, // ddist for backward, dddist for double backward
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits>
grad_coord_or_force, // dcoord for backward, dforce(i.e. ddcoord) for double backward
const PairDist<DataT>* d_radialRij,
IndexT nRadialRij) {
int gidx = threadIdx.x * gridDim.x + blockIdx.x;
......@@ -192,17 +295,28 @@ __global__ void pairwiseDistance_backward(
const DataT dely = pos_t[mol_idx][j][1] - pos_t[mol_idx][i][1];
const DataT delz = pos_t[mol_idx][j][2] - pos_t[mol_idx][i][2];
DataT grad_dist_coord_x = delx / Rij;
DataT grad_dist_coord_y = dely / Rij;
DataT grad_dist_coord_z = delz / Rij;
DataT grad_radial_dist_item = grad_radial_dist[gidx];
atomicAdd(&grad_coord[mol_idx][j][0], grad_radial_dist_item * grad_dist_coord_x);
atomicAdd(&grad_coord[mol_idx][j][1], grad_radial_dist_item * grad_dist_coord_y);
atomicAdd(&grad_coord[mol_idx][j][2], grad_radial_dist_item * grad_dist_coord_z);
atomicAdd(&grad_coord[mol_idx][i][0], -grad_radial_dist_item * grad_dist_coord_x);
atomicAdd(&grad_coord[mol_idx][i][1], -grad_radial_dist_item * grad_dist_coord_y);
atomicAdd(&grad_coord[mol_idx][i][2], -grad_radial_dist_item * grad_dist_coord_z);
if constexpr (is_double_backward) {
auto& grad_force = grad_coord_or_force;
DataT grad_force_coord_Rij_item = (grad_force[mol_idx][j][0] - grad_force[mol_idx][i][0]) * delx / Rij +
(grad_force[mol_idx][j][1] - grad_force[mol_idx][i][1]) * dely / Rij +
(grad_force[mol_idx][j][2] - grad_force[mol_idx][i][2]) * delz / Rij;
grad_dist[gidx] = grad_force_coord_Rij_item;
} else {
auto& grad_coord = grad_coord_or_force;
DataT grad_dist_coord_x = delx / Rij;
DataT grad_dist_coord_y = dely / Rij;
DataT grad_dist_coord_z = delz / Rij;
DataT grad_radial_dist_item = grad_dist[gidx];
atomicAdd(&grad_coord[mol_idx][j][0], grad_radial_dist_item * grad_dist_coord_x);
atomicAdd(&grad_coord[mol_idx][j][1], grad_radial_dist_item * grad_dist_coord_y);
atomicAdd(&grad_coord[mol_idx][j][2], grad_radial_dist_item * grad_dist_coord_z);
atomicAdd(&grad_coord[mol_idx][i][0], -grad_radial_dist_item * grad_dist_coord_x);
atomicAdd(&grad_coord[mol_idx][i][1], -grad_radial_dist_item * grad_dist_coord_y);
atomicAdd(&grad_coord[mol_idx][i][2], -grad_radial_dist_item * grad_dist_coord_z);
}
}
template <typename SpeciesT, typename DataT, typename IndexT = int, int TILEX = 8, int TILEY = 4>
......@@ -349,18 +463,24 @@ __global__ void cuAngularAEVs(
}
}
template <typename SpeciesT, typename DataT, typename IndexT = int, int TILEX = 8, int TILEY = 4>
__global__ void
// __launch_bounds__(32)
cuAngularAEVs_backward(
template <
bool is_double_backward,
typename SpeciesT,
typename DataT,
typename IndexT = int,
int TILEX = 8,
int TILEY = 4>
__global__ void cuAngularAEVs_backward_or_doublebackward(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfZ_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> Zeta_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> grad_output,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> grad_coord,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits>
grad_output, // for backward, this is daev, for double backward, this is dforce (i.e. ddcoord)
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits>
grad_input, // for backward, this is dcoord, for double backward, this is ddaev
const PairDist<DataT>* d_Rij,
const PairDist<DataT>* d_centralAtom,
int* d_nPairsPerCenterAtom,
......@@ -533,53 +653,76 @@ cuAngularAEVs_backward(
DataT factor2 = exp(-EtaA * (Rijk - ShfA) * (Rijk - ShfA));
DataT grad_factor2_dist = -EtaA * (Rijk - ShfA) * factor2;
DataT grad_output_item =
grad_output[mol_idx][i][aev_params.radial_length + subaev_offset + ishfr * nShfZ + itheta];
DataT grad_vij_x = 2 * grad_output_item *
DataT grad_vij_x = 2 *
(grad_factor1_theta * grad_theta_vij_x_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdx[jj] / Rij * fc_ijk +
factor1 * factor2 * fc_ik * grad_fc_ij * sdx[jj] / Rij);
DataT grad_vij_y = 2 * grad_output_item *
DataT grad_vij_y = 2 *
(grad_factor1_theta * grad_theta_vij_y_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdy[jj] / Rij * fc_ijk +
factor1 * factor2 * fc_ik * grad_fc_ij * sdy[jj] / Rij);
DataT grad_vij_z = 2 * grad_output_item *
DataT grad_vij_z = 2 *
(grad_factor1_theta * grad_theta_vij_z_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdz[jj] / Rij * fc_ijk +
factor1 * factor2 * fc_ik * grad_fc_ij * sdz[jj] / Rij);
DataT grad_vik_x = 2 * grad_output_item *
DataT grad_vik_x = 2 *
(grad_factor1_theta * grad_theta_vik_x_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdx[kk] / Rik * fc_ijk +
factor1 * factor2 * fc_ij * grad_fc_ik * sdx[kk] / Rik);
DataT grad_vik_y = 2 * grad_output_item *
DataT grad_vik_y = 2 *
(grad_factor1_theta * grad_theta_vik_y_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdy[kk] / Rik * fc_ijk +
factor1 * factor2 * fc_ij * grad_fc_ik * sdy[kk] / Rik);
DataT grad_vik_z = 2 * grad_output_item *
DataT grad_vik_z = 2 *
(grad_factor1_theta * grad_theta_vik_z_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdz[kk] / Rik * fc_ijk +
factor1 * factor2 * fc_ij * grad_fc_ik * sdz[kk] / Rik);
sdix_grad += (-grad_vij_x - grad_vik_x);
sdiy_grad += (-grad_vij_y - grad_vik_y);
sdiz_grad += (-grad_vij_z - grad_vik_z);
for (int offset = 16; offset > 0; offset /= 2) {
grad_vij_x += __shfl_down_sync(0xFFFFFFFF, grad_vij_x, offset);
grad_vij_y += __shfl_down_sync(0xFFFFFFFF, grad_vij_y, offset);
grad_vij_z += __shfl_down_sync(0xFFFFFFFF, grad_vij_z, offset);
grad_vik_x += __shfl_down_sync(0xFFFFFFFF, grad_vik_x, offset);
grad_vik_y += __shfl_down_sync(0xFFFFFFFF, grad_vik_y, offset);
grad_vik_z += __shfl_down_sync(0xFFFFFFFF, grad_vik_z, offset);
}
if (laneIdx == 0) {
sdjx_grad[jj] += grad_vij_x;
sdjy_grad[jj] += grad_vij_y;
sdjz_grad[jj] += grad_vij_z;
sdjx_grad[kk] += grad_vik_x;
sdjy_grad[kk] += grad_vik_y;
sdjz_grad[kk] += grad_vik_z;
if constexpr (is_double_backward) {
int atomj_idx = d_Rij[start_idx + jj].j;
int atomk_idx = d_Rij[start_idx + kk].j;
auto& grad_force = grad_output;
auto& grad_grad_aev = grad_input;
grad_vij_x *= (grad_force[mol_idx][atomj_idx][0] - grad_force[mol_idx][i][0]);
grad_vij_y *= (grad_force[mol_idx][atomj_idx][1] - grad_force[mol_idx][i][1]);
grad_vij_z *= (grad_force[mol_idx][atomj_idx][2] - grad_force[mol_idx][i][2]);
grad_vik_x *= (grad_force[mol_idx][atomk_idx][0] - grad_force[mol_idx][i][0]);
grad_vik_y *= (grad_force[mol_idx][atomk_idx][1] - grad_force[mol_idx][i][1]);
grad_vik_z *= (grad_force[mol_idx][atomk_idx][2] - grad_force[mol_idx][i][2]);
atomicAdd(
&grad_grad_aev[mol_idx][i][aev_params.radial_length + subaev_offset + ishfr * nShfZ + itheta],
grad_vij_x + grad_vij_y + grad_vij_z + grad_vik_x + grad_vik_y + grad_vik_z);
} else {
DataT grad_output_item =
grad_output[mol_idx][i][aev_params.radial_length + subaev_offset + ishfr * nShfZ + itheta];
grad_vij_x *= grad_output_item;
grad_vij_y *= grad_output_item;
grad_vij_z *= grad_output_item;
grad_vik_x *= grad_output_item;
grad_vik_y *= grad_output_item;
grad_vik_z *= grad_output_item;
sdix_grad += (-grad_vij_x - grad_vik_x);
sdiy_grad += (-grad_vij_y - grad_vik_y);
sdiz_grad += (-grad_vij_z - grad_vik_z);
for (int offset = 16; offset > 0; offset /= 2) {
grad_vij_x += __shfl_down_sync(0xFFFFFFFF, grad_vij_x, offset);
grad_vij_y += __shfl_down_sync(0xFFFFFFFF, grad_vij_y, offset);
grad_vij_z += __shfl_down_sync(0xFFFFFFFF, grad_vij_z, offset);
grad_vik_x += __shfl_down_sync(0xFFFFFFFF, grad_vik_x, offset);
grad_vik_y += __shfl_down_sync(0xFFFFFFFF, grad_vik_y, offset);
grad_vik_z += __shfl_down_sync(0xFFFFFFFF, grad_vik_z, offset);
}
if (laneIdx == 0) {
sdjx_grad[jj] += grad_vij_x;
sdjy_grad[jj] += grad_vij_y;
sdjz_grad[jj] += grad_vij_z;
sdjx_grad[kk] += grad_vik_x;
sdjy_grad[kk] += grad_vik_y;
sdjz_grad[kk] += grad_vik_z;
}
}
}
}
......@@ -587,17 +730,20 @@ cuAngularAEVs_backward(
}
}
int atomi_idx = i;
atomicAdd(&grad_coord[mol_idx][atomi_idx][0], sdix_grad);
atomicAdd(&grad_coord[mol_idx][atomi_idx][1], sdiy_grad);
atomicAdd(&grad_coord[mol_idx][atomi_idx][2], sdiz_grad);
if constexpr (!is_double_backward) {
auto& grad_coord = grad_input;
int atomi_idx = i;
atomicAdd(&grad_coord[mol_idx][atomi_idx][0], sdix_grad);
atomicAdd(&grad_coord[mol_idx][atomi_idx][1], sdiy_grad);
atomicAdd(&grad_coord[mol_idx][atomi_idx][2], sdiz_grad);
for (int jj = laneIdx; jj < jnum; jj += threads_per_catom) {
int atomj_idx = d_Rij[start_idx + jj].j;
for (int jj = laneIdx; jj < jnum; jj += threads_per_catom) {
int atomj_idx = d_Rij[start_idx + jj].j;
atomicAdd(&grad_coord[mol_idx][atomj_idx][0], sdjx_grad[jj]);
atomicAdd(&grad_coord[mol_idx][atomj_idx][1], sdjy_grad[jj]);
atomicAdd(&grad_coord[mol_idx][atomj_idx][2], sdjz_grad[jj]);
atomicAdd(&grad_coord[mol_idx][atomj_idx][0], sdjx_grad[jj]);
atomicAdd(&grad_coord[mol_idx][atomj_idx][1], sdjy_grad[jj]);
atomicAdd(&grad_coord[mol_idx][atomj_idx][2], sdjz_grad[jj]);
}
}
}
......@@ -641,13 +787,15 @@ __global__ void cuRadialAEVs(
}
// every <THREADS_PER_RIJ> threads take care of 1 RIJ, and iterate <nShfR / THREADS_PER_RIJ> times
template <typename SpeciesT, typename DataT, int THREADS_PER_RIJ>
__global__ void cuRadialAEVs_backward(
template <bool is_double_backward, typename SpeciesT, typename DataT, int THREADS_PER_RIJ>
__global__ void cuRadialAEVs_backward_or_doublebackward(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfR_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaR_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> grad_output,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> grad_radial_dist,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits>
grad_aev, // daev for backward, ddaev for double backward
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits>
grad_dist, // ddist for backward, dddist for double backward
const PairDist<DataT>* d_Rij,
AEVScalarParams<DataT, int> aev_params,
int nRadialRij) {
......@@ -673,16 +821,24 @@ __global__ void cuRadialAEVs_backward(
DataT fc = 0.5 * cos(PI * Rij / aev_params.Rcr) + 0.5;
DataT fc_grad = -0.5 * (PI / aev_params.Rcr) * sin(PI * Rij / aev_params.Rcr);
DataT upstream_grad;
if constexpr (is_double_backward) {
upstream_grad = grad_dist[idx];
}
for (int ishfr = laneIdx; ishfr < nShfR; ishfr += THREADS_PER_RIJ) {
DataT ShfR = ShfR_t[ishfr];
DataT GmR = 0.25 * exp(-EtaR * (Rij - ShfR) * (Rij - ShfR));
DataT GmR_grad = -EtaR * (-2 * ShfR + 2 * Rij) * GmR;
DataT jacobian = GmR_grad * fc + GmR * fc_grad;
DataT grad_output_item = grad_output[mol_idx][i][type_j * aev_params.radial_sublength + ishfr];
DataT grad_radial_dist_item = grad_output_item * (GmR_grad * fc + GmR * fc_grad);
atomicAdd(&grad_radial_dist[idx], grad_radial_dist_item);
if constexpr (is_double_backward) {
atomicAdd(&grad_aev[mol_idx][i][type_j * aev_params.radial_sublength + ishfr], upstream_grad * jacobian);
} else {
upstream_grad = grad_aev[mol_idx][i][type_j * aev_params.radial_sublength + ishfr];
atomicAdd(&grad_dist[idx], upstream_grad * jacobian);
}
}
}
......@@ -898,7 +1054,6 @@ Result cuaev_forward(
const int block_size = 64;
dim3 block(8, 8, 1);
if (n_molecules == 1) {
int tileWidth = 32;
int tilesPerRow = (max_natoms_per_mol + tileWidth - 1) / tileWidth;
......@@ -1063,7 +1218,7 @@ Tensor cuaev_backward(
int block_size = 64;
int nblocks = (nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs_backward<int, float, 8><<<nblocks, block_size, 0, stream>>>(
cuRadialAEVs_backward_or_doublebackward<false, int, float, 8><<<nblocks, block_size, 0, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
......@@ -1075,7 +1230,7 @@ Tensor cuaev_backward(
// For best result, block_size should match average molecule size (no padding) to avoid atomicAdd
nblocks = (nRadialRij + block_size - 1) / block_size;
pairwiseDistance_backward<<<nblocks, block_size, 0, stream>>>(
pairwiseDistance_backward_or_doublebackward<false><<<nblocks, block_size, 0, stream>>>(
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_radial_dist.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
......@@ -1099,7 +1254,7 @@ Tensor cuaev_backward(
int smem_size_aligned = smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom);
Tensor grad_angular_coord = torch::zeros({nAngularRij, 3}, coordinates_t.options().requires_grad(false));
cuAngularAEVs_backward<<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>(
cuAngularAEVs_backward_or_doublebackward<false><<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
......@@ -1120,50 +1275,171 @@ Tensor cuaev_backward(
return grad_coord;
}
#define AEV_INPUT \
const Tensor &coordinates_t, const Tensor &species_t, double Rcr_, double Rca_, const Tensor &EtaR_t, \
const Tensor &ShfR_t, const Tensor &EtaA_t, const Tensor &Zeta_t, const Tensor &ShfA_t, const Tensor &ShfZ_t, \
int64_t num_species_
Tensor cuaev_double_backward(
const Tensor& grad_force,
const Tensor& coordinates_t,
const Tensor& species_t,
const AEVScalarParams<float>& aev_params,
const Tensor& EtaR_t,
const Tensor& ShfR_t,
const Tensor& EtaA_t,
const Tensor& Zeta_t,
const Tensor& ShfA_t,
const Tensor& ShfZ_t,
const Tensor& tensor_Rij,
int total_natom_pairs,
const Tensor& tensor_radialRij,
int nRadialRij,
const Tensor& tensor_angularRij,
int nAngularRij,
const Tensor& tensor_centralAtom,
const Tensor& tensor_numPairsPerCenterAtom,
const Tensor& tensor_centerAtomStartIdx,
int maxnbrs_per_atom_aligned,
int angular_length_aligned,
int ncenter_atoms) {
using namespace torch::indexing;
const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Tensor cuaev_cuda(AEV_INPUT) {
Result res = cuaev_forward<float>(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
return res.aev_t;
int aev_length = aev_params.radial_length + aev_params.angular_length;
auto grad_grad_aev = torch::zeros(
{coordinates_t.size(0), coordinates_t.size(1), aev_length},
coordinates_t.options().requires_grad(false)); // [2, 5, 384]
PairDist<float>* d_Rij = (PairDist<float>*)tensor_Rij.data_ptr();
PairDist<float>* d_radialRij = (PairDist<float>*)tensor_radialRij.data_ptr();
PairDist<float>* d_angularRij = (PairDist<float>*)tensor_angularRij.data_ptr();
PairDist<float>* d_centralAtom = (PairDist<float>*)tensor_centralAtom.data_ptr();
int* d_numPairsPerCenterAtom = (int*)tensor_numPairsPerCenterAtom.data_ptr();
int* d_centerAtomStartIdx = (int*)tensor_centerAtomStartIdx.data_ptr();
auto grad_force_coord_Rij = torch::zeros({nRadialRij}, coordinates_t.options().requires_grad(false));
int block_size = 64;
int nblocks = (nRadialRij + block_size - 1) / block_size;
pairwiseDistance_backward_or_doublebackward<true><<<nblocks, block_size, 0, stream>>>(
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_force_coord_Rij.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij,
nRadialRij);
nblocks = (nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs_backward_or_doublebackward<true, int, float, 8><<<nblocks, block_size, 0, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_grad_aev.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_force_coord_Rij.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
d_radialRij,
aev_params,
nRadialRij);
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3;
int sj_xyz_grad = sizeof(float) * max_nbrs * 3;
int sRij = sizeof(float) * max_nbrs;
int sfc = sizeof(float) * max_nbrs;
int sfc_grad = sizeof(float) * max_nbrs;
int sj = sizeof(int) * max_nbrs;
return (sxyz + sj_xyz_grad + sRij + sfc + sfc_grad + sj) * ncatom_per_tpb;
};
block_size = 32;
const int nthreads_per_catom = 32;
const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
int smem_size_aligned = smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom);
cuAngularAEVs_backward_or_doublebackward<true><<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_grad_aev.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_angularRij,
d_centralAtom,
d_numPairsPerCenterAtom,
d_centerAtomStartIdx,
aev_params,
maxnbrs_per_atom_aligned,
angular_length_aligned,
ncenter_atoms);
return grad_grad_aev;
}
class CuaevAutograd : public torch::autograd::Function<CuaevAutograd> {
class CuaevDoubleAutograd : public torch::autograd::Function<CuaevDoubleAutograd> {
public:
static Tensor forward(torch::autograd::AutogradContext* ctx, AEV_INPUT) {
at::AutoNonVariableTypeMode g;
Result res = cuaev_forward<float>(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
if (coordinates_t.requires_grad()) {
static Tensor forward(AutogradContext* ctx, Tensor grad_e_aev, AutogradContext* prectx) {
auto saved = prectx->get_saved_variables();
auto coordinates_t = saved[0], species_t = saved[1];
auto tensor_Rij = saved[2], tensor_radialRij = saved[3], tensor_angularRij = saved[4];
auto EtaR_t = saved[5], ShfR_t = saved[6], EtaA_t = saved[7], Zeta_t = saved[8], ShfA_t = saved[9],
ShfZ_t = saved[10];
auto tensor_centralAtom = saved[11], tensor_numPairsPerCenterAtom = saved[12],
tensor_centerAtomStartIdx = saved[13];
AEVScalarParams<float> aev_params(prectx->saved_data["aev_params"]);
c10::List<int64_t> int_list = prectx->saved_data["int_list"].toIntList();
int total_natom_pairs = int_list[0], nRadialRij = int_list[1], nAngularRij = int_list[2];
int maxnbrs_per_atom_aligned = int_list[3], angular_length_aligned = int_list[4];
int ncenter_atoms = int_list[5];
if (grad_e_aev.requires_grad()) {
ctx->save_for_backward({coordinates_t,
species_t,
res.tensor_Rij,
res.tensor_radialRij,
res.tensor_angularRij,
tensor_Rij,
tensor_radialRij,
tensor_angularRij,
EtaR_t,
ShfR_t,
EtaA_t,
Zeta_t,
ShfA_t,
ShfZ_t,
res.tensor_centralAtom,
res.tensor_numPairsPerCenterAtom,
res.tensor_centerAtomStartIdx});
ctx->saved_data["aev_params"] = res.aev_params;
ctx->saved_data["int_list"] = c10::List<int64_t>{res.total_natom_pairs,
res.nRadialRij,
res.nAngularRij,
res.maxnbrs_per_atom_aligned,
res.angular_length_aligned,
res.ncenter_atoms};
tensor_centralAtom,
tensor_numPairsPerCenterAtom,
tensor_centerAtomStartIdx});
ctx->saved_data["aev_params"] = aev_params;
ctx->saved_data["int_list"] = c10::List<int64_t>{
total_natom_pairs, nRadialRij, nAngularRij, maxnbrs_per_atom_aligned, angular_length_aligned, ncenter_atoms};
}
return res.aev_t;
Tensor grad_coord = cuaev_backward(
grad_e_aev,
coordinates_t,
species_t,
aev_params,
EtaR_t,
ShfR_t,
EtaA_t,
Zeta_t,
ShfA_t,
ShfZ_t,
tensor_Rij,
total_natom_pairs,
tensor_radialRij,
nRadialRij,
tensor_angularRij,
nAngularRij,
tensor_centralAtom,
tensor_numPairsPerCenterAtom,
tensor_centerAtomStartIdx,
maxnbrs_per_atom_aligned,
angular_length_aligned,
ncenter_atoms);
return grad_coord;
}
static tensor_list backward(torch::autograd::AutogradContext* ctx, tensor_list grad_outputs) {
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
Tensor grad_force = grad_outputs[0];
auto saved = ctx->get_saved_variables();
auto coordinates_t = saved[0], species_t = saved[1];
auto tensor_Rij = saved[2], tensor_radialRij = saved[3], tensor_angularRij = saved[4];
......@@ -1177,8 +1453,8 @@ class CuaevAutograd : public torch::autograd::Function<CuaevAutograd> {
int maxnbrs_per_atom_aligned = int_list[3], angular_length_aligned = int_list[4];
int ncenter_atoms = int_list[5];
Tensor grad_coord = cuaev_backward(
grad_outputs[0],
Tensor grad_grad_aev = cuaev_double_backward(
grad_force,
coordinates_t,
species_t,
aev_params,
......@@ -1201,6 +1477,56 @@ class CuaevAutograd : public torch::autograd::Function<CuaevAutograd> {
angular_length_aligned,
ncenter_atoms);
return {grad_grad_aev, torch::Tensor()};
}
};
#define AEV_INPUT \
const Tensor &coordinates_t, const Tensor &species_t, double Rcr_, double Rca_, const Tensor &EtaR_t, \
const Tensor &ShfR_t, const Tensor &EtaA_t, const Tensor &Zeta_t, const Tensor &ShfA_t, const Tensor &ShfZ_t, \
int64_t num_species_
Tensor cuaev_cuda(AEV_INPUT) {
Result res = cuaev_forward<float>(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
return res.aev_t;
}
class CuaevAutograd : public torch::autograd::Function<CuaevAutograd> {
public:
static Tensor forward(AutogradContext* ctx, AEV_INPUT) {
at::AutoNonVariableTypeMode g;
Result res = cuaev_forward<float>(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
if (coordinates_t.requires_grad()) {
ctx->save_for_backward({coordinates_t,
species_t,
res.tensor_Rij,
res.tensor_radialRij,
res.tensor_angularRij,
EtaR_t,
ShfR_t,
EtaA_t,
Zeta_t,
ShfA_t,
ShfZ_t,
res.tensor_centralAtom,
res.tensor_numPairsPerCenterAtom,
res.tensor_centerAtomStartIdx});
ctx->saved_data["aev_params"] = res.aev_params;
ctx->saved_data["int_list"] = c10::List<int64_t>{res.total_natom_pairs,
res.nRadialRij,
res.nAngularRij,
res.maxnbrs_per_atom_aligned,
res.angular_length_aligned,
res.ncenter_atoms};
}
return res.aev_t;
}
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
Tensor grad_coord = CuaevDoubleAutograd::apply(grad_outputs[0], ctx);
return {
grad_coord, Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor()};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment