Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
8292bb3a
Unverified
Commit
8292bb3a
authored
Jun 21, 2022
by
Peter Eastman
Committed by
GitHub
Jun 21, 2022
Browse files
Reduced the cost of updating tabulated functions (#3649)
parent
d7da750a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
89 additions
and
74 deletions
+89
-74
openmmapi/include/openmm/TabulatedFunction.h
openmmapi/include/openmm/TabulatedFunction.h
+9
-1
openmmapi/src/TabulatedFunction.cpp
openmmapi/src/TabulatedFunction.cpp
+11
-1
platforms/common/include/openmm/common/CommonKernels.h
platforms/common/include/openmm/common/CommonKernels.h
+6
-6
platforms/common/src/CommonKernels.cpp
platforms/common/src/CommonKernels.cpp
+18
-19
platforms/cpu/include/CpuKernels.h
platforms/cpu/include/CpuKernels.h
+3
-3
platforms/cpu/src/CpuKernels.cpp
platforms/cpu/src/CpuKernels.cpp
+12
-13
platforms/reference/include/ReferenceKernels.h
platforms/reference/include/ReferenceKernels.h
+6
-6
platforms/reference/src/ReferenceKernels.cpp
platforms/reference/src/ReferenceKernels.cpp
+24
-25
No files found.
openmmapi/include/openmm/TabulatedFunction.h
View file @
8292bb3a
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* *
* Portions copyright (c) 2014 Stanford University and the Authors.
*
* Portions copyright (c) 2014
-2022
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: *
* *
* *
...
@@ -57,6 +57,8 @@ namespace OpenMM {
...
@@ -57,6 +57,8 @@ namespace OpenMM {
class
OPENMM_EXPORT
TabulatedFunction
{
class
OPENMM_EXPORT
TabulatedFunction
{
public:
public:
TabulatedFunction
()
:
updateCount
(
0
)
{
}
virtual
~
TabulatedFunction
()
{
virtual
~
TabulatedFunction
()
{
}
}
/**
/**
...
@@ -67,12 +69,18 @@ public:
...
@@ -67,12 +69,18 @@ public:
* Get the periodicity status of the tabulated function.
* Get the periodicity status of the tabulated function.
*/
*/
bool
getPeriodic
()
const
;
bool
getPeriodic
()
const
;
/**
* Get the value of a counter that is updated every time setFunctionParameters()
* is called. This provides a fast way to detect when a function has changed.
*/
int
getUpdateCount
()
const
;
virtual
bool
operator
==
(
const
TabulatedFunction
&
other
)
const
=
0
;
virtual
bool
operator
==
(
const
TabulatedFunction
&
other
)
const
=
0
;
virtual
bool
operator
!=
(
const
TabulatedFunction
&
other
)
const
{
virtual
bool
operator
!=
(
const
TabulatedFunction
&
other
)
const
{
return
!
(
*
this
==
other
);
return
!
(
*
this
==
other
);
}
}
protected:
protected:
bool
periodic
;
bool
periodic
;
int
updateCount
;
};
};
/**
/**
...
...
openmmapi/src/TabulatedFunction.cpp
View file @
8292bb3a
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* *
* Portions copyright (c) 2014-202
1
Stanford University and the Authors. *
* Portions copyright (c) 2014-202
2
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: *
* *
* *
...
@@ -39,6 +39,10 @@ bool TabulatedFunction::getPeriodic() const {
...
@@ -39,6 +39,10 @@ bool TabulatedFunction::getPeriodic() const {
return
periodic
;
return
periodic
;
}
}
int
TabulatedFunction
::
getUpdateCount
()
const
{
return
updateCount
;
}
Continuous1DFunction
::
Continuous1DFunction
(
const
vector
<
double
>&
values
,
double
min
,
double
max
,
bool
periodic
)
{
Continuous1DFunction
::
Continuous1DFunction
(
const
vector
<
double
>&
values
,
double
min
,
double
max
,
bool
periodic
)
{
this
->
periodic
=
periodic
;
this
->
periodic
=
periodic
;
setFunctionParameters
(
values
,
min
,
max
);
setFunctionParameters
(
values
,
min
,
max
);
...
@@ -66,6 +70,7 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
...
@@ -66,6 +70,7 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this
->
values
=
values
;
this
->
values
=
values
;
this
->
min
=
min
;
this
->
min
=
min
;
this
->
max
=
max
;
this
->
max
=
max
;
updateCount
++
;
}
}
Continuous1DFunction
*
Continuous1DFunction
::
Copy
()
const
{
Continuous1DFunction
*
Continuous1DFunction
::
Copy
()
const
{
...
@@ -120,6 +125,7 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
...
@@ -120,6 +125,7 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
this
->
xmax
=
xmax
;
this
->
xmax
=
xmax
;
this
->
ymin
=
ymin
;
this
->
ymin
=
ymin
;
this
->
ymax
=
ymax
;
this
->
ymax
=
ymax
;
updateCount
++
;
}
}
Continuous2DFunction
*
Continuous2DFunction
::
Copy
()
const
{
Continuous2DFunction
*
Continuous2DFunction
::
Copy
()
const
{
...
@@ -186,6 +192,7 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize
...
@@ -186,6 +192,7 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize
this
->
ymax
=
ymax
;
this
->
ymax
=
ymax
;
this
->
zmin
=
zmin
;
this
->
zmin
=
zmin
;
this
->
zmax
=
zmax
;
this
->
zmax
=
zmax
;
updateCount
++
;
}
}
Continuous3DFunction
*
Continuous3DFunction
::
Copy
()
const
{
Continuous3DFunction
*
Continuous3DFunction
::
Copy
()
const
{
...
@@ -220,6 +227,7 @@ void Discrete1DFunction::getFunctionParameters(vector<double>& values) const {
...
@@ -220,6 +227,7 @@ void Discrete1DFunction::getFunctionParameters(vector<double>& values) const {
void
Discrete1DFunction
::
setFunctionParameters
(
const
vector
<
double
>&
values
)
{
void
Discrete1DFunction
::
setFunctionParameters
(
const
vector
<
double
>&
values
)
{
this
->
values
=
values
;
this
->
values
=
values
;
updateCount
++
;
}
}
Discrete1DFunction
*
Discrete1DFunction
::
Copy
()
const
{
Discrete1DFunction
*
Discrete1DFunction
::
Copy
()
const
{
...
@@ -256,6 +264,7 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto
...
@@ -256,6 +264,7 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto
this
->
xsize
=
xsize
;
this
->
xsize
=
xsize
;
this
->
ysize
=
ysize
;
this
->
ysize
=
ysize
;
this
->
values
=
values
;
this
->
values
=
values
;
updateCount
++
;
}
}
Discrete2DFunction
*
Discrete2DFunction
::
Copy
()
const
{
Discrete2DFunction
*
Discrete2DFunction
::
Copy
()
const
{
...
@@ -297,6 +306,7 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize,
...
@@ -297,6 +306,7 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize,
this
->
ysize
=
ysize
;
this
->
ysize
=
ysize
;
this
->
zsize
=
zsize
;
this
->
zsize
=
zsize
;
this
->
values
=
values
;
this
->
values
=
values
;
updateCount
++
;
}
}
Discrete3DFunction
*
Discrete3DFunction
::
Copy
()
const
{
Discrete3DFunction
*
Discrete3DFunction
::
Copy
()
const
{
...
...
platforms/common/include/openmm/common/CommonKernels.h
View file @
8292bb3a
...
@@ -530,7 +530,7 @@ private:
...
@@ -530,7 +530,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
const
System
&
system
;
const
System
&
system
;
};
};
...
@@ -579,7 +579,7 @@ private:
...
@@ -579,7 +579,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
std
::
vector
<
void
*>
groupForcesArgs
;
std
::
vector
<
void
*>
groupForcesArgs
;
ComputeKernel
computeCentersKernel
,
groupForcesKernel
,
applyForcesKernel
;
ComputeKernel
computeCentersKernel
,
groupForcesKernel
,
applyForcesKernel
;
const
System
&
system
;
const
System
&
system
;
...
@@ -632,7 +632,7 @@ private:
...
@@ -632,7 +632,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
std
::
vector
<
std
::
string
>
paramNames
,
computedValueNames
;
std
::
vector
<
std
::
string
>
paramNames
,
computedValueNames
;
std
::
vector
<
ComputeParameterInfo
>
paramBuffers
,
computedValueBuffers
;
std
::
vector
<
ComputeParameterInfo
>
paramBuffers
,
computedValueBuffers
;
double
longRangeCoefficient
;
double
longRangeCoefficient
;
...
@@ -735,7 +735,7 @@ private:
...
@@ -735,7 +735,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
std
::
vector
<
bool
>
pairValueUsesParam
,
pairEnergyUsesParam
,
pairEnergyUsesValue
;
std
::
vector
<
bool
>
pairValueUsesParam
,
pairEnergyUsesParam
,
pairEnergyUsesValue
;
const
System
&
system
;
const
System
&
system
;
ComputeKernel
pairValueKernel
,
perParticleValueKernel
,
pairEnergyKernel
,
perParticleEnergyKernel
,
gradientChainRuleKernel
;
ComputeKernel
pairValueKernel
,
perParticleValueKernel
,
pairEnergyKernel
,
perParticleEnergyKernel
,
gradientChainRuleKernel
;
...
@@ -793,7 +793,7 @@ private:
...
@@ -793,7 +793,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
const
System
&
system
;
const
System
&
system
;
ComputeKernel
donorKernel
,
acceptorKernel
;
ComputeKernel
donorKernel
,
acceptorKernel
;
};
};
...
@@ -845,7 +845,7 @@ private:
...
@@ -845,7 +845,7 @@ private:
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
float
>
globalParamValues
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
vector
<
ComputeArray
>
tabulatedFunctionArrays
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
const
System
&
system
;
const
System
&
system
;
ComputeKernel
forceKernel
,
blockBoundsKernel
,
neighborsKernel
,
startIndicesKernel
,
copyPairsKernel
;
ComputeKernel
forceKernel
,
blockBoundsKernel
,
neighborsKernel
,
startIndicesKernel
,
copyPairsKernel
;
ComputeEvent
event
;
ComputeEvent
event
;
...
...
platforms/common/src/CommonKernels.cpp
View file @
8292bb3a
...
@@ -35,7 +35,6 @@
...
@@ -35,7 +35,6 @@
#include "openmm/internal/CustomCompoundBondForceImpl.h"
#include "openmm/internal/CustomCompoundBondForceImpl.h"
#include "openmm/internal/CustomHbondForceImpl.h"
#include "openmm/internal/CustomHbondForceImpl.h"
#include "openmm/internal/CustomManyParticleForceImpl.h"
#include "openmm/internal/CustomManyParticleForceImpl.h"
#include "openmm/serialization/XmlSerializer.h"
#include "CommonKernelSources.h"
#include "CommonKernelSources.h"
#include "lepton/CustomFunction.h"
#include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h"
#include "lepton/ExpressionTreeNode.h"
...
@@ -1294,7 +1293,7 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
...
@@ -1294,7 +1293,7 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
...
@@ -1417,8 +1416,8 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp
...
@@ -1417,8 +1416,8 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if (force.getTabulatedFunction(i)
.getUpdateCount()
!= tabulatedFunction
UpdateCount
[name]) {
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
int width;
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
tabulatedFunctionArrays[i].upload(f);
...
@@ -1553,7 +1552,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
...
@@ -1553,7 +1552,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
string arrayName = "table"+cc.intToString(i);
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...
@@ -1747,8 +1746,8 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp
...
@@ -1747,8 +1746,8 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if (force.getTabulatedFunction(i)
.getUpdateCount()
!= tabulatedFunction
UpdateCount
[name]) {
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
int width;
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
tabulatedFunctionArrays[i].upload(f);
...
@@ -1902,7 +1901,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -1902,7 +1901,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
string arrayName = prefix+"table"+cc.intToString(i);
string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...
@@ -2459,8 +2458,8 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
...
@@ -2459,8 +2458,8 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if (force.getTabulatedFunction(i)
.getUpdateCount()
!= tabulatedFunction
UpdateCount
[name]) {
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
int width;
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
tabulatedFunctionArrays[i].upload(f);
...
@@ -2798,7 +2797,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
...
@@ -2798,7 +2797,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
string arrayName = prefix+"table"+cc.intToString(i);
string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...
@@ -3785,8 +3784,8 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context
...
@@ -3785,8 +3784,8 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if (force.getTabulatedFunction(i)
.getUpdateCount()
!= tabulatedFunction
UpdateCount
[name]) {
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
int width;
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
tabulatedFunctionArrays[i].upload(f);
...
@@ -4012,7 +4011,7 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
...
@@ -4012,7 +4011,7 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
string arrayName = "table"+cc.intToString(i);
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...
@@ -4336,8 +4335,8 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont
...
@@ -4336,8 +4335,8 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if (force.getTabulatedFunction(i)
.getUpdateCount()
!= tabulatedFunction
UpdateCount
[name]) {
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
int width;
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
tabulatedFunctionArrays[i].upload(f);
...
@@ -4425,7 +4424,7 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
...
@@ -4425,7 +4424,7 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
string arrayName = "table"+cc.intToString(i);
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...
@@ -4855,8 +4854,8 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp
...
@@ -4855,8 +4854,8 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
string name = force.getTabulatedFunctionName(i);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if (force.getTabulatedFunction(i)
.getUpdateCount()
!= tabulatedFunction
UpdateCount
[name]) {
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount[name] =
force.getTabulatedFunction(i)
.getUpdateCount(
);
int width;
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
tabulatedFunctionArrays[i].upload(f);
...
...
platforms/cpu/include/CpuKernels.h
View file @
8292bb3a
...
@@ -334,7 +334,7 @@ private:
...
@@ -334,7 +334,7 @@ private:
std
::
vector
<
std
::
string
>
parameterNames
,
globalParameterNames
,
computedValueNames
,
energyParamDerivNames
;
std
::
vector
<
std
::
string
>
parameterNames
,
globalParameterNames
,
computedValueNames
,
energyParamDerivNames
;
std
::
vector
<
std
::
pair
<
std
::
set
<
int
>
,
std
::
set
<
int
>
>
>
interactionGroups
;
std
::
vector
<
std
::
pair
<
std
::
set
<
int
>
,
std
::
set
<
int
>
>
>
interactionGroups
;
std
::
vector
<
double
>
longRangeCoefficientDerivs
;
std
::
vector
<
double
>
longRangeCoefficientDerivs
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
NonbondedMethod
nonbondedMethod
;
CpuCustomNonbondedForce
*
nonbonded
;
CpuCustomNonbondedForce
*
nonbonded
;
};
};
...
@@ -424,7 +424,7 @@ private:
...
@@ -424,7 +424,7 @@ private:
std
::
vector
<
std
::
string
>
particleParameterNames
,
globalParameterNames
,
energyParamDerivNames
,
valueNames
;
std
::
vector
<
std
::
string
>
particleParameterNames
,
globalParameterNames
,
energyParamDerivNames
,
valueNames
;
std
::
vector
<
OpenMM
::
CustomGBForce
::
ComputationType
>
valueTypes
;
std
::
vector
<
OpenMM
::
CustomGBForce
::
ComputationType
>
valueTypes
;
std
::
vector
<
OpenMM
::
CustomGBForce
::
ComputationType
>
energyTypes
;
std
::
vector
<
OpenMM
::
CustomGBForce
::
ComputationType
>
energyTypes
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
NonbondedMethod
nonbondedMethod
;
};
};
...
@@ -467,7 +467,7 @@ private:
...
@@ -467,7 +467,7 @@ private:
std
::
vector
<
std
::
vector
<
double
>
>
particleParamArray
;
std
::
vector
<
std
::
vector
<
double
>
>
particleParamArray
;
CpuCustomManyParticleForce
*
ixn
;
CpuCustomManyParticleForce
*
ixn
;
std
::
vector
<
std
::
string
>
globalParameterNames
;
std
::
vector
<
std
::
string
>
globalParameterNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
NonbondedMethod
nonbondedMethod
;
};
};
...
...
platforms/cpu/src/CpuKernels.cpp
View file @
8292bb3a
...
@@ -45,7 +45,6 @@
...
@@ -45,7 +45,6 @@
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/vectorize.h"
#include "openmm/internal/vectorize.h"
#include "openmm/serialization/XmlSerializer.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h"
#include "lepton/CustomFunction.h"
#include "lepton/Operation.h"
#include "lepton/Operation.h"
...
@@ -888,10 +887,10 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
...
@@ -888,10 +887,10 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
switchingDistance
=
force
.
getSwitchingDistance
();
switchingDistance
=
force
.
getSwitchingDistance
();
}
}
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Record information for the long range correction.
// Record information for the long range correction.
...
@@ -1053,8 +1052,8 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
...
@@ -1053,8 +1052,8 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
bool
changed
=
false
;
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
changed
=
true
;
}
}
}
}
...
@@ -1166,10 +1165,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
...
@@ -1166,10 +1165,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
neighborList
=
new
CpuNeighborList
(
4
);
neighborList
=
new
CpuNeighborList
(
4
);
data
.
isPeriodic
|=
(
force
.
getNonbondedMethod
()
==
CustomGBForce
::
CutoffPeriodic
);
data
.
isPeriodic
|=
(
force
.
getNonbondedMethod
()
==
CustomGBForce
::
CutoffPeriodic
);
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
// Create the interaction.
...
@@ -1319,8 +1318,8 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c
...
@@ -1319,8 +1318,8 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c
bool
changed
=
false
;
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
changed
=
true
;
}
}
}
}
...
@@ -1349,10 +1348,10 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons
...
@@ -1349,10 +1348,10 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons
for
(
int
i
=
0
;
i
<
force
.
getNumGlobalParameters
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumGlobalParameters
();
i
++
)
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
// Create the interaction.
...
@@ -1399,8 +1398,8 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl&
...
@@ -1399,8 +1398,8 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl&
bool
changed
=
false
;
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
changed
=
true
;
}
}
}
}
...
...
platforms/reference/include/ReferenceKernels.h
View file @
8292bb3a
...
@@ -706,7 +706,7 @@ private:
...
@@ -706,7 +706,7 @@ private:
std
::
vector
<
std
::
string
>
parameterNames
,
globalParameterNames
,
computedValueNames
,
energyParamDerivNames
;
std
::
vector
<
std
::
string
>
parameterNames
,
globalParameterNames
,
computedValueNames
,
energyParamDerivNames
;
std
::
vector
<
std
::
pair
<
std
::
set
<
int
>
,
std
::
set
<
int
>
>
>
interactionGroups
;
std
::
vector
<
std
::
pair
<
std
::
set
<
int
>
,
std
::
set
<
int
>
>
>
interactionGroups
;
std
::
vector
<
double
>
longRangeCoefficientDerivs
;
std
::
vector
<
double
>
longRangeCoefficientDerivs
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
NonbondedMethod
nonbondedMethod
;
NeighborList
*
neighborList
;
NeighborList
*
neighborList
;
};
};
...
@@ -797,7 +797,7 @@ private:
...
@@ -797,7 +797,7 @@ private:
std
::
vector
<
std
::
vector
<
Lepton
::
CompiledExpression
>
>
energyGradientExpressions
;
std
::
vector
<
std
::
vector
<
Lepton
::
CompiledExpression
>
>
energyGradientExpressions
;
std
::
vector
<
std
::
vector
<
Lepton
::
CompiledExpression
>
>
energyParamDerivExpressions
;
std
::
vector
<
std
::
vector
<
Lepton
::
CompiledExpression
>
>
energyParamDerivExpressions
;
std
::
vector
<
OpenMM
::
CustomGBForce
::
ComputationType
>
energyTypes
;
std
::
vector
<
OpenMM
::
CustomGBForce
::
ComputationType
>
energyTypes
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
NonbondedMethod
nonbondedMethod
;
NeighborList
*
neighborList
;
NeighborList
*
neighborList
;
};
};
...
@@ -884,7 +884,7 @@ private:
...
@@ -884,7 +884,7 @@ private:
ReferenceCustomHbondIxn
*
ixn
;
ReferenceCustomHbondIxn
*
ixn
;
std
::
vector
<
std
::
set
<
int
>
>
exclusions
;
std
::
vector
<
std
::
set
<
int
>
>
exclusions
;
std
::
vector
<
std
::
string
>
globalParameterNames
;
std
::
vector
<
std
::
string
>
globalParameterNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
};
};
/**
/**
...
@@ -927,7 +927,7 @@ private:
...
@@ -927,7 +927,7 @@ private:
std
::
vector
<
std
::
vector
<
double
>
>
bondParamArray
;
std
::
vector
<
std
::
vector
<
double
>
>
bondParamArray
;
ReferenceCustomCentroidBondIxn
*
ixn
;
ReferenceCustomCentroidBondIxn
*
ixn
;
std
::
vector
<
std
::
string
>
globalParameterNames
,
energyParamDerivNames
;
std
::
vector
<
std
::
string
>
globalParameterNames
,
energyParamDerivNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
bool
usePeriodic
;
bool
usePeriodic
;
Vec3
*
boxVectors
;
Vec3
*
boxVectors
;
};
};
...
@@ -970,7 +970,7 @@ private:
...
@@ -970,7 +970,7 @@ private:
std
::
vector
<
std
::
vector
<
double
>
>
bondParamArray
;
std
::
vector
<
std
::
vector
<
double
>
>
bondParamArray
;
ReferenceCustomCompoundBondIxn
*
ixn
;
ReferenceCustomCompoundBondIxn
*
ixn
;
std
::
vector
<
std
::
string
>
globalParameterNames
,
energyParamDerivNames
;
std
::
vector
<
std
::
string
>
globalParameterNames
,
energyParamDerivNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
bool
usePeriodic
;
bool
usePeriodic
;
Vec3
*
boxVectors
;
Vec3
*
boxVectors
;
};
};
...
@@ -1012,7 +1012,7 @@ private:
...
@@ -1012,7 +1012,7 @@ private:
std
::
vector
<
std
::
vector
<
double
>
>
particleParamArray
;
std
::
vector
<
std
::
vector
<
double
>
>
particleParamArray
;
ReferenceCustomManyParticleIxn
*
ixn
;
ReferenceCustomManyParticleIxn
*
ixn
;
std
::
vector
<
std
::
string
>
globalParameterNames
;
std
::
vector
<
std
::
string
>
globalParameterNames
;
std
::
map
<
std
::
string
,
const
T
abulatedFunction
*>
tabulatedFunctions
;
std
::
map
<
std
::
string
,
int
>
t
abulatedFunction
UpdateCount
;
NonbondedMethod
nonbondedMethod
;
NonbondedMethod
nonbondedMethod
;
};
};
...
...
platforms/reference/src/ReferenceKernels.cpp
View file @
8292bb3a
...
@@ -80,7 +80,6 @@
...
@@ -80,7 +80,6 @@
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/Integrator.h"
#include "openmm/Integrator.h"
#include "openmm/OpenMMException.h"
#include "openmm/OpenMMException.h"
#include "openmm/serialization/XmlSerializer.h"
#include "SimTKOpenMMUtilities.h"
#include "SimTKOpenMMUtilities.h"
#include "lepton/CustomFunction.h"
#include "lepton/CustomFunction.h"
#include "lepton/Operation.h"
#include "lepton/Operation.h"
...
@@ -1177,10 +1176,10 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
...
@@ -1177,10 +1176,10 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
switchingDistance
=
force
.
getSwitchingDistance
();
switchingDistance
=
force
.
getSwitchingDistance
();
}
}
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the expressions.
// Create the expressions.
...
@@ -1349,8 +1348,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
...
@@ -1349,8 +1348,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
bool
changed
=
false
;
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
changed
=
true
;
}
}
}
}
...
@@ -1465,10 +1464,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
...
@@ -1465,10 +1464,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
else
else
neighborList
=
new
NeighborList
();
neighborList
=
new
NeighborList
();
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the expressions.
// Create the expressions.
...
@@ -1624,8 +1623,8 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont
...
@@ -1624,8 +1623,8 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont
bool
changed
=
false
;
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
changed
=
true
;
}
}
}
}
...
@@ -1750,10 +1749,10 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
...
@@ -1750,10 +1749,10 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
nonbondedCutoff
=
force
.
getCutoffDistance
();
nonbondedCutoff
=
force
.
getCutoffDistance
();
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
// Create the interaction.
...
@@ -1839,8 +1838,8 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c
...
@@ -1839,8 +1838,8 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c
bool
changed
=
false
;
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
changed
=
true
;
}
}
}
}
...
@@ -1873,10 +1872,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system
...
@@ -1873,10 +1872,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system
for
(
int
i
=
0
;
i
<
numBonds
;
++
i
)
for
(
int
i
=
0
;
i
<
numBonds
;
++
i
)
force
.
getBondParameters
(
i
,
bondGroups
[
i
],
bondParamArray
[
i
]);
force
.
getBondParameters
(
i
,
bondGroups
[
i
],
bondParamArray
[
i
]);
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
// Create the interaction.
...
@@ -1962,8 +1961,8 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context
...
@@ -1962,8 +1961,8 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context
bool
changed
=
false
;
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
changed
=
true
;
}
}
}
}
...
@@ -1990,10 +1989,10 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
...
@@ -1990,10 +1989,10 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
for
(
int
i
=
0
;
i
<
numBonds
;
++
i
)
for
(
int
i
=
0
;
i
<
numBonds
;
++
i
)
force
.
getBondParameters
(
i
,
bondParticles
[
i
],
bondParamArray
[
i
]);
force
.
getBondParameters
(
i
,
bondParticles
[
i
],
bondParamArray
[
i
]);
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
// Create the interaction.
...
@@ -2078,8 +2077,8 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context
...
@@ -2078,8 +2077,8 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context
bool
changed
=
false
;
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
changed
=
true
;
}
}
}
}
...
@@ -2107,10 +2106,10 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system
...
@@ -2107,10 +2106,10 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system
for
(
int
i
=
0
;
i
<
force
.
getNumGlobalParameters
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumGlobalParameters
();
i
++
)
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
// Record the tabulated functions for future reference.
// Record the tabulated function
update count
s for future reference.
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunction
s
[
force
.
getTabulatedFunctionName
(
i
)]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
force
.
getTabulatedFunctionName
(
i
)]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
// Create the interaction.
// Create the interaction.
...
@@ -2158,8 +2157,8 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context
...
@@ -2158,8 +2157,8 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context
bool
changed
=
false
;
bool
changed
=
false
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
string
name
=
force
.
getTabulatedFunctionName
(
i
);
string
name
=
force
.
getTabulatedFunctionName
(
i
);
if
(
force
.
getTabulatedFunction
(
i
)
!=
*
tabulatedFunction
s
[
name
])
{
if
(
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
()
!=
tabulatedFunction
UpdateCount
[
name
])
{
tabulatedFunction
s
[
name
]
=
XmlSerializer
::
clone
(
force
.
getTabulatedFunction
(
i
));
tabulatedFunction
UpdateCount
[
name
]
=
force
.
getTabulatedFunction
(
i
)
.
getUpdateCount
(
);
changed
=
true
;
changed
=
true
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment