Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
d2d63056
Unverified
Commit
d2d63056
authored
Nov 24, 2020
by
Richard Xue
Committed by
GitHub
Nov 24, 2020
Browse files
CRLF to LF (#553)
* line-limit 120 * CRLF to LF
parent
a6d819ed
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
585 additions
and
615 deletions
+585
-615
.clang-format
.clang-format
+1
-1
torchani/cuaev/aev.cu
torchani/cuaev/aev.cu
+584
-614
No files found.
.clang-format
View file @
d2d63056
...
@@ -35,7 +35,7 @@ BreakBeforeTernaryOperators: true
...
@@ -35,7 +35,7 @@ BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
BreakStringLiterals: false
ColumnLimit: 1
0
0
ColumnLimit: 1
2
0
CommentPragmas: '^ IWYU pragma:'
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerAllOnOneLineOrOnePerLine: true
...
...
torchani/cuaev/aev.cu
View file @
d2d63056
#include <thrust/equal.h>
#include <thrust/equal.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#include <ATen/Context.h>
#include <ATen/Context.h>
#include <THC/THC.h>
#include <THC/THC.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCThrustAllocator.cuh>
#define PI 3.141592653589793
#define PI 3.141592653589793
template
<
typename
DataT
,
typename
IndexT
=
int
>
template
<
typename
DataT
,
typename
IndexT
=
int
>
struct
AEVScalarParams
{
struct
AEVScalarParams
{
DataT
Rcr
;
DataT
Rcr
;
DataT
Rca
;
DataT
Rca
;
IndexT
radial_sublength
;
IndexT
radial_sublength
;
IndexT
radial_length
;
IndexT
radial_length
;
IndexT
angular_sublength
;
IndexT
angular_sublength
;
IndexT
angular_length
;
IndexT
angular_length
;
IndexT
num_species
;
IndexT
num_species
;
};
};
#define MAX_NSPECIES 10
#define MAX_NSPECIES 10
__constant__
int
csubaev_offsets
[
MAX_NSPECIES
*
MAX_NSPECIES
];
__constant__
int
csubaev_offsets
[
MAX_NSPECIES
*
MAX_NSPECIES
];
template
<
typename
DataT
>
template
<
typename
DataT
>
struct
PairDist
{
struct
PairDist
{
DataT
Rij
;
DataT
Rij
;
int
midx
;
int
midx
;
short
i
;
short
i
;
short
j
;
short
j
;
};
};
// used to group Rijs by atom id
// used to group Rijs by atom id
template
<
typename
DataT
>
template
<
typename
DataT
>
__host__
__device__
bool
operator
==
(
const
PairDist
<
DataT
>&
lhs
,
const
PairDist
<
DataT
>&
rhs
)
{
__host__
__device__
bool
operator
==
(
const
PairDist
<
DataT
>&
lhs
,
const
PairDist
<
DataT
>&
rhs
)
{
return
lhs
.
midx
==
rhs
.
midx
&&
lhs
.
i
==
rhs
.
i
;
return
lhs
.
midx
==
rhs
.
midx
&&
lhs
.
i
==
rhs
.
i
;
}
}
/// Alignment of memory. Must be a power of two
/// Alignment of memory. Must be a power of two
/// \tparam boundary Boundary to align to (NOTE: must be power of 2)
/// \tparam boundary Boundary to align to (NOTE: must be power of 2)
/// \param value Input value that is to be aligned
/// \param value Input value that is to be aligned
/// \return Value aligned to boundary
/// \return Value aligned to boundary
template
<
int32_t
boundary
>
template
<
int32_t
boundary
>
__host__
__device__
__forceinline__
int
align
(
const
int
&
value
)
{
__host__
__device__
__forceinline__
int
align
(
const
int
&
value
)
{
static_assert
((
boundary
&
(
boundary
-
1
))
==
0
,
"Boundary for align must be power of 2"
);
static_assert
((
boundary
&
(
boundary
-
1
))
==
0
,
"Boundary for align must be power of 2"
);
return
(
value
+
boundary
)
&
~
(
boundary
-
1
);
return
(
value
+
boundary
)
&
~
(
boundary
-
1
);
}
}
template
<
typename
SpeciesT
,
typename
DataT
,
typename
IndexT
=
int
>
template
<
typename
SpeciesT
,
typename
DataT
,
typename
IndexT
=
int
>
__global__
void
pairwiseDistance
(
__global__
void
pairwiseDistance
(
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
pos_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
pos_t
,
PairDist
<
DataT
>*
d_Rij
,
PairDist
<
DataT
>*
d_Rij
,
IndexT
max_natoms_per_mol
)
{
IndexT
max_natoms_per_mol
)
{
extern
__shared__
DataT
spos
[];
extern
__shared__
DataT
spos
[];
DataT
*
sx
=
&
spos
[
0
];
DataT
*
sx
=
&
spos
[
0
];
DataT
*
sy
=
&
spos
[
max_natoms_per_mol
];
DataT
*
sy
=
&
spos
[
max_natoms_per_mol
];
DataT
*
sz
=
&
spos
[
2
*
max_natoms_per_mol
];
DataT
*
sz
=
&
spos
[
2
*
max_natoms_per_mol
];
int
mol_idx
=
blockIdx
.
x
;
int
mol_idx
=
blockIdx
.
x
;
int
tidx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
tidx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
tidx
;
i
<
max_natoms_per_mol
;
i
+=
blockDim
.
x
*
blockDim
.
y
)
{
for
(
int
i
=
tidx
;
i
<
max_natoms_per_mol
;
i
+=
blockDim
.
x
*
blockDim
.
y
)
{
sx
[
i
]
=
pos_t
[
mol_idx
][
i
][
0
];
sx
[
i
]
=
pos_t
[
mol_idx
][
i
][
0
];
sy
[
i
]
=
pos_t
[
mol_idx
][
i
][
1
];
sy
[
i
]
=
pos_t
[
mol_idx
][
i
][
1
];
sz
[
i
]
=
pos_t
[
mol_idx
][
i
][
2
];
sz
[
i
]
=
pos_t
[
mol_idx
][
i
][
2
];
}
}
__syncthreads
();
__syncthreads
();
int
natom_pairs
=
max_natoms_per_mol
*
max_natoms_per_mol
;
int
natom_pairs
=
max_natoms_per_mol
*
max_natoms_per_mol
;
for
(
int
i
=
threadIdx
.
y
;
i
<
max_natoms_per_mol
;
i
+=
blockDim
.
y
)
{
for
(
int
i
=
threadIdx
.
y
;
i
<
max_natoms_per_mol
;
i
+=
blockDim
.
y
)
{
SpeciesT
type_i
=
species_t
[
mol_idx
][
i
];
SpeciesT
type_i
=
species_t
[
mol_idx
][
i
];
DataT
xi
=
sx
[
i
];
DataT
xi
=
sx
[
i
];
DataT
yi
=
sy
[
i
];
DataT
yi
=
sy
[
i
];
DataT
zi
=
sz
[
i
];
DataT
zi
=
sz
[
i
];
for
(
int
j
=
threadIdx
.
x
;
j
<
max_natoms_per_mol
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
max_natoms_per_mol
;
j
+=
blockDim
.
x
)
{
SpeciesT
type_j
=
species_t
[
mol_idx
][
j
];
SpeciesT
type_j
=
species_t
[
mol_idx
][
j
];
const
DataT
xj
=
sx
[
j
];
const
DataT
xj
=
sx
[
j
];
const
DataT
yj
=
sy
[
j
];
const
DataT
yj
=
sy
[
j
];
const
DataT
zj
=
sz
[
j
];
const
DataT
zj
=
sz
[
j
];
const
DataT
delx
=
xj
-
xi
;
const
DataT
delx
=
xj
-
xi
;
const
DataT
dely
=
yj
-
yi
;
const
DataT
dely
=
yj
-
yi
;
const
DataT
delz
=
zj
-
zi
;
const
DataT
delz
=
zj
-
zi
;
const
DataT
Rsq
=
delx
*
delx
+
dely
*
dely
+
delz
*
delz
;
const
DataT
Rsq
=
delx
*
delx
+
dely
*
dely
+
delz
*
delz
;
if
(
type_i
!=
-
1
&&
type_j
!=
-
1
&&
i
!=
j
)
{
if
(
type_i
!=
-
1
&&
type_j
!=
-
1
&&
i
!=
j
)
{
DataT
Rij
=
sqrt
(
Rsq
);
DataT
Rij
=
sqrt
(
Rsq
);
PairDist
<
DataT
>
d
;
PairDist
<
DataT
>
d
;
d
.
Rij
=
Rij
;
d
.
Rij
=
Rij
;
d
.
midx
=
mol_idx
;
d
.
midx
=
mol_idx
;
d
.
i
=
i
;
d
.
i
=
i
;
d
.
j
=
j
;
d
.
j
=
j
;
d_Rij
[
mol_idx
*
natom_pairs
+
i
*
max_natoms_per_mol
+
j
]
=
d
;
d_Rij
[
mol_idx
*
natom_pairs
+
i
*
max_natoms_per_mol
+
j
]
=
d
;
}
}
}
}
}
}
}
}
template
<
typename
SpeciesT
,
typename
DataT
,
typename
IndexT
=
int
,
int
TILEX
=
8
,
int
TILEY
=
4
>
template
<
typename
SpeciesT
,
typename
DataT
,
typename
IndexT
=
int
,
int
TILEX
=
8
,
int
TILEY
=
4
>
__global__
void
cuAngularAEVs
(
__global__
void
cuAngularAEVs
(
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
pos_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
pos_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
ShfA_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
ShfA_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
ShfZ_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
ShfZ_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
EtaA_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
EtaA_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
Zeta_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
Zeta_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
aev_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
aev_t
,
PairDist
<
DataT
>*
d_Rij
,
PairDist
<
DataT
>*
d_Rij
,
PairDist
<
DataT
>*
d_centralAtom
,
PairDist
<
DataT
>*
d_centralAtom
,
int
*
d_nPairsPerCenterAtom
,
int
*
d_nPairsPerCenterAtom
,
int
*
d_centerAtomStartIdx
,
int
*
d_centerAtomStartIdx
,
AEVScalarParams
<
DataT
,
IndexT
>
aev_params
,
AEVScalarParams
<
DataT
,
IndexT
>
aev_params
,
int
maxnbrs_per_atom_aligned
,
int
maxnbrs_per_atom_aligned
,
int
angular_length_aligned
,
int
angular_length_aligned
,
int
ncentral_atoms
)
{
int
ncentral_atoms
)
{
extern
__shared__
DataT
smem
[];
extern
__shared__
DataT
smem
[];
int
threads_per_catom
=
TILEX
*
TILEY
;
int
threads_per_catom
=
TILEX
*
TILEY
;
int
gIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
gIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
cIdx
=
gIdx
/
threads_per_catom
;
// central atom id
int
cIdx
=
gIdx
/
threads_per_catom
;
// central atom id
if
(
cIdx
>=
ncentral_atoms
)
if
(
cIdx
>=
ncentral_atoms
)
return
;
return
;
int
groupIdx
=
threadIdx
.
x
/
threads_per_catom
;
int
groupIdx
=
threadIdx
.
x
/
threads_per_catom
;
int
laneIdx
=
threadIdx
.
x
%
threads_per_catom
;
int
laneIdx
=
threadIdx
.
x
%
threads_per_catom
;
int
ncatom_per_tpb
=
blockDim
.
x
/
threads_per_catom
;
int
ncatom_per_tpb
=
blockDim
.
x
/
threads_per_catom
;
DataT
*
saev
=
&
smem
[
groupIdx
*
angular_length_aligned
];
DataT
*
saev
=
&
smem
[
groupIdx
*
angular_length_aligned
];
int
offset
=
ncatom_per_tpb
*
angular_length_aligned
;
int
offset
=
ncatom_per_tpb
*
angular_length_aligned
;
DataT
*
sdx
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sdx
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
DataT
*
sdy
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sdy
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
DataT
*
sdz
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sdz
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
DataT
*
sdist
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sdist
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
DataT
*
sfc
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sfc
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
int
*
stype
=
(
int
*
)
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
int
*
stype
=
(
int
*
)
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
EtaA
=
EtaA_t
[
0
];
DataT
EtaA
=
EtaA_t
[
0
];
DataT
Zeta
=
Zeta_t
[
0
];
DataT
Zeta
=
Zeta_t
[
0
];
IndexT
nShfA
=
ShfA_t
.
size
(
0
);
IndexT
nShfA
=
ShfA_t
.
size
(
0
);
IndexT
nShfZ
=
ShfZ_t
.
size
(
0
);
IndexT
nShfZ
=
ShfZ_t
.
size
(
0
);
DataT
Rca
=
aev_params
.
Rca
;
DataT
Rca
=
aev_params
.
Rca
;
IndexT
num_species
=
aev_params
.
num_species
;
IndexT
num_species
=
aev_params
.
num_species
;
PairDist
<
DataT
>
d
=
d_centralAtom
[
cIdx
];
PairDist
<
DataT
>
d
=
d_centralAtom
[
cIdx
];
int
start_idx
=
d_centerAtomStartIdx
[
cIdx
];
int
start_idx
=
d_centerAtomStartIdx
[
cIdx
];
int
jnum
=
d_nPairsPerCenterAtom
[
cIdx
];
int
jnum
=
d_nPairsPerCenterAtom
[
cIdx
];
// center atom
// center atom
int
i
=
d
.
i
;
int
i
=
d
.
i
;
int
mol_idx
=
d
.
midx
;
int
mol_idx
=
d
.
midx
;
for
(
int
iaev
=
laneIdx
;
iaev
<
aev_params
.
angular_length
;
iaev
+=
threads_per_catom
)
{
for
(
int
iaev
=
laneIdx
;
iaev
<
aev_params
.
angular_length
;
iaev
+=
threads_per_catom
)
{
saev
[
iaev
]
=
0
;
saev
[
iaev
]
=
0
;
}
}
DataT
xi
=
pos_t
[
mol_idx
][
i
][
0
];
DataT
xi
=
pos_t
[
mol_idx
][
i
][
0
];
DataT
yi
=
pos_t
[
mol_idx
][
i
][
1
];
DataT
yi
=
pos_t
[
mol_idx
][
i
][
1
];
DataT
zi
=
pos_t
[
mol_idx
][
i
][
2
];
DataT
zi
=
pos_t
[
mol_idx
][
i
][
2
];
for
(
int
jj
=
laneIdx
;
jj
<
jnum
;
jj
+=
threads_per_catom
)
{
for
(
int
jj
=
laneIdx
;
jj
<
jnum
;
jj
+=
threads_per_catom
)
{
PairDist
<
DataT
>
dij
=
d_Rij
[
start_idx
+
jj
];
PairDist
<
DataT
>
dij
=
d_Rij
[
start_idx
+
jj
];
int
j
=
dij
.
j
;
int
j
=
dij
.
j
;
DataT
Rij
=
dij
.
Rij
;
DataT
Rij
=
dij
.
Rij
;
SpeciesT
type_j
=
species_t
[
mol_idx
][
j
];
SpeciesT
type_j
=
species_t
[
mol_idx
][
j
];
sdx
[
jj
]
=
pos_t
[
mol_idx
][
j
][
0
]
-
xi
;
sdx
[
jj
]
=
pos_t
[
mol_idx
][
j
][
0
]
-
xi
;
sdy
[
jj
]
=
pos_t
[
mol_idx
][
j
][
1
]
-
yi
;
sdy
[
jj
]
=
pos_t
[
mol_idx
][
j
][
1
]
-
yi
;
sdz
[
jj
]
=
pos_t
[
mol_idx
][
j
][
2
]
-
zi
;
sdz
[
jj
]
=
pos_t
[
mol_idx
][
j
][
2
]
-
zi
;
stype
[
jj
]
=
type_j
;
stype
[
jj
]
=
type_j
;
sdist
[
jj
]
=
Rij
;
sdist
[
jj
]
=
Rij
;
DataT
fc_ij
=
0.5
*
cos
(
PI
*
Rij
/
Rca
)
+
0.5
;
DataT
fc_ij
=
0.5
*
cos
(
PI
*
Rij
/
Rca
)
+
0.5
;
sfc
[
jj
]
=
fc_ij
;
sfc
[
jj
]
=
fc_ij
;
}
}
short2
tile
=
make_short2
(
laneIdx
%
TILEX
,
laneIdx
/
TILEX
);
short2
tile
=
make_short2
(
laneIdx
%
TILEX
,
laneIdx
/
TILEX
);
for
(
int
jj
=
0
;
jj
<
jnum
;
jj
++
)
{
for
(
int
jj
=
0
;
jj
<
jnum
;
jj
++
)
{
const
DataT
Rij
=
sdist
[
jj
];
const
DataT
Rij
=
sdist
[
jj
];
SpeciesT
type_j
=
stype
[
jj
];
SpeciesT
type_j
=
stype
[
jj
];
DataT
fc_ij
=
sfc
[
jj
];
DataT
fc_ij
=
sfc
[
jj
];
for
(
int
kk_start
=
jj
+
1
;
kk_start
<
jnum
;
kk_start
+=
threads_per_catom
)
{
for
(
int
kk_start
=
jj
+
1
;
kk_start
<
jnum
;
kk_start
+=
threads_per_catom
)
{
int
kk
=
kk_start
+
laneIdx
;
int
kk
=
kk_start
+
laneIdx
;
DataT
theta
=
0
;
DataT
theta
=
0
;
if
(
kk
<
jnum
)
{
if
(
kk
<
jnum
)
{
const
DataT
Rik
=
sdist
[
kk
];
const
DataT
Rik
=
sdist
[
kk
];
theta
=
theta
=
acos
(
0.95
*
(
sdx
[
jj
]
*
sdx
[
kk
]
+
sdy
[
jj
]
*
sdy
[
kk
]
+
sdz
[
jj
]
*
sdz
[
kk
])
/
(
Rij
*
Rik
));
acos
(
0.95
*
(
sdx
[
jj
]
*
sdx
[
kk
]
+
sdy
[
jj
]
*
sdy
[
kk
]
+
sdz
[
jj
]
*
sdz
[
kk
])
/
(
Rij
*
Rik
));
}
}
for
(
int
srcLane
=
0
;
kk_start
+
srcLane
<
min
(
32
,
jnum
);
++
srcLane
)
{
for
(
int
srcLane
=
0
;
kk_start
+
srcLane
<
min
(
32
,
jnum
);
++
srcLane
)
{
int
kk
=
kk_start
+
srcLane
;
int
kk
=
kk_start
+
srcLane
;
DataT
theta_ijk
=
__shfl_sync
(
0xFFFFFFFF
,
theta
,
srcLane
);
DataT
theta_ijk
=
__shfl_sync
(
0xFFFFFFFF
,
theta
,
srcLane
);
const
DataT
Rik
=
sdist
[
kk
];
const
DataT
Rik
=
sdist
[
kk
];
SpeciesT
type_k
=
stype
[
kk
];
SpeciesT
type_k
=
stype
[
kk
];
DataT
fc_ik
=
sfc
[
kk
];
DataT
fc_ik
=
sfc
[
kk
];
DataT
Rijk
=
(
Rij
+
Rik
)
/
2
;
DataT
Rijk
=
(
Rij
+
Rik
)
/
2
;
DataT
fc_ijk
=
fc_ij
*
fc_ik
;
DataT
fc_ijk
=
fc_ij
*
fc_ik
;
IndexT
subaev_offset
=
csubaev_offsets
[
type_j
*
num_species
+
type_k
];
IndexT
subaev_offset
=
csubaev_offsets
[
type_j
*
num_species
+
type_k
];
IndexT
aev_offset
=
aev_params
.
radial_length
+
subaev_offset
;
IndexT
aev_offset
=
aev_params
.
radial_length
+
subaev_offset
;
for
(
int
itheta
=
tile
.
x
;
itheta
<
nShfZ
;
itheta
+=
TILEX
)
{
for
(
int
itheta
=
tile
.
x
;
itheta
<
nShfZ
;
itheta
+=
TILEX
)
{
DataT
ShfZ
=
ShfZ_t
[
itheta
];
DataT
ShfZ
=
ShfZ_t
[
itheta
];
DataT
factor1
=
pow
((
1
+
cos
(
theta_ijk
-
ShfZ
))
/
2
,
Zeta
);
DataT
factor1
=
pow
((
1
+
cos
(
theta_ijk
-
ShfZ
))
/
2
,
Zeta
);
for
(
int
ishfr
=
tile
.
y
;
ishfr
<
nShfA
;
ishfr
+=
TILEY
)
{
for
(
int
ishfr
=
tile
.
y
;
ishfr
<
nShfA
;
ishfr
+=
TILEY
)
{
DataT
ShfA
=
ShfA_t
[
ishfr
];
DataT
ShfA
=
ShfA_t
[
ishfr
];
DataT
factor2
=
exp
(
-
EtaA
*
(
Rijk
-
ShfA
)
*
(
Rijk
-
ShfA
));
DataT
factor2
=
exp
(
-
EtaA
*
(
Rijk
-
ShfA
)
*
(
Rijk
-
ShfA
));
DataT
res
=
2
*
factor1
*
factor2
*
fc_ijk
;
DataT
res
=
2
*
factor1
*
factor2
*
fc_ijk
;
saev
[
subaev_offset
+
ishfr
*
nShfZ
+
itheta
]
+=
res
;
saev
[
subaev_offset
+
ishfr
*
nShfZ
+
itheta
]
+=
res
;
}
}
}
}
}
}
}
}
}
}
for
(
int
iaev
=
laneIdx
;
iaev
<
aev_params
.
angular_length
;
iaev
+=
threads_per_catom
)
{
for
(
int
iaev
=
laneIdx
;
iaev
<
aev_params
.
angular_length
;
iaev
+=
threads_per_catom
)
{
aev_t
[
mol_idx
][
i
][
aev_params
.
radial_length
+
iaev
]
=
saev
[
iaev
];
aev_t
[
mol_idx
][
i
][
aev_params
.
radial_length
+
iaev
]
=
saev
[
iaev
];
}
}
}
}
template
<
typename
SpeciesT
,
typename
DataT
,
int
THREADS_PER_RIJ
>
template
<
typename
SpeciesT
,
typename
DataT
,
int
THREADS_PER_RIJ
>
__global__
void
cuRadialAEVs
(
__global__
void
cuRadialAEVs
(
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
ShfR_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
ShfR_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
EtaR_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
EtaR_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
aev_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
aev_t
,
PairDist
<
DataT
>*
d_Rij
,
PairDist
<
DataT
>*
d_Rij
,
AEVScalarParams
<
DataT
,
int
>
aev_params
,
AEVScalarParams
<
DataT
,
int
>
aev_params
,
int
nRadialRij
)
{
int
nRadialRij
)
{
int
gidx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
gidx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
idx
=
gidx
/
THREADS_PER_RIJ
;
int
idx
=
gidx
/
THREADS_PER_RIJ
;
int
nShfR
=
ShfR_t
.
size
(
0
);
int
nShfR
=
ShfR_t
.
size
(
0
);
DataT
EtaR
=
EtaR_t
[
0
];
DataT
EtaR
=
EtaR_t
[
0
];
if
(
idx
>=
nRadialRij
)
if
(
idx
>=
nRadialRij
)
return
;
return
;
int
laneIdx
=
threadIdx
.
x
%
THREADS_PER_RIJ
;
int
laneIdx
=
threadIdx
.
x
%
THREADS_PER_RIJ
;
PairDist
<
DataT
>
d
=
d_Rij
[
idx
];
PairDist
<
DataT
>
d
=
d_Rij
[
idx
];
DataT
Rij
=
d
.
Rij
;
DataT
Rij
=
d
.
Rij
;
int
mol_idx
=
d
.
midx
;
int
mol_idx
=
d
.
midx
;
int
i
=
d
.
i
;
int
i
=
d
.
i
;
int
j
=
d
.
j
;
int
j
=
d
.
j
;
SpeciesT
type_i
=
species_t
[
mol_idx
][
i
];
SpeciesT
type_i
=
species_t
[
mol_idx
][
i
];
SpeciesT
type_j
=
species_t
[
mol_idx
][
j
];
SpeciesT
type_j
=
species_t
[
mol_idx
][
j
];
DataT
fc
=
0.5
*
cos
(
PI
*
Rij
/
aev_params
.
Rcr
)
+
0.5
;
DataT
fc
=
0.5
*
cos
(
PI
*
Rij
/
aev_params
.
Rcr
)
+
0.5
;
for
(
int
ishfr
=
laneIdx
;
ishfr
<
nShfR
;
ishfr
+=
THREADS_PER_RIJ
)
{
for
(
int
ishfr
=
laneIdx
;
ishfr
<
nShfR
;
ishfr
+=
THREADS_PER_RIJ
)
{
DataT
ShfR
=
ShfR_t
[
ishfr
];
DataT
ShfR
=
ShfR_t
[
ishfr
];
DataT
GmR
=
0.25
*
exp
(
-
EtaR
*
(
Rij
-
ShfR
)
*
(
Rij
-
ShfR
))
*
fc
;
DataT
GmR
=
0.25
*
exp
(
-
EtaR
*
(
Rij
-
ShfR
)
*
(
Rij
-
ShfR
))
*
fc
;
atomicAdd
(
&
aev_t
[
mol_idx
][
i
][
type_j
*
aev_params
.
radial_sublength
+
ishfr
],
GmR
);
atomicAdd
(
&
aev_t
[
mol_idx
][
i
][
type_j
*
aev_params
.
radial_sublength
+
ishfr
],
GmR
);
}
}
}
}
template
<
typename
DataT
>
template
<
typename
DataT
>
void
cubScan
(
const
DataT
*
d_in
,
DataT
*
d_out
,
int
num_items
,
cudaStream_t
stream
)
{
void
cubScan
(
const
DataT
*
d_in
,
DataT
*
d_out
,
int
num_items
,
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
// Determine temporary device storage requirements
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
// Allocate temporary storage
// Allocate temporary storage
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
d_temp_storage
=
buffer_tmp
.
get
();
d_temp_storage
=
buffer_tmp
.
get
();
// Run exclusive prefix sum
// Run exclusive prefix sum
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
}
}
template
<
typename
DataT
,
typename
IndexT
>
template
<
typename
DataT
,
typename
IndexT
>
int
cubEncode
(
int
cubEncode
(
const
DataT
*
d_in
,
const
DataT
*
d_in
,
DataT
*
d_unique_out
,
DataT
*
d_unique_out
,
IndexT
*
d_counts_out
,
IndexT
*
d_counts_out
,
int
num_items
,
int
num_items
,
int
*
d_num_runs_out
,
int
*
d_num_runs_out
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
// Determine temporary device storage requirements
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRunLengthEncode
::
Encode
(
cub
::
DeviceRunLengthEncode
::
Encode
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_unique_out
,
d_counts_out
,
d_num_runs_out
,
num_items
,
stream
);
d_temp_storage
,
temp_storage_bytes
,
// Allocate temporary storage
d_in
,
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
d_unique_out
,
d_temp_storage
=
buffer_tmp
.
get
();
d_counts_out
,
d_num_runs_out
,
// Run encoding
num_items
,
cub
::
DeviceRunLengthEncode
::
Encode
(
stream
);
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_unique_out
,
d_counts_out
,
d_num_runs_out
,
num_items
,
stream
);
// Allocate temporary storage
int
num_selected
=
0
;
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
cudaMemcpyAsync
(
&
num_selected
,
d_num_runs_out
,
sizeof
(
int
),
cudaMemcpyDefault
,
stream
);
d_temp_storage
=
buffer_tmp
.
get
();
cudaStreamSynchronize
(
stream
);
return
num_selected
;
// Run encoding
}
cub
::
DeviceRunLengthEncode
::
Encode
(
d_temp_storage
,
template
<
typename
DataT
,
typename
LambdaOpT
>
temp_storage_bytes
,
int
cubDeviceSelect
(
d_in
,
const
DataT
*
d_in
,
d_unique_out
,
DataT
*
d_out
,
d_counts_out
,
int
num_items
,
d_num_runs_out
,
int
*
d_num_selected_out
,
num_items
,
LambdaOpT
select_op
,
stream
);
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
int
num_selected
=
0
;
cudaMemcpyAsync
(
&
num_selected
,
d_num_runs_out
,
sizeof
(
int
),
cudaMemcpyDefault
,
stream
);
// Determine temporary device storage requirements
cudaStreamSynchronize
(
stream
);
void
*
d_temp_storage
=
NULL
;
return
num_selected
;
size_t
temp_storage_bytes
=
0
;
}
cub
::
DeviceSelect
::
If
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
d_num_selected_out
,
num_items
,
select_op
);
template
<
typename
DataT
,
typename
LambdaOpT
>
// Allocate temporary storage
int
cubDeviceSelect
(
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
const
DataT
*
d_in
,
d_temp_storage
=
buffer_tmp
.
get
();
DataT
*
d_out
,
int
num_items
,
// Run selection
int
*
d_num_selected_out
,
cub
::
DeviceSelect
::
If
(
LambdaOpT
select_op
,
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
d_num_selected_out
,
num_items
,
select_op
,
stream
);
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
int
num_selected
=
0
;
cudaMemcpyAsync
(
&
num_selected
,
d_num_selected_out
,
sizeof
(
int
),
cudaMemcpyDefault
,
stream
);
// Determine temporary device storage requirements
cudaStreamSynchronize
(
stream
);
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
return
num_selected
;
cub
::
DeviceSelect
::
If
(
}
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
d_num_selected_out
,
num_items
,
select_op
);
template
<
typename
DataT
>
// Allocate temporary storage
DataT
cubMax
(
const
DataT
*
d_in
,
int
num_items
,
DataT
*
d_out
,
cudaStream_t
stream
)
{
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
d_temp_storage
=
buffer_tmp
.
get
();
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
// Run selection
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceSelect
::
If
(
cub
::
DeviceReduce
::
Max
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
d_temp_storage
,
temp_storage_bytes
,
// Allocate temporary storage
d_in
,
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
d_out
,
d_temp_storage
=
buffer_tmp
.
get
();
d_num_selected_out
,
num_items
,
// Run min-reduction
select_op
,
cub
::
DeviceReduce
::
Max
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
stream
);
int
maxVal
=
0
;
int
num_selected
=
0
;
cudaMemcpyAsync
(
&
maxVal
,
d_out
,
sizeof
(
DataT
),
cudaMemcpyDefault
,
stream
);
cudaMemcpyAsync
(
&
num_selected
,
d_num_selected_out
,
sizeof
(
int
),
cudaMemcpyDefault
,
stream
);
cudaStreamSynchronize
(
stream
);
cudaStreamSynchronize
(
stream
);
return
maxVal
;
return
num_selected
;
}
}
void
initConsts
(
AEVScalarParams
<
float
>&
aev_params
,
cudaStream_t
stream
)
{
template
<
typename
DataT
>
int
num_species
=
aev_params
.
num_species
;
DataT
cubMax
(
const
DataT
*
d_in
,
int
num_items
,
DataT
*
d_out
,
cudaStream_t
stream
)
{
assert
(
num_species
<=
MAX_NSPECIES
);
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
// precompute the aev offsets and load to constand memory
// Determine temporary device storage requirements
int
*
subaev_offsets
=
new
int
[
num_species
*
num_species
];
void
*
d_temp_storage
=
NULL
;
for
(
int
t
=
0
;
t
<
num_species
;
++
t
)
{
size_t
temp_storage_bytes
=
0
;
int
offset
=
0
;
cub
::
DeviceReduce
::
Max
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
for
(
int
s
=
0
;
s
<
num_species
;
s
++
)
{
if
(
t
<
num_species
-
s
)
{
// Allocate temporary storage
subaev_offsets
[
s
*
num_species
+
s
+
t
]
=
aev_params
.
angular_sublength
*
(
offset
+
t
);
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
subaev_offsets
[(
s
+
t
)
*
num_species
+
s
]
=
aev_params
.
angular_sublength
*
(
offset
+
t
);
d_temp_storage
=
buffer_tmp
.
get
();
}
offset
+=
num_species
-
s
;
// Run min-reduction
}
cub
::
DeviceReduce
::
Max
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
}
cudaMemcpyToSymbolAsync
(
int
maxVal
=
0
;
csubaev_offsets
,
subaev_offsets
,
sizeof
(
int
)
*
num_species
*
num_species
,
0
,
cudaMemcpyDefault
,
stream
);
cudaMemcpyAsync
(
&
maxVal
,
d_out
,
sizeof
(
DataT
),
cudaMemcpyDefault
,
stream
);
delete
[]
subaev_offsets
;
cudaStreamSynchronize
(
stream
);
}
return
maxVal
;
// NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1
}
template
<
typename
ScalarRealT
=
float
>
torch
::
Tensor
cuComputeAEV
(
void
initConsts
(
AEVScalarParams
<
float
>&
aev_params
,
cudaStream_t
stream
)
{
torch
::
Tensor
coordinates_t
,
int
num_species
=
aev_params
.
num_species
;
torch
::
Tensor
species_t
,
assert
(
num_species
<=
MAX_NSPECIES
);
double
Rcr_
,
// precompute the aev offsets and load to constand memory
double
Rca_
,
int
*
subaev_offsets
=
new
int
[
num_species
*
num_species
];
torch
::
Tensor
EtaR_t
,
for
(
int
t
=
0
;
t
<
num_species
;
++
t
)
{
torch
::
Tensor
ShfR_t
,
int
offset
=
0
;
torch
::
Tensor
EtaA_t
,
for
(
int
s
=
0
;
s
<
num_species
;
s
++
)
{
torch
::
Tensor
Zeta_t
,
if
(
t
<
num_species
-
s
)
{
torch
::
Tensor
ShfA_t
,
subaev_offsets
[
s
*
num_species
+
s
+
t
]
=
aev_params
.
angular_sublength
*
(
offset
+
t
);
torch
::
Tensor
ShfZ_t
,
subaev_offsets
[(
s
+
t
)
*
num_species
+
s
]
=
aev_params
.
angular_sublength
*
(
offset
+
t
);
int64_t
num_species_
)
{
}
TORCH_CHECK
(
offset
+=
num_species
-
s
;
(
species_t
.
dtype
()
==
torch
::
kInt32
)
&&
(
coordinates_t
.
dtype
()
==
torch
::
kFloat32
),
"Unsupported input type"
);
}
TORCH_CHECK
(
}
EtaR_t
.
size
(
0
)
==
1
||
EtaA_t
.
size
(
0
)
==
1
||
Zeta_t
.
size
(
0
)
==
1
,
cudaMemcpyToSymbolAsync
(
"cuda extension is currently not supported for the specified "
csubaev_offsets
,
"configuration"
);
subaev_offsets
,
sizeof
(
int
)
*
num_species
*
num_species
,
ScalarRealT
Rcr
=
Rcr_
;
0
,
ScalarRealT
Rca
=
Rca_
;
cudaMemcpyDefault
,
int
num_species
=
num_species_
;
stream
);
delete
[]
subaev_offsets
;
const
int
n_molecules
=
species_t
.
size
(
0
);
}
const
int
max_natoms_per_mol
=
species_t
.
size
(
1
);
// NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1
AEVScalarParams
<
float
>
aev_params
;
template
<
typename
ScalarRealT
=
float
>
aev_params
.
Rca
=
Rca
;
torch
::
Tensor
cuComputeAEV
(
aev_params
.
Rcr
=
Rcr
;
torch
::
Tensor
coordinates_t
,
aev_params
.
num_species
=
num_species
;
torch
::
Tensor
species_t
,
double
Rcr_
,
aev_params
.
radial_sublength
=
EtaR_t
.
size
(
0
)
*
ShfR_t
.
size
(
0
);
double
Rca_
,
aev_params
.
radial_length
=
aev_params
.
radial_sublength
*
num_species
;
torch
::
Tensor
EtaR_t
,
torch
::
Tensor
ShfR_t
,
aev_params
.
angular_sublength
=
EtaA_t
.
size
(
0
)
*
Zeta_t
.
size
(
0
)
*
ShfA_t
.
size
(
0
)
*
ShfZ_t
.
size
(
0
);
torch
::
Tensor
EtaA_t
,
aev_params
.
angular_length
=
aev_params
.
angular_sublength
*
(
num_species
*
(
num_species
+
1
)
/
2
);
torch
::
Tensor
Zeta_t
,
torch
::
Tensor
ShfA_t
,
int
aev_length
=
aev_params
.
radial_length
+
aev_params
.
angular_length
;
torch
::
Tensor
ShfZ_t
,
int64_t
num_species_
)
{
auto
aev_t
=
torch
::
zeros
({
n_molecules
,
max_natoms_per_mol
,
aev_length
},
coordinates_t
.
options
());
TORCH_CHECK
(
(
species_t
.
dtype
()
==
torch
::
kInt32
)
&&
(
coordinates_t
.
dtype
()
==
torch
::
kFloat32
),
if
(
species_t
.
numel
()
==
0
)
{
"Unsupported input type"
);
return
aev_t
;
TORCH_CHECK
(
}
EtaR_t
.
size
(
0
)
==
1
||
EtaA_t
.
size
(
0
)
==
1
||
Zeta_t
.
size
(
0
)
==
1
,
"cuda extension is currently not supported for the specified "
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
"configuration"
);
auto
thrust_allocator
=
THCThrustAllocator
(
at
::
globalContext
().
lazyInitCUDA
());
auto
policy
=
thrust
::
cuda
::
par
(
thrust_allocator
).
on
(
stream
);
ScalarRealT
Rcr
=
Rcr_
;
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
ScalarRealT
Rca
=
Rca_
;
int
num_species
=
num_species_
;
// precompute the aev offsets and load to constand memory
initConsts
(
aev_params
,
stream
);
const
int
n_molecules
=
species_t
.
size
(
0
);
const
int
max_natoms_per_mol
=
species_t
.
size
(
1
);
// buffer to store all the pairwise distance (Rij)
auto
total_natom_pairs
=
n_molecules
*
max_natoms_per_mol
*
max_natoms_per_mol
;
AEVScalarParams
<
float
>
aev_params
;
auto
buffer_Rij
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
total_natom_pairs
);
aev_params
.
Rca
=
Rca
;
PairDist
<
float
>*
d_Rij
=
(
PairDist
<
float
>*
)
buffer_Rij
.
get
();
aev_params
.
Rcr
=
Rcr
;
aev_params
.
num_species
=
num_species
;
// init all Rij to inf
PairDist
<
float
>
init
;
aev_params
.
radial_sublength
=
EtaR_t
.
size
(
0
)
*
ShfR_t
.
size
(
0
);
init
.
Rij
=
std
::
numeric_limits
<
float
>::
infinity
();
aev_params
.
radial_length
=
aev_params
.
radial_sublength
*
num_species
;
thrust
::
fill
(
policy
,
d_Rij
,
d_Rij
+
total_natom_pairs
,
init
);
aev_params
.
angular_sublength
=
EtaA_t
.
size
(
0
)
*
Zeta_t
.
size
(
0
)
*
ShfA_t
.
size
(
0
)
*
ShfZ_t
.
size
(
0
);
// buffer to store all the pairwise distance that is needed for Radial AEV
aev_params
.
angular_length
=
aev_params
.
angular_sublength
*
(
num_species
*
(
num_species
+
1
)
/
2
);
// computation
auto
buffer_radialRij
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
total_natom_pairs
);
int
aev_length
=
aev_params
.
radial_length
+
aev_params
.
angular_length
;
PairDist
<
float
>*
d_radialRij
=
(
PairDist
<
float
>*
)
buffer_radialRij
.
get
();
auto
aev_t
=
torch
::
zeros
({
n_molecules
,
max_natoms_per_mol
,
aev_length
},
coordinates_t
.
options
());
auto
buffer_count
=
allocator
.
allocate
(
sizeof
(
int
));
int
*
d_count_out
=
(
int
*
)
buffer_count
.
get
();
if
(
species_t
.
numel
()
==
0
)
{
return
aev_t
;
const
int
block_size
=
64
;
}
dim3
block
(
8
,
8
,
1
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// Compute pairwise distance (Rij) for all atom pairs in a molecule
auto
thrust_allocator
=
THCThrustAllocator
(
at
::
globalContext
().
lazyInitCUDA
());
pairwiseDistance
<<<
n_molecules
,
block
,
sizeof
(
float
)
*
max_natoms_per_mol
*
3
,
stream
>>>
(
auto
policy
=
thrust
::
cuda
::
par
(
thrust_allocator
).
on
(
stream
);
species_t
.
packed_accessor32
<
int
,
2
,
torch
::
RestrictPtrTraits
>
(),
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
coordinates_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
d_Rij
,
// precompute the aev offsets and load to constand memory
max_natoms_per_mol
);
initConsts
(
aev_params
,
stream
);
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <=
// buffer to store all the pairwise distance (Rij)
// Rcr
auto
total_natom_pairs
=
n_molecules
*
max_natoms_per_mol
*
max_natoms_per_mol
;
int
nRadialRij
=
cubDeviceSelect
(
auto
buffer_Rij
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
total_natom_pairs
);
d_Rij
,
PairDist
<
float
>*
d_Rij
=
(
PairDist
<
float
>*
)
buffer_Rij
.
get
();
d_radialRij
,
total_natom_pairs
,
// init all Rij to inf
d_count_out
,
PairDist
<
float
>
init
;
[
=
]
__device__
(
const
PairDist
<
float
>
d
)
{
return
d
.
Rij
<=
Rcr
;
},
init
.
Rij
=
std
::
numeric_limits
<
float
>::
infinity
();
stream
);
thrust
::
fill
(
policy
,
d_Rij
,
d_Rij
+
total_natom_pairs
,
init
);
int
nblocks
=
(
nRadialRij
*
8
+
block_size
-
1
)
/
block_size
;
// buffer to store all the pairwise distance that is needed for Radial AEV
cuRadialAEVs
<
int
,
float
,
8
><<<
nblocks
,
block_size
,
0
,
stream
>>>
(
// computation
species_t
.
packed_accessor32
<
int
,
2
,
torch
::
RestrictPtrTraits
>
(),
auto
buffer_radialRij
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
total_natom_pairs
);
ShfR_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
PairDist
<
float
>*
d_radialRij
=
(
PairDist
<
float
>*
)
buffer_radialRij
.
get
();
EtaR_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
aev_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
auto
buffer_count
=
allocator
.
allocate
(
sizeof
(
int
));
d_radialRij
,
int
*
d_count_out
=
(
int
*
)
buffer_count
.
get
();
aev_params
,
nRadialRij
);
const
int
block_size
=
64
;
// reuse buffer allocated for all Rij
dim3
block
(
8
,
8
,
1
);
// d_angularRij will store all the Rij required in Angular AEV computation
// Compute pairwise distance (Rij) for all atom pairs in a molecule
PairDist
<
float
>*
d_angularRij
=
d_Rij
;
pairwiseDistance
<<<
n_molecules
,
block
,
sizeof
(
float
)
*
max_natoms_per_mol
*
3
,
stream
>>>
(
species_t
.
packed_accessor32
<
int
,
2
,
torch
::
RestrictPtrTraits
>
(),
// Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij
coordinates_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
// <= Rca
d_Rij
,
int
nAngularRij
=
cubDeviceSelect
(
max_natoms_per_mol
);
d_radialRij
,
d_angularRij
,
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <=
nRadialRij
,
// Rcr
d_count_out
,
int
nRadialRij
=
cubDeviceSelect
(
[
=
]
__device__
(
const
PairDist
<
float
>
d
)
{
return
d
.
Rij
<=
Rca
;
},
d_Rij
,
stream
);
d_radialRij
,
total_natom_pairs
,
auto
buffer_centralAtom
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
nAngularRij
);
d_count_out
,
PairDist
<
float
>*
d_centralAtom
=
(
PairDist
<
float
>*
)
buffer_centralAtom
.
get
();
[
=
]
__device__
(
const
PairDist
<
float
>
d
)
{
return
d
.
Rij
<=
Rcr
;
},
stream
);
auto
buffer_numPairsPerCenterAtom
=
allocator
.
allocate
(
sizeof
(
int
)
*
nAngularRij
);
int
*
d_numPairsPerCenterAtom
=
(
int
*
)
buffer_numPairsPerCenterAtom
.
get
();
int
nblocks
=
(
nRadialRij
*
8
+
block_size
-
1
)
/
block_size
;
cuRadialAEVs
<
int
,
float
,
8
><<<
nblocks
,
block_size
,
0
,
stream
>>>
(
// group by center atom
species_t
.
packed_accessor32
<
int
,
2
,
torch
::
RestrictPtrTraits
>
(),
int
ncenter_atoms
=
cubEncode
(
d_angularRij
,
d_centralAtom
,
d_numPairsPerCenterAtom
,
nAngularRij
,
d_count_out
,
stream
);
ShfR_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
EtaR_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
auto
buffer_centerAtomStartIdx
=
allocator
.
allocate
(
sizeof
(
int
)
*
ncenter_atoms
);
aev_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
int
*
d_centerAtomStartIdx
=
(
int
*
)
buffer_centerAtomStartIdx
.
get
();
d_radialRij
,
aev_params
,
cubScan
(
d_numPairsPerCenterAtom
,
d_centerAtomStartIdx
,
ncenter_atoms
,
stream
);
nRadialRij
);
{
const
int
nthreads_per_catom
=
32
;
// reuse buffer allocated for all Rij
const
int
nblocks_angAEV
=
(
ncenter_atoms
*
nthreads_per_catom
+
block_size
-
1
)
/
block_size
;
// d_angularRij will store all the Rij required in Angular AEV computation
auto
smem_size
=
[
&
aev_params
](
int
max_nbrs
,
int
ncatom_per_tpb
)
{
PairDist
<
float
>*
d_angularRij
=
d_Rij
;
int
sm_aev
=
sizeof
(
float
)
*
align
<
4
>
(
aev_params
.
angular_length
);
int
sxyz
=
sizeof
(
float
)
*
max_nbrs
*
3
;
// Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij
int
sRij
=
sizeof
(
float
)
*
max_nbrs
;
// <= Rca
int
sfc
=
sizeof
(
float
)
*
max_nbrs
;
int
nAngularRij
=
cubDeviceSelect
(
int
sj
=
sizeof
(
int
)
*
max_nbrs
;
d_radialRij
,
d_angularRij
,
return
(
sm_aev
+
sxyz
+
sRij
+
sfc
+
sj
)
*
ncatom_per_tpb
;
nRadialRij
,
};
d_count_out
,
[
=
]
__device__
(
const
PairDist
<
float
>
d
)
{
return
d
.
Rij
<=
Rca
;
},
int
maxNbrsPerCenterAtom
=
cubMax
(
d_numPairsPerCenterAtom
,
ncenter_atoms
,
d_count_out
,
stream
);
stream
);
int
maxnbrs_per_atom_aligned
=
align
<
4
>
(
maxNbrsPerCenterAtom
);
auto
buffer_centralAtom
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
nAngularRij
);
PairDist
<
float
>*
d_centralAtom
=
(
PairDist
<
float
>*
)
buffer_centralAtom
.
get
();
cuAngularAEVs
<<<
nblocks_angAEV
,
auto
buffer_numPairsPerCenterAtom
=
allocator
.
allocate
(
sizeof
(
int
)
*
nAngularRij
);
block_size
,
int
*
d_numPairsPerCenterAtom
=
(
int
*
)
buffer_numPairsPerCenterAtom
.
get
();
smem_size
(
maxnbrs_per_atom_aligned
,
block_size
/
nthreads_per_catom
),
stream
>>>
(
// group by center atom
species_t
.
packed_accessor32
<
int
,
2
,
torch
::
RestrictPtrTraits
>
(),
int
ncenter_atoms
=
cubEncode
(
coordinates_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
d_angularRij
,
d_centralAtom
,
d_numPairsPerCenterAtom
,
nAngularRij
,
d_count_out
,
stream
);
ShfA_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
ShfZ_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
auto
buffer_centerAtomStartIdx
=
allocator
.
allocate
(
sizeof
(
int
)
*
ncenter_atoms
);
EtaA_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
int
*
d_centerAtomStartIdx
=
(
int
*
)
buffer_centerAtomStartIdx
.
get
();
Zeta_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
aev_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
cubScan
(
d_numPairsPerCenterAtom
,
d_centerAtomStartIdx
,
ncenter_atoms
,
stream
);
d_angularRij
,
{
d_centralAtom
,
const
int
nthreads_per_catom
=
32
;
d_numPairsPerCenterAtom
,
const
int
nblocks_angAEV
=
(
ncenter_atoms
*
nthreads_per_catom
+
block_size
-
1
)
/
block_size
;
d_centerAtomStartIdx
,
auto
smem_size
=
[
&
aev_params
](
int
max_nbrs
,
int
ncatom_per_tpb
)
{
aev_params
,
int
sm_aev
=
sizeof
(
float
)
*
align
<
4
>
(
aev_params
.
angular_length
);
maxnbrs_per_atom_aligned
,
int
sxyz
=
sizeof
(
float
)
*
max_nbrs
*
3
;
align
<
4
>
(
aev_params
.
angular_length
),
int
sRij
=
sizeof
(
float
)
*
max_nbrs
;
ncenter_atoms
);
int
sfc
=
sizeof
(
float
)
*
max_nbrs
;
}
int
sj
=
sizeof
(
int
)
*
max_nbrs
;
return
aev_t
;
}
return
(
sm_aev
+
sxyz
+
sRij
+
sfc
+
sj
)
*
ncatom_per_tpb
;
};
TORCH_LIBRARY
(
cuaev
,
m
)
{
m
.
def
(
"cuComputeAEV"
,
&
cuComputeAEV
<
float
>
);
int
maxNbrsPerCenterAtom
=
cubMax
(
d_numPairsPerCenterAtom
,
ncenter_atoms
,
d_count_out
,
stream
);
}
int
maxnbrs_per_atom_aligned
=
align
<
4
>
(
maxNbrsPerCenterAtom
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{}
cuAngularAEVs
<<<
nblocks_angAEV
,
block_size
,
smem_size
(
maxnbrs_per_atom_aligned
,
block_size
/
nthreads_per_catom
),
stream
>>>
(
species_t
.
packed_accessor32
<
int
,
2
,
torch
::
RestrictPtrTraits
>
(),
coordinates_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
ShfA_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
ShfZ_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
EtaA_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
Zeta_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
aev_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
d_angularRij
,
d_centralAtom
,
d_numPairsPerCenterAtom
,
d_centerAtomStartIdx
,
aev_params
,
maxnbrs_per_atom_aligned
,
align
<
4
>
(
aev_params
.
angular_length
),
ncenter_atoms
);
}
return
aev_t
;
}
TORCH_LIBRARY
(
cuaev
,
m
)
{
m
.
def
(
"cuComputeAEV"
,
&
cuComputeAEV
<
float
>
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment