Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
8292bb3a
Unverified
Commit
8292bb3a
authored
Jun 21, 2022
by
Peter Eastman
Committed by
GitHub
Jun 21, 2022
Browse files
Reduced the cost of updating tabulated functions (#3649)
parent
d7da750a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
89 additions
and
74 deletions
+89
-74
openmmapi/include/openmm/TabulatedFunction.h
openmmapi/include/openmm/TabulatedFunction.h
+9
-1
openmmapi/src/TabulatedFunction.cpp
openmmapi/src/TabulatedFunction.cpp
+11
-1
platforms/common/include/openmm/common/CommonKernels.h
platforms/common/include/openmm/common/CommonKernels.h
+6
-6
platforms/common/src/CommonKernels.cpp
platforms/common/src/CommonKernels.cpp
+18
-19
platforms/cpu/include/CpuKernels.h
platforms/cpu/include/CpuKernels.h
+3
-3
platforms/cpu/src/CpuKernels.cpp
platforms/cpu/src/CpuKernels.cpp
+12
-13
platforms/reference/include/ReferenceKernels.h
platforms/reference/include/ReferenceKernels.h
+6
-6
platforms/reference/src/ReferenceKernels.cpp
platforms/reference/src/ReferenceKernels.cpp
+24
-25
No files found.
openmmapi/include/openmm/TabulatedFunction.h
View file @
8292bb3a
...
...
@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors.
*
* Portions copyright (c) 2014
-2022
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
...
...
@@ -57,6 +57,8 @@ namespace OpenMM {
class
OPENMM_EXPORT
TabulatedFunction
{
public:
TabulatedFunction
()
:
updateCount
(
0
)
{
}
virtual
~
TabulatedFunction
()
{
}
/**
...
...
@@ -67,12 +69,18 @@ public:
* Get the periodicity status of the tabulated function.
*/
bool
getPeriodic
()
const
;
/**
* Get the value of a counter that is updated every time setFunctionParameters()
* is called. This provides a fast way to detect when a function has changed.
*/
int
getUpdateCount
()
const
;
virtual
bool
operator
==
(
const
TabulatedFunction
&
other
)
const
=
0
;
virtual
bool
operator
!=
(
const
TabulatedFunction
&
other
)
const
{
return
!
(
*
this
==
other
);
}
protected:
bool
periodic
;
int
updateCount
;
};
/**
...
...
openmmapi/src/TabulatedFunction.cpp
View file @
8292bb3a
...
...
@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014-202
1
Stanford University and the Authors. *
* Portions copyright (c) 2014-202
2
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
...
...
@@ -39,6 +39,10 @@ bool TabulatedFunction::getPeriodic() const {
return
periodic
;
}
int
TabulatedFunction
::
getUpdateCount
()
const
{
return
updateCount
;
}
Continuous1DFunction
::
Continuous1DFunction
(
const
vector
<
double
>&
values
,
double
min
,
double
max
,
bool
periodic
)
{
this
->
periodic
=
periodic
;
setFunctionParameters
(
values
,
min
,
max
);
...
...
@@ -66,6 +70,7 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this
->
values
=
values
;
this
->
min
=
min
;
this
->
max
=
max
;
updateCount
++
;
}
Continuous1DFunction
*
Continuous1DFunction
::
Copy
()
const
{
...
...
@@ -120,6 +125,7 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
this
->
xmax
=
xmax
;
this
->
ymin
=
ymin
;
this
->
ymax
=
ymax
;
updateCount
++
;
}
Continuous2DFunction
*
Continuous2DFunction
::
Copy
()
const
{
...
...
@@ -186,6 +192,7 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize
this
->
ymax
=
ymax
;
this
->
zmin
=
zmin
;
this
->
zmax
=
zmax
;
updateCount
++
;
}
Continuous3DFunction
*
Continuous3DFunction
::
Copy
()
const
{
...
...
@@ -220,6 +227,7 @@ void Discrete1DFunction::getFunctionParameters(vector<double>& values) const {
void
Discrete1DFunction
::
setFunctionParameters
(
const
vector
<
double
>&
values
)
{
this
->
values
=
values
;
updateCount
++
;
}
Discrete1DFunction
*
Discrete1DFunction
::
Copy
()
const
{
...
...
@@ -256,6 +264,7 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto
this
->
xsize
=
xsize
;
this
->
ysize
=
ysize
;
this
->
values
=
values
;
updateCount
++
;
}
Discrete2DFunction
*
Discrete2DFunction
::
Copy
()
const
{
...
...
@@ -297,6 +306,7 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize,
this
->
ysize
=
ysize
;
this
->
zsize
=
zsize
;
this
->
values
=
values
;
updateCount
++
;
}
Discrete3DFunction
*
Discrete3DFunction
::
Copy
()
const
{
...
...
platforms/common/include/openmm/common/CommonKernels.h
View file @
8292bb3a
...
...
@@ -530,7 +530,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
const
System
&
system
;
};
...
...
@@ -579,7 +579,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
std
::
vector
<
void
*>
groupForcesArgs
;
ComputeKernel
computeCentersKernel
,
groupForcesKernel
,
applyForcesKernel
;
const
System
&
system
;
...
...
@@ -632,7 +632,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
std
::
vector
<
std
::
string
>
paramNames
,
computedValueNames
;
std
::
vector
<
ComputeParameterInfo
>
paramBuffers
,
computedValueBuffers
;
double
longRangeCoefficient
;
...
...
@@ -735,7 +735,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
std
::
vector
<
bool
>
pairValueUsesParam
,
pairEnergyUsesParam
,
pairEnergyUsesValue
;
const
System
&
system
;
ComputeKernel
pairValueKernel
,
perParticleValueKernel
,
pairEnergyKernel
,
perParticleEnergyKernel
,
gradientChainRuleKernel
;
...
...
@@ -793,7 +793,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
const
System
&
system
;
ComputeKernel
donorKernel
,
acceptorKernel
;
};
...
...
@@ -845,7 +845,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
const
System
&
system
;
ComputeKernel
forceKernel
,
blockBoundsKernel
,
neighborsKernel
,
startIndicesKernel
,
copyPairsKernel
;
ComputeEvent
event
;
...
...
platforms/common/src/CommonKernels.cpp
View file @
8292bb3a
...
...
@@ -35,7 +35,6 @@
#include "openmm/internal/CustomCompoundBondForceImpl.h"
#include "openmm/internal/CustomHbondForceImpl.h"
#include "openmm/internal/CustomManyParticleForceImpl.h"
#include "openmm/serialization/XmlSerializer.h"
#include "CommonKernelSources.h"
#include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h"
...
...
@@ -1294,7 +1293,7 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
functionList
.
push_back
(
&
force
.
getTabulatedFunction
(
i
));
string
name
=
force
.
getTabulatedFunctionName
(
i
);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
functions
[
name
]
=
cc
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getTabulatedFunction
(
i
));
int
width
;
vector
<
float
>
f
=
cc
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getTabulatedFunction
(
i
),
width
);
...
...
@@ -1417,8 +1416,8 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
int
width
;
vector
<
float
>
f
=
cc
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getTabulatedFunction
(
i
),
width
);
tabulatedFunctionArrays
[
i
].
upload
(
f
);
...
...
@@ -1553,7 +1552,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
functionList
.
push_back
(
&
force
.
getTabulatedFunction
(
i
));
string
name
=
force
.
getTabulatedFunctionName
(
i
);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
string
arrayName
=
"table"
+
cc
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
cc
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getTabulatedFunction
(
i
));
...
...
@@ -1747,8 +1746,8 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
int
width
;
vector
<
float
>
f
=
cc
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getTabulatedFunction
(
i
),
width
);
tabulatedFunctionArrays
[
i
].
upload
(
f
);
...
...
@@ -1902,7 +1901,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
functionList
.
push_back
(
&
force
.
getTabulatedFunction
(
i
));
string
name
=
force
.
getTabulatedFunctionName
(
i
);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
string
arrayName
=
prefix
+
"table"
+
cc
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
cc
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getTabulatedFunction
(
i
));
...
...
@@ -2459,8 +2458,8 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
int
width
;
vector
<
float
>
f
=
cc
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getTabulatedFunction
(
i
),
width
);
tabulatedFunctionArrays
[
i
].
upload
(
f
);
...
...
@@ -2798,7 +2797,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
functionList
.
push_back
(
&
force
.
getTabulatedFunction
(
i
));
string
name
=
force
.
getTabulatedFunctionName
(
i
);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
string
arrayName
=
prefix
+
"table"
+
cc
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
cc
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getTabulatedFunction
(
i
));
...
...
@@ -3785,8 +3784,8 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
int
width
;
vector
<
float
>
f
=
cc
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getTabulatedFunction
(
i
),
width
);
tabulatedFunctionArrays
[
i
].
upload
(
f
);
...
...
@@ -4012,7 +4011,7 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
functionList
.
push_back
(
&
force
.
getTabulatedFunction
(
i
));
string
name
=
force
.
getTabulatedFunctionName
(
i
);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
string
arrayName
=
"table"
+
cc
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
cc
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getTabulatedFunction
(
i
));
...
...
@@ -4336,8 +4335,8 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
int
width
;
vector
<
float
>
f
=
cc
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getTabulatedFunction
(
i
),
width
);
tabulatedFunctionArrays
[
i
].
upload
(
f
);
...
...
@@ -4425,7 +4424,7 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
functionList
.
push_back
(
&
force
.
getTabulatedFunction
(
i
));
string
name
=
force
.
getTabulatedFunctionName
(
i
);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
string
arrayName
=
"table"
+
cc
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
cc
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getTabulatedFunction
(
i
));
...
...
@@ -4855,8 +4854,8 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
int
width
;
vector
<
float
>
f
=
cc
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getTabulatedFunction
(
i
),
width
);
tabulatedFunctionArrays
[
i
].
upload
(
f
);
...
...
platforms/cpu/include/CpuKernels.h
View file @
8292bb3a
...
...
@@ -334,7 +334,7 @@ private:
std
::
vector
<
std
::
string
>
parameterNames
,
globalParameterNames
,
computedValueNames
,
energyParamDerivNames
;
std
::
vector
<
std
::
pair
<
std
::
set
<
int
>
,
std
::
set
<
int
>
>
>
interactionGroups
;
std
::
vector
<
double
>
longRangeCoefficientDerivs
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
CpuCustomNonbondedForce
*
nonbonded
;
};
...
...
@@ -424,7 +424,7 @@ private:
std
::
vector
<
std
::
string
>
particleParameterNames
,
globalParameterNames
,
energyParamDerivNames
,
valueNames
;
std
::
vector
<
OpenMM
::
CustomGBForce
::
ComputationType
>
valueTypes
;
std
::
vector
<
OpenMM
::
CustomGBForce
::
ComputationType
>
energyTypes
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
};
...
...
@@ -467,7 +467,7 @@ private:
std
::
vector
<
std
::
vector
<
double
>
>
particleParamArray
;
CpuCustomManyParticleForce
*
ixn
;
std
::
vector
<
std
::
string
>
globalParameterNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
};
...
...
platforms/cpu/src/CpuKernels.cpp
View file @
8292bb3a
...
...
@@ -45,7 +45,6 @@
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/vectorize.h"
#include "openmm/serialization/XmlSerializer.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h"
#include "lepton/Operation.h"
...
...
@@ -888,10 +887,10 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
switchingDistance
=
force
.
getSwitchingDistance
();
}
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Record information for the long range correction.
...
...
@@ -1053,8 +1052,8 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
}
}
...
...
@@ -1166,10 +1165,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
neighborList
=
new
CpuNeighborList
(
4
);
data
.
isPeriodic
|=
(
force
.
getNonbondedMethod
()
==
CustomGBForce
::
CutoffPeriodic
);
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
...
...
@@ -1319,8 +1318,8 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
}
}
...
...
@@ -1349,10 +1348,10 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons
for
(
int
i
=
0
;
i
<
force
.
getNumGlobalParameters
();
i
++
)
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
...
...
@@ -1399,8 +1398,8 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl&
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
}
}
...
...
platforms/reference/include/ReferenceKernels.h
View file @
8292bb3a
...
...
@@ -706,7 +706,7 @@ private:
std
::
vector
<
std
::
string
>
parameterNames
,
globalParameterNames
,
computedValueNames
,
energyParamDerivNames
;
std
::
vector
<
std
::
pair
<
std
::
set
<
int
>
,
std
::
set
<
int
>
>
>
interactionGroups
;
std
::
vector
<
double
>
longRangeCoefficientDerivs
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
NeighborList
*
neighborList
;
};
...
...
@@ -797,7 +797,7 @@ private:
std
::
vector
<
std
::
vector
<
Lepton
::
CompiledExpression
>
>
energyGradientExpressions
;
std
::
vector
<
std
::
vector
<
Lepton
::
CompiledExpression
>
>
energyParamDerivExpressions
;
std
::
vector
<
OpenMM
::
CustomGBForce
::
ComputationType
>
energyTypes
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
NeighborList
*
neighborList
;
};
...
...
@@ -884,7 +884,7 @@ private:
ReferenceCustomHbondIxn
*
ixn
;
std
::
vector
<
std
::
set
<
int
>
>
exclusions
;
std
::
vector
<
std
::
string
>
globalParameterNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
};
/**
...
...
@@ -927,7 +927,7 @@ private:
std
::
vector
<
std
::
vector
<
double
>
>
bondParamArray
;
ReferenceCustomCentroidBondIxn
*
ixn
;
std
::
vector
<
std
::
string
>
globalParameterNames
,
energyParamDerivNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
bool
usePeriodic
;
Vec3
*
boxVectors
;
};
...
...
@@ -970,7 +970,7 @@ private:
std
::
vector
<
std
::
vector
<
double
>
>
bondParamArray
;
ReferenceCustomCompoundBondIxn
*
ixn
;
std
::
vector
<
std
::
string
>
globalParameterNames
,
energyParamDerivNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
bool
usePeriodic
;
Vec3
*
boxVectors
;
};
...
...
@@ -1012,7 +1012,7 @@ private:
std
::
vector
<
std
::
vector
<
double
>
>
particleParamArray
;
ReferenceCustomManyParticleIxn
*
ixn
;
std
::
vector
<
std
::
string
>
globalParameterNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
};
...
...
platforms/reference/src/ReferenceKernels.cpp
View file @
8292bb3a
...
...
@@ -80,7 +80,6 @@
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/Integrator.h"
#include "openmm/OpenMMException.h"
#include "openmm/serialization/XmlSerializer.h"
#include "SimTKOpenMMUtilities.h"
#include "lepton/CustomFunction.h"
#include "lepton/Operation.h"
...
...
@@ -1177,10 +1176,10 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
switchingDistance
=
force
.
getSwitchingDistance
();
}
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the expressions.
...
...
@@ -1349,8 +1348,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
}
}
...
...
@@ -1465,10 +1464,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
else
neighborList
=
new
NeighborList
();
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the expressions.
...
...
@@ -1624,8 +1623,8 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
}
}
...
...
@@ -1750,10 +1749,10 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
nonbondedCutoff
=
force
.
getCutoffDistance
();
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
...
...
@@ -1839,8 +1838,8 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
}
}
...
...
@@ -1873,10 +1872,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system
for
(
int
i
=
0
;
i
<
numBonds
;
++
i
)
force
.
getBondParameters
(
i
,
bondGroups
[
i
],
bondParamArray
[
i
]);
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
...
...
@@ -1962,8 +1961,8 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
}
}
...
...
@@ -1990,10 +1989,10 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
for
(
int
i
=
0
;
i
<
numBonds
;
++
i
)
force
.
getBondParameters
(
i
,
bondParticles
[
i
],
bondParamArray
[
i
]);
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
...
...
@@ -2078,8 +2077,8 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
}
}
...
...
@@ -2107,10 +2106,10 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system
for
(
int
i
=
0
;
i
<
force
.
getNumGlobalParameters
();
i
++
)
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
...
...
@@ -2158,8 +2157,8 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment