Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
a6d819ed
Unverified
Commit
a6d819ed
authored
Nov 20, 2020
by
Richard Xue
Committed by
GitHub
Nov 20, 2020
Browse files
clang-format line-limit to 120 (#552)
* clang-format * line-limit 120
parent
5ff2f8fc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
274 additions
and
158 deletions
+274
-158
.clang-format
.clang-format
+88
-0
torchani/cuaev/aev.cu
torchani/cuaev/aev.cu
+186
-158
No files found.
.clang-format
0 → 100644
View file @
a6d819ed
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 2000000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
torchani/cuaev/aev.cu
View file @
a6d819ed
#include <cub/cub.cuh>
#include <thrust/equal.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include <ATen/Context.h>
#include <THC/THC.h>
#include <THC/THCThrustAllocator.cuh>
#include <c10/cuda/CUDACachingAllocator.h>
#include <THC/THCThrustAllocator.cuh>
#define PI 3.141592653589793
template
<
typename
DataT
,
typename
IndexT
=
int
>
struct
AEVScalarParams
{
template
<
typename
DataT
,
typename
IndexT
=
int
>
struct
AEVScalarParams
{
DataT
Rcr
;
DataT
Rca
;
...
...
@@ -23,7 +24,8 @@ template <typename DataT, typename IndexT = int> struct AEVScalarParams {
#define MAX_NSPECIES 10
__constant__
int
csubaev_offsets
[
MAX_NSPECIES
*
MAX_NSPECIES
];
template
<
typename
DataT
>
struct
PairDist
{
template
<
typename
DataT
>
struct
PairDist
{
DataT
Rij
;
int
midx
;
short
i
;
...
...
@@ -32,8 +34,7 @@ template <typename DataT> struct PairDist {
// used to group Rijs by atom id
template
<
typename
DataT
>
__host__
__device__
bool
operator
==
(
const
PairDist
<
DataT
>
&
lhs
,
const
PairDist
<
DataT
>
&
rhs
)
{
__host__
__device__
bool
operator
==
(
const
PairDist
<
DataT
>&
lhs
,
const
PairDist
<
DataT
>&
rhs
)
{
return
lhs
.
midx
==
rhs
.
midx
&&
lhs
.
i
==
rhs
.
i
;
}
...
...
@@ -42,23 +43,21 @@ __host__ __device__ bool operator==(const PairDist<DataT> &lhs,
/// \param value Input value that is to be aligned
/// \return Value aligned to boundary
template
<
int32_t
boundary
>
__host__
__device__
__forceinline__
int
align
(
const
int
&
value
)
{
static_assert
((
boundary
&
(
boundary
-
1
))
==
0
,
"Boundary for align must be power of 2"
);
__host__
__device__
__forceinline__
int
align
(
const
int
&
value
)
{
static_assert
((
boundary
&
(
boundary
-
1
))
==
0
,
"Boundary for align must be power of 2"
);
return
(
value
+
boundary
)
&
~
(
boundary
-
1
);
}
template
<
typename
SpeciesT
,
typename
DataT
,
typename
IndexT
=
int
>
__global__
void
pairwiseDistance
(
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
pos_t
,
PairDist
<
DataT
>
*
d_Rij
,
IndexT
max_natoms_per_mol
)
{
PairDist
<
DataT
>*
d_Rij
,
IndexT
max_natoms_per_mol
)
{
extern
__shared__
DataT
spos
[];
DataT
*
sx
=
&
spos
[
0
];
DataT
*
sy
=
&
spos
[
max_natoms_per_mol
];
DataT
*
sz
=
&
spos
[
2
*
max_natoms_per_mol
];
DataT
*
sx
=
&
spos
[
0
];
DataT
*
sy
=
&
spos
[
max_natoms_per_mol
];
DataT
*
sz
=
&
spos
[
2
*
max_natoms_per_mol
];
int
mol_idx
=
blockIdx
.
x
;
int
tidx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -74,7 +73,6 @@ __global__ void pairwiseDistance(
int
natom_pairs
=
max_natoms_per_mol
*
max_natoms_per_mol
;
for
(
int
i
=
threadIdx
.
y
;
i
<
max_natoms_per_mol
;
i
+=
blockDim
.
y
)
{
SpeciesT
type_i
=
species_t
[
mol_idx
][
i
];
DataT
xi
=
sx
[
i
];
...
...
@@ -107,21 +105,23 @@ __global__ void pairwiseDistance(
}
}
template
<
typename
SpeciesT
,
typename
DataT
,
typename
IndexT
=
int
,
int
TILEX
=
8
,
int
TILEY
=
4
>
template
<
typename
SpeciesT
,
typename
DataT
,
typename
IndexT
=
int
,
int
TILEX
=
8
,
int
TILEY
=
4
>
__global__
void
cuAngularAEVs
(
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
pos_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
ShfA_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
ShfZ_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
EtaA_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
Zeta_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
aev_t
,
PairDist
<
DataT
>
*
d_Rij
,
PairDist
<
DataT
>
*
d_centralAtom
,
int
*
d_nPairsPerCenterAtom
,
int
*
d_centerAtomStartIdx
,
AEVScalarParams
<
DataT
,
IndexT
>
aev_params
,
int
maxnbrs_per_atom_aligned
,
int
angular_length_aligned
,
int
ncentral_atoms
)
{
PairDist
<
DataT
>*
d_Rij
,
PairDist
<
DataT
>*
d_centralAtom
,
int
*
d_nPairsPerCenterAtom
,
int
*
d_centerAtomStartIdx
,
AEVScalarParams
<
DataT
,
IndexT
>
aev_params
,
int
maxnbrs_per_atom_aligned
,
int
angular_length_aligned
,
int
ncentral_atoms
)
{
extern
__shared__
DataT
smem
[];
int
threads_per_catom
=
TILEX
*
TILEY
;
...
...
@@ -135,25 +135,25 @@ __global__ void cuAngularAEVs(
int
laneIdx
=
threadIdx
.
x
%
threads_per_catom
;
int
ncatom_per_tpb
=
blockDim
.
x
/
threads_per_catom
;
DataT
*
saev
=
&
smem
[
groupIdx
*
angular_length_aligned
];
DataT
*
saev
=
&
smem
[
groupIdx
*
angular_length_aligned
];
int
offset
=
ncatom_per_tpb
*
angular_length_aligned
;
DataT
*
sdx
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sdx
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
DataT
*
sdy
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sdy
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
DataT
*
sdz
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sdz
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
DataT
*
sdist
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sdist
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
DataT
*
sfc
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
*
sfc
=
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
offset
+=
ncatom_per_tpb
*
maxnbrs_per_atom_aligned
;
int
*
stype
=
(
int
*
)
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
int
*
stype
=
(
int
*
)
&
smem
[
offset
+
groupIdx
*
maxnbrs_per_atom_aligned
];
DataT
EtaA
=
EtaA_t
[
0
];
DataT
Zeta
=
Zeta_t
[
0
];
...
...
@@ -171,8 +171,7 @@ __global__ void cuAngularAEVs(
int
i
=
d
.
i
;
int
mol_idx
=
d
.
midx
;
for
(
int
iaev
=
laneIdx
;
iaev
<
aev_params
.
angular_length
;
iaev
+=
threads_per_catom
)
{
for
(
int
iaev
=
laneIdx
;
iaev
<
aev_params
.
angular_length
;
iaev
+=
threads_per_catom
)
{
saev
[
iaev
]
=
0
;
}
...
...
@@ -202,16 +201,13 @@ __global__ void cuAngularAEVs(
DataT
fc_ij
=
sfc
[
jj
];
for
(
int
kk_start
=
jj
+
1
;
kk_start
<
jnum
;
kk_start
+=
threads_per_catom
)
{
for
(
int
kk_start
=
jj
+
1
;
kk_start
<
jnum
;
kk_start
+=
threads_per_catom
)
{
int
kk
=
kk_start
+
laneIdx
;
DataT
theta
=
0
;
if
(
kk
<
jnum
)
{
const
DataT
Rik
=
sdist
[
kk
];
theta
=
acos
(
0.95
*
(
sdx
[
jj
]
*
sdx
[
kk
]
+
sdy
[
jj
]
*
sdy
[
kk
]
+
sdz
[
jj
]
*
sdz
[
kk
])
/
(
Rij
*
Rik
));
theta
=
acos
(
0.95
*
(
sdx
[
jj
]
*
sdx
[
kk
]
+
sdy
[
jj
]
*
sdy
[
kk
]
+
sdz
[
jj
]
*
sdz
[
kk
])
/
(
Rij
*
Rik
));
}
for
(
int
srcLane
=
0
;
kk_start
+
srcLane
<
min
(
32
,
jnum
);
++
srcLane
)
{
...
...
@@ -247,22 +243,20 @@ __global__ void cuAngularAEVs(
}
}
for
(
int
iaev
=
laneIdx
;
iaev
<
aev_params
.
angular_length
;
iaev
+=
threads_per_catom
)
{
for
(
int
iaev
=
laneIdx
;
iaev
<
aev_params
.
angular_length
;
iaev
+=
threads_per_catom
)
{
aev_t
[
mol_idx
][
i
][
aev_params
.
radial_length
+
iaev
]
=
saev
[
iaev
];
}
}
template
<
typename
SpeciesT
,
typename
DataT
,
int
THREADS_PER_RIJ
>
__global__
void
cuRadialAEVs
(
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
SpeciesT
,
2
,
torch
::
RestrictPtrTraits
>
species_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
ShfR_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
1
,
torch
::
RestrictPtrTraits
>
EtaR_t
,
torch
::
PackedTensorAccessor32
<
DataT
,
3
,
torch
::
RestrictPtrTraits
>
aev_t
,
PairDist
<
DataT
>
*
d_Rij
,
AEVScalarParams
<
DataT
,
int
>
aev_params
,
PairDist
<
DataT
>*
d_Rij
,
AEVScalarParams
<
DataT
,
int
>
aev_params
,
int
nRadialRij
)
{
int
gidx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
idx
=
gidx
/
THREADS_PER_RIJ
;
...
...
@@ -290,104 +284,123 @@ __global__ void cuRadialAEVs(
DataT
GmR
=
0.25
*
exp
(
-
EtaR
*
(
Rij
-
ShfR
)
*
(
Rij
-
ShfR
))
*
fc
;
atomicAdd
(
&
aev_t
[
mol_idx
][
i
][
type_j
*
aev_params
.
radial_sublength
+
ishfr
],
GmR
);
atomicAdd
(
&
aev_t
[
mol_idx
][
i
][
type_j
*
aev_params
.
radial_sublength
+
ishfr
],
GmR
);
}
}
template
<
typename
DataT
>
void
cubScan
(
const
DataT
*
d_in
,
DataT
*
d_out
,
int
num_items
,
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
void
cubScan
(
const
DataT
*
d_in
,
DataT
*
d_out
,
int
num_items
,
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
// Allocate temporary storage
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
d_temp_storage
=
buffer_tmp
.
get
();
// Run exclusive prefix sum
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
}
template
<
typename
DataT
,
typename
IndexT
>
int
cubEncode
(
const
DataT
*
d_in
,
DataT
*
d_unique_out
,
IndexT
*
d_counts_out
,
int
num_items
,
int
*
d_num_runs_out
,
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
int
cubEncode
(
const
DataT
*
d_in
,
DataT
*
d_unique_out
,
IndexT
*
d_counts_out
,
int
num_items
,
int
*
d_num_runs_out
,
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRunLengthEncode
::
Encode
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_unique_out
,
d_counts_out
,
d_num_runs_out
,
num_items
,
stream
);
cub
::
DeviceRunLengthEncode
::
Encode
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_unique_out
,
d_counts_out
,
d_num_runs_out
,
num_items
,
stream
);
// Allocate temporary storage
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
d_temp_storage
=
buffer_tmp
.
get
();
// Run encoding
cub
::
DeviceRunLengthEncode
::
Encode
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_unique_out
,
d_counts_out
,
d_num_runs_out
,
num_items
,
stream
);
cub
::
DeviceRunLengthEncode
::
Encode
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_unique_out
,
d_counts_out
,
d_num_runs_out
,
num_items
,
stream
);
int
num_selected
=
0
;
cudaMemcpyAsync
(
&
num_selected
,
d_num_runs_out
,
sizeof
(
int
),
cudaMemcpyDefault
,
stream
);
cudaMemcpyAsync
(
&
num_selected
,
d_num_runs_out
,
sizeof
(
int
),
cudaMemcpyDefault
,
stream
);
cudaStreamSynchronize
(
stream
);
return
num_selected
;
}
template
<
typename
DataT
,
typename
LambdaOpT
>
int
cubDeviceSelect
(
const
DataT
*
d_in
,
DataT
*
d_out
,
int
num_items
,
int
*
d_num_selected_out
,
LambdaOpT
select_op
,
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
int
cubDeviceSelect
(
const
DataT
*
d_in
,
DataT
*
d_out
,
int
num_items
,
int
*
d_num_selected_out
,
LambdaOpT
select_op
,
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceSelect
::
If
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
d_num_selected_out
,
num_items
,
select_op
);
cub
::
DeviceSelect
::
If
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
d_num_selected_out
,
num_items
,
select_op
);
// Allocate temporary storage
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
d_temp_storage
=
buffer_tmp
.
get
();
// Run selection
cub
::
DeviceSelect
::
If
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
d_num_selected_out
,
num_items
,
select_op
,
stream
);
cub
::
DeviceSelect
::
If
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
d_num_selected_out
,
num_items
,
select_op
,
stream
);
int
num_selected
=
0
;
cudaMemcpyAsync
(
&
num_selected
,
d_num_selected_out
,
sizeof
(
int
),
cudaMemcpyDefault
,
stream
);
cudaMemcpyAsync
(
&
num_selected
,
d_num_selected_out
,
sizeof
(
int
),
cudaMemcpyDefault
,
stream
);
cudaStreamSynchronize
(
stream
);
return
num_selected
;
}
template
<
typename
DataT
>
DataT
cubMax
(
const
DataT
*
d_in
,
int
num_items
,
DataT
*
d_out
,
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
DataT
cubMax
(
const
DataT
*
d_in
,
int
num_items
,
DataT
*
d_out
,
cudaStream_t
stream
)
{
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceReduce
::
Max
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
cub
::
DeviceReduce
::
Max
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
// Allocate temporary storage
auto
buffer_tmp
=
allocator
.
allocate
(
temp_storage_bytes
);
d_temp_storage
=
buffer_tmp
.
get
();
// Run min-reduction
cub
::
DeviceReduce
::
Max
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
cub
::
DeviceReduce
::
Max
(
d_temp_storage
,
temp_storage_bytes
,
d_in
,
d_out
,
num_items
,
stream
);
int
maxVal
=
0
;
cudaMemcpyAsync
(
&
maxVal
,
d_out
,
sizeof
(
DataT
),
cudaMemcpyDefault
,
stream
);
...
...
@@ -396,42 +409,52 @@ DataT cubMax(const DataT *d_in, int num_items, DataT *d_out,
return
maxVal
;
}
void
initConsts
(
AEVScalarParams
<
float
>
&
aev_params
,
cudaStream_t
stream
)
{
void
initConsts
(
AEVScalarParams
<
float
>&
aev_params
,
cudaStream_t
stream
)
{
int
num_species
=
aev_params
.
num_species
;
assert
(
num_species
<=
MAX_NSPECIES
);
// precompute the aev offsets and load to constand memory
int
*
subaev_offsets
=
new
int
[
num_species
*
num_species
];
int
*
subaev_offsets
=
new
int
[
num_species
*
num_species
];
for
(
int
t
=
0
;
t
<
num_species
;
++
t
)
{
int
offset
=
0
;
for
(
int
s
=
0
;
s
<
num_species
;
s
++
)
{
if
(
t
<
num_species
-
s
)
{
subaev_offsets
[
s
*
num_species
+
s
+
t
]
=
aev_params
.
angular_sublength
*
(
offset
+
t
);
subaev_offsets
[(
s
+
t
)
*
num_species
+
s
]
=
aev_params
.
angular_sublength
*
(
offset
+
t
);
subaev_offsets
[
s
*
num_species
+
s
+
t
]
=
aev_params
.
angular_sublength
*
(
offset
+
t
);
subaev_offsets
[(
s
+
t
)
*
num_species
+
s
]
=
aev_params
.
angular_sublength
*
(
offset
+
t
);
}
offset
+=
num_species
-
s
;
}
}
cudaMemcpyToSymbolAsync
(
csubaev_offsets
,
subaev_offsets
,
sizeof
(
int
)
*
num_species
*
num_species
,
0
,
cudaMemcpyDefault
,
stream
);
cudaMemcpyToSymbolAsync
(
csubaev_offsets
,
subaev_offsets
,
sizeof
(
int
)
*
num_species
*
num_species
,
0
,
cudaMemcpyDefault
,
stream
);
delete
[]
subaev_offsets
;
}
// NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1
template
<
typename
ScalarRealT
=
float
>
torch
::
Tensor
cuComputeAEV
(
torch
::
Tensor
coordinates_t
,
torch
::
Tensor
species_t
,
double
Rcr_
,
double
Rca_
,
torch
::
Tensor
EtaR_t
,
torch
::
Tensor
ShfR_t
,
torch
::
Tensor
EtaA_t
,
torch
::
Tensor
Zeta_t
,
torch
::
Tensor
ShfA_t
,
torch
::
Tensor
ShfZ_t
,
int64_t
num_species_
)
{
TORCH_CHECK
((
species_t
.
dtype
()
==
torch
::
kInt32
)
&&
(
coordinates_t
.
dtype
()
==
torch
::
kFloat32
),
"Unsupported input type"
);
TORCH_CHECK
(
EtaR_t
.
size
(
0
)
==
1
||
EtaA_t
.
size
(
0
)
==
1
||
Zeta_t
.
size
(
0
)
==
1
,
"cuda extension is currently not supported for the specified "
"configuration"
);
torch
::
Tensor
cuComputeAEV
(
torch
::
Tensor
coordinates_t
,
torch
::
Tensor
species_t
,
double
Rcr_
,
double
Rca_
,
torch
::
Tensor
EtaR_t
,
torch
::
Tensor
ShfR_t
,
torch
::
Tensor
EtaA_t
,
torch
::
Tensor
Zeta_t
,
torch
::
Tensor
ShfA_t
,
torch
::
Tensor
ShfZ_t
,
int64_t
num_species_
)
{
TORCH_CHECK
(
(
species_t
.
dtype
()
==
torch
::
kInt32
)
&&
(
coordinates_t
.
dtype
()
==
torch
::
kFloat32
),
"Unsupported input type"
);
TORCH_CHECK
(
EtaR_t
.
size
(
0
)
==
1
||
EtaA_t
.
size
(
0
)
==
1
||
Zeta_t
.
size
(
0
)
==
1
,
"cuda extension is currently not supported for the specified "
"configuration"
);
ScalarRealT
Rcr
=
Rcr_
;
ScalarRealT
Rca
=
Rca_
;
...
...
@@ -448,35 +471,29 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
aev_params
.
radial_sublength
=
EtaR_t
.
size
(
0
)
*
ShfR_t
.
size
(
0
);
aev_params
.
radial_length
=
aev_params
.
radial_sublength
*
num_species
;
aev_params
.
angular_sublength
=
EtaA_t
.
size
(
0
)
*
Zeta_t
.
size
(
0
)
*
ShfA_t
.
size
(
0
)
*
ShfZ_t
.
size
(
0
);
aev_params
.
angular_length
=
aev_params
.
angular_sublength
*
(
num_species
*
(
num_species
+
1
)
/
2
);
aev_params
.
angular_sublength
=
EtaA_t
.
size
(
0
)
*
Zeta_t
.
size
(
0
)
*
ShfA_t
.
size
(
0
)
*
ShfZ_t
.
size
(
0
);
aev_params
.
angular_length
=
aev_params
.
angular_sublength
*
(
num_species
*
(
num_species
+
1
)
/
2
);
int
aev_length
=
aev_params
.
radial_length
+
aev_params
.
angular_length
;
auto
aev_t
=
torch
::
zeros
({
n_molecules
,
max_natoms_per_mol
,
aev_length
},
coordinates_t
.
options
());
auto
aev_t
=
torch
::
zeros
({
n_molecules
,
max_natoms_per_mol
,
aev_length
},
coordinates_t
.
options
());
if
(
species_t
.
numel
()
==
0
)
{
return
aev_t
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
thrust_allocator
=
THCThrustAllocator
(
at
::
globalContext
().
lazyInitCUDA
());
auto
thrust_allocator
=
THCThrustAllocator
(
at
::
globalContext
().
lazyInitCUDA
());
auto
policy
=
thrust
::
cuda
::
par
(
thrust_allocator
).
on
(
stream
);
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
auto
&
allocator
=
*
c10
::
cuda
::
CUDACachingAllocator
::
get
();
// precompute the aev offsets and load to constand memory
initConsts
(
aev_params
,
stream
);
// buffer to store all the pairwise distance (Rij)
auto
total_natom_pairs
=
n_molecules
*
max_natoms_per_mol
*
max_natoms_per_mol
;
auto
buffer_Rij
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
total_natom_pairs
);
PairDist
<
float
>
*
d_Rij
=
(
PairDist
<
float
>
*
)
buffer_Rij
.
get
();
auto
total_natom_pairs
=
n_molecules
*
max_natoms_per_mol
*
max_natoms_per_mol
;
auto
buffer_Rij
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
total_natom_pairs
);
PairDist
<
float
>*
d_Rij
=
(
PairDist
<
float
>*
)
buffer_Rij
.
get
();
// init all Rij to inf
PairDist
<
float
>
init
;
...
...
@@ -485,28 +502,31 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
// buffer to store all the pairwise distance that is needed for Radial AEV
// computation
auto
buffer_radialRij
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
total_natom_pairs
);
PairDist
<
float
>
*
d_radialRij
=
(
PairDist
<
float
>
*
)
buffer_radialRij
.
get
();
auto
buffer_radialRij
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
total_natom_pairs
);
PairDist
<
float
>*
d_radialRij
=
(
PairDist
<
float
>*
)
buffer_radialRij
.
get
();
auto
buffer_count
=
allocator
.
allocate
(
sizeof
(
int
));
int
*
d_count_out
=
(
int
*
)
buffer_count
.
get
();
int
*
d_count_out
=
(
int
*
)
buffer_count
.
get
();
const
int
block_size
=
64
;
dim3
block
(
8
,
8
,
1
);
// Compute pairwise distance (Rij) for all atom pairs in a molecule
pairwiseDistance
<<<
n_molecules
,
block
,
sizeof
(
float
)
*
max_natoms_per_mol
*
3
,
stream
>>>
(
pairwiseDistance
<<<
n_molecules
,
block
,
sizeof
(
float
)
*
max_natoms_per_mol
*
3
,
stream
>>>
(
species_t
.
packed_accessor32
<
int
,
2
,
torch
::
RestrictPtrTraits
>
(),
coordinates_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
d_Rij
,
max_natoms_per_mol
);
d_Rij
,
max_natoms_per_mol
);
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <=
// Rcr
int
nRadialRij
=
cubDeviceSelect
(
d_Rij
,
d_radialRij
,
total_natom_pairs
,
d_count_out
,
[
=
]
__device__
(
const
PairDist
<
float
>
d
)
{
return
d
.
Rij
<=
Rcr
;
},
stream
);
d_Rij
,
d_radialRij
,
total_natom_pairs
,
d_count_out
,
[
=
]
__device__
(
const
PairDist
<
float
>
d
)
{
return
d
.
Rij
<=
Rcr
;
},
stream
);
int
nblocks
=
(
nRadialRij
*
8
+
block_size
-
1
)
/
block_size
;
cuRadialAEVs
<
int
,
float
,
8
><<<
nblocks
,
block_size
,
0
,
stream
>>>
(
...
...
@@ -514,40 +534,41 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
ShfR_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
EtaR_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
aev_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
d_radialRij
,
aev_params
,
nRadialRij
);
d_radialRij
,
aev_params
,
nRadialRij
);
// reuse buffer allocated for all Rij
// d_angularRij will store all the Rij required in Angular AEV computation
PairDist
<
float
>
*
d_angularRij
=
d_Rij
;
PairDist
<
float
>*
d_angularRij
=
d_Rij
;
// Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij
// <= Rca
int
nAngularRij
=
cubDeviceSelect
(
d_radialRij
,
d_angularRij
,
nRadialRij
,
d_count_out
,
[
=
]
__device__
(
const
PairDist
<
float
>
d
)
{
return
d
.
Rij
<=
Rca
;
},
stream
);
d_radialRij
,
d_angularRij
,
nRadialRij
,
d_count_out
,
[
=
]
__device__
(
const
PairDist
<
float
>
d
)
{
return
d
.
Rij
<=
Rca
;
},
stream
);
auto
buffer_centralAtom
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
nAngularRij
);
PairDist
<
float
>
*
d_centralAtom
=
(
PairDist
<
float
>
*
)
buffer_centralAtom
.
get
();
auto
buffer_centralAtom
=
allocator
.
allocate
(
sizeof
(
PairDist
<
float
>
)
*
nAngularRij
);
PairDist
<
float
>*
d_centralAtom
=
(
PairDist
<
float
>*
)
buffer_centralAtom
.
get
();
auto
buffer_numPairsPerCenterAtom
=
allocator
.
allocate
(
sizeof
(
int
)
*
nAngularRij
);
int
*
d_numPairsPerCenterAtom
=
(
int
*
)
buffer_numPairsPerCenterAtom
.
get
();
auto
buffer_numPairsPerCenterAtom
=
allocator
.
allocate
(
sizeof
(
int
)
*
nAngularRij
);
int
*
d_numPairsPerCenterAtom
=
(
int
*
)
buffer_numPairsPerCenterAtom
.
get
();
// group by center atom
int
ncenter_atoms
=
cubEncode
(
d_angularRij
,
d_centralAtom
,
d_numPairsPerCenterAtom
,
nAngularRij
,
d_count_out
,
stream
);
int
ncenter_atoms
=
cubEncode
(
d_angularRij
,
d_centralAtom
,
d_numPairsPerCenterAtom
,
nAngularRij
,
d_count_out
,
stream
);
auto
buffer_centerAtomStartIdx
=
allocator
.
allocate
(
sizeof
(
int
)
*
ncenter_atoms
);
int
*
d_centerAtomStartIdx
=
(
int
*
)
buffer_centerAtomStartIdx
.
get
();
auto
buffer_centerAtomStartIdx
=
allocator
.
allocate
(
sizeof
(
int
)
*
ncenter_atoms
);
int
*
d_centerAtomStartIdx
=
(
int
*
)
buffer_centerAtomStartIdx
.
get
();
cubScan
(
d_numPairsPerCenterAtom
,
d_centerAtomStartIdx
,
ncenter_atoms
,
stream
);
{
const
int
nthreads_per_catom
=
32
;
const
int
nblocks_angAEV
=
(
ncenter_atoms
*
nthreads_per_catom
+
block_size
-
1
)
/
block_size
;
const
int
nblocks_angAEV
=
(
ncenter_atoms
*
nthreads_per_catom
+
block_size
-
1
)
/
block_size
;
auto
smem_size
=
[
&
aev_params
](
int
max_nbrs
,
int
ncatom_per_tpb
)
{
int
sm_aev
=
sizeof
(
float
)
*
align
<
4
>
(
aev_params
.
angular_length
);
int
sxyz
=
sizeof
(
float
)
*
max_nbrs
*
3
;
...
...
@@ -558,15 +579,15 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
return
(
sm_aev
+
sxyz
+
sRij
+
sfc
+
sj
)
*
ncatom_per_tpb
;
};
int
maxNbrsPerCenterAtom
=
cubMax
(
d_numPairsPerCenterAtom
,
ncenter_atoms
,
d_count_out
,
stream
);
int
maxNbrsPerCenterAtom
=
cubMax
(
d_numPairsPerCenterAtom
,
ncenter_atoms
,
d_count_out
,
stream
);
int
maxnbrs_per_atom_aligned
=
align
<
4
>
(
maxNbrsPerCenterAtom
);
cuAngularAEVs
<<<
nblocks_angAEV
,
block_size
,
smem_size
(
maxnbrs_per_atom_aligned
,
block_size
/
nthreads_per_catom
),
stream
>>>
(
cuAngularAEVs
<<<
nblocks_angAEV
,
block_size
,
smem_size
(
maxnbrs_per_atom_aligned
,
block_size
/
nthreads_per_catom
),
stream
>>>
(
species_t
.
packed_accessor32
<
int
,
2
,
torch
::
RestrictPtrTraits
>
(),
coordinates_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
ShfA_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
...
...
@@ -574,13 +595,20 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
EtaA_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
Zeta_t
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
(),
aev_t
.
packed_accessor32
<
float
,
3
,
torch
::
RestrictPtrTraits
>
(),
d_angularRij
,
d_centralAtom
,
d_numPairsPerCenterAtom
,
d_centerAtomStartIdx
,
aev_params
,
maxnbrs_per_atom_aligned
,
align
<
4
>
(
aev_params
.
angular_length
),
ncenter_atoms
);
d_angularRij
,
d_centralAtom
,
d_numPairsPerCenterAtom
,
d_centerAtomStartIdx
,
aev_params
,
maxnbrs_per_atom_aligned
,
align
<
4
>
(
aev_params
.
angular_length
),
ncenter_atoms
);
}
return
aev_t
;
}
TORCH_LIBRARY
(
cuaev
,
m
)
{
m
.
def
(
"cuComputeAEV"
,
&
cuComputeAEV
<
float
>
);
}
TORCH_LIBRARY
(
cuaev
,
m
)
{
m
.
def
(
"cuComputeAEV"
,
&
cuComputeAEV
<
float
>
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment