Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
2a52e208
Commit
2a52e208
authored
Jul 25, 2016
by
peastman
Browse files
Reference implementation of parameter derivatives for CustomIntegrator
parent
74efa95f
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
206 additions
and
15 deletions
+206
-15
libraries/lepton/include/lepton/CustomFunction.h
libraries/lepton/include/lepton/CustomFunction.h
+1
-1
libraries/lepton/src/ParsedExpression.cpp
libraries/lepton/src/ParsedExpression.cpp
+1
-1
openmmapi/include/openmm/CustomIntegrator.h
openmmapi/include/openmm/CustomIntegrator.h
+10
-0
openmmapi/include/openmm/internal/CustomIntegratorUtilities.h
...mmapi/include/openmm/internal/CustomIntegratorUtilities.h
+26
-1
openmmapi/src/CustomIntegratorUtilities.cpp
openmmapi/src/CustomIntegratorUtilities.cpp
+32
-5
platforms/reference/include/ReferenceCustomDynamics.h
platforms/reference/include/ReferenceCustomDynamics.h
+4
-0
platforms/reference/src/SimTKReference/ReferenceCustomDynamics.cpp
.../reference/src/SimTKReference/ReferenceCustomDynamics.cpp
+45
-5
tests/TestCustomIntegrator.h
tests/TestCustomIntegrator.h
+87
-2
No files found.
libraries/lepton/include/lepton/CustomFunction.h
View file @
2a52e208
...
@@ -48,7 +48,7 @@ public:
...
@@ -48,7 +48,7 @@ public:
virtual
~
CustomFunction
()
{
virtual
~
CustomFunction
()
{
}
}
/**
/**
* Get the number of arguments this function exp
r
ects.
* Get the number of arguments this function expects.
*/
*/
virtual
int
getNumArguments
()
const
=
0
;
virtual
int
getNumArguments
()
const
=
0
;
/**
/**
...
...
libraries/lepton/src/ParsedExpression.cpp
View file @
2a52e208
...
@@ -109,7 +109,7 @@ ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const Ex
...
@@ -109,7 +109,7 @@ ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const Ex
for
(
int
i
=
0
;
i
<
(
int
)
children
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
children
.
size
();
i
++
)
children
[
i
]
=
precalculateConstantSubexpressions
(
node
.
getChildren
()[
i
]);
children
[
i
]
=
precalculateConstantSubexpressions
(
node
.
getChildren
()[
i
]);
ExpressionTreeNode
result
=
ExpressionTreeNode
(
node
.
getOperation
().
clone
(),
children
);
ExpressionTreeNode
result
=
ExpressionTreeNode
(
node
.
getOperation
().
clone
(),
children
);
if
(
node
.
getOperation
().
getId
()
==
Operation
::
VARIABLE
)
if
(
node
.
getOperation
().
getId
()
==
Operation
::
VARIABLE
||
node
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
)
return
result
;
return
result
;
for
(
int
i
=
0
;
i
<
(
int
)
children
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
children
.
size
();
i
++
)
if
(
children
[
i
].
getOperation
().
getId
()
!=
Operation
::
CONSTANT
)
if
(
children
[
i
].
getOperation
().
getId
()
!=
Operation
::
CONSTANT
)
...
...
openmmapi/include/openmm/CustomIntegrator.h
View file @
2a52e208
...
@@ -202,6 +202,16 @@ namespace OpenMM {
...
@@ -202,6 +202,16 @@ namespace OpenMM {
* following comparison operators: =, <. >, !=, <=, >=. Blocks may be nested
* following comparison operators: =, <. >, !=, <=, >=. Blocks may be nested
* inside each other.
* inside each other.
*
*
* Another feature of CustomIntegrator is that it can use derivatives of the
* potential energy with respect to context parameters. These derivatives are
* typically computed by custom forces, and are only computed if a Force object
* has been specifically told to compute them by calling addEnergyParameterDerivative()
* on it. CustomIntegrator provides a deriv() function for accessing these
* derivatives in global or per-DOF expressions. For example, "deriv(energy, lambda)"
* is the derivative of the total potentially energy with respect to the parameter
* lambda. You can also restrict it to a single force group by specifying a different
* variable for the first argument, such as "deriv(energy1, lambda)".
*
* An Integrator has one other job in addition to evolving the equations of motion:
* An Integrator has one other job in addition to evolving the equations of motion:
* it defines how to compute the kinetic energy of the system. Depending on the
* it defines how to compute the kinetic energy of the system. Depending on the
* integration method used, simply summing mv<sup>2</sup>/2 over all degrees of
* integration method used, simply summing mv<sup>2</sup>/2 over all degrees of
...
...
openmmapi/include/openmm/internal/CustomIntegratorUtilities.h
View file @
2a52e208
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* *
* Portions copyright (c) 2015 Stanford University and the Authors.
*
* Portions copyright (c) 2015
-2016
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: *
* *
* *
...
@@ -34,8 +34,10 @@
...
@@ -34,8 +34,10 @@
#include "openmm/CustomIntegrator.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/ContextImpl.h"
#include "lepton/CustomFunction.h"
#include "lepton/ParsedExpression.h"
#include "lepton/ParsedExpression.h"
#include <map>
#include <map>
#include <string>
#include <vector>
#include <vector>
namespace
OpenMM
{
namespace
OpenMM
{
...
@@ -48,6 +50,7 @@ class System;
...
@@ -48,6 +50,7 @@ class System;
class
OPENMM_EXPORT
CustomIntegratorUtilities
{
class
OPENMM_EXPORT
CustomIntegratorUtilities
{
public:
public:
class
DerivFunction
;
enum
Comparison
{
enum
Comparison
{
EQUAL
=
0
,
LESS_THAN
=
1
,
GREATER_THAN
=
2
,
NOT_EQUAL
=
3
,
LESS_THAN_OR_EQUAL
=
4
,
GREATER_THAN_OR_EQUAL
=
5
EQUAL
=
0
,
LESS_THAN
=
1
,
GREATER_THAN
=
2
,
NOT_EQUAL
=
3
,
LESS_THAN_OR_EQUAL
=
4
,
GREATER_THAN_OR_EQUAL
=
5
};
};
...
@@ -82,6 +85,28 @@ private:
...
@@ -82,6 +85,28 @@ private:
const
std
::
vector
<
bool
>&
invalidatesForces
,
const
std
::
vector
<
int
>&
forceGroup
,
std
::
vector
<
bool
>&
computeBoth
);
const
std
::
vector
<
bool
>&
invalidatesForces
,
const
std
::
vector
<
int
>&
forceGroup
,
std
::
vector
<
bool
>&
computeBoth
);
static
void
analyzeForceComputationsForPath
(
std
::
vector
<
int
>&
steps
,
const
std
::
vector
<
bool
>&
needsForces
,
const
std
::
vector
<
bool
>&
needsEnergy
,
static
void
analyzeForceComputationsForPath
(
std
::
vector
<
int
>&
steps
,
const
std
::
vector
<
bool
>&
needsForces
,
const
std
::
vector
<
bool
>&
needsEnergy
,
const
std
::
vector
<
bool
>&
invalidatesForces
,
const
std
::
vector
<
int
>&
forceGroup
,
std
::
vector
<
bool
>&
computeBoth
);
const
std
::
vector
<
bool
>&
invalidatesForces
,
const
std
::
vector
<
int
>&
forceGroup
,
std
::
vector
<
bool
>&
computeBoth
);
static
void
validateDerivatives
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
string
>&
derivNames
);
};
/**
* This class is used to implement the deriv() function when it appears in expressions.
*/
class
CustomIntegratorUtilities
::
DerivFunction
:
public
Lepton
::
CustomFunction
{
public:
DerivFunction
()
{
}
int
getNumArguments
()
const
{
return
2
;
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
0.0
;
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0.0
;
}
CustomFunction
*
clone
()
const
{
return
new
DerivFunction
();
}
};
};
}
// namespace OpenMM
}
// namespace OpenMM
...
...
openmmapi/src/CustomIntegratorUtilities.cpp
View file @
2a52e208
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* *
* Portions copyright (c) 2015 Stanford University and the Authors.
*
* Portions copyright (c) 2015
-2016
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: *
* *
* *
...
@@ -34,6 +34,7 @@
...
@@ -34,6 +34,7 @@
#include "openmm/internal/ForceImpl.h"
#include "openmm/internal/ForceImpl.h"
#include "lepton/Operation.h"
#include "lepton/Operation.h"
#include "lepton/Parser.h"
#include "lepton/Parser.h"
#include <algorithm>
#include <set>
#include <set>
#include <sstream>
#include <sstream>
...
@@ -81,6 +82,9 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
...
@@ -81,6 +82,9 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
forceGroup
.
resize
(
numSteps
,
-
2
);
forceGroup
.
resize
(
numSteps
,
-
2
);
vector
<
CustomIntegrator
::
ComputationType
>
stepType
(
numSteps
);
vector
<
CustomIntegrator
::
ComputationType
>
stepType
(
numSteps
);
vector
<
string
>
stepVariable
(
numSteps
);
vector
<
string
>
stepVariable
(
numSteps
);
map
<
string
,
Lepton
::
CustomFunction
*>
customFunctions
;
DerivFunction
derivFunction
;
customFunctions
[
"deriv"
]
=
&
derivFunction
;
// Parse the expressions.
// Parse the expressions.
...
@@ -92,11 +96,11 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
...
@@ -92,11 +96,11 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
string
lhs
,
rhs
;
string
lhs
,
rhs
;
parseCondition
(
expression
,
lhs
,
rhs
,
comparisons
[
step
]);
parseCondition
(
expression
,
lhs
,
rhs
,
comparisons
[
step
]);
expressions
[
step
].
push_back
(
Lepton
::
Parser
::
parse
(
lhs
).
optimize
());
expressions
[
step
].
push_back
(
Lepton
::
Parser
::
parse
(
lhs
,
customFunctions
).
optimize
());
expressions
[
step
].
push_back
(
Lepton
::
Parser
::
parse
(
rhs
).
optimize
());
expressions
[
step
].
push_back
(
Lepton
::
Parser
::
parse
(
rhs
,
customFunctions
).
optimize
());
}
}
else
if
(
expression
.
size
()
>
0
)
else
if
(
expression
.
size
()
>
0
)
expressions
[
step
].
push_back
(
Lepton
::
Parser
::
parse
(
expression
).
optimize
());
expressions
[
step
].
push_back
(
Lepton
::
Parser
::
parse
(
expression
,
customFunctions
).
optimize
());
}
}
// Identify which steps invalidate the forces.
// Identify which steps invalidate the forces.
...
@@ -191,6 +195,14 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
...
@@ -191,6 +195,14 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
vector
<
int
>
jumps
(
numSteps
,
-
1
);
vector
<
int
>
jumps
(
numSteps
,
-
1
);
vector
<
int
>
stepsInPath
;
vector
<
int
>
stepsInPath
;
enumeratePaths
(
0
,
stepsInPath
,
jumps
,
blockEnd
,
stepType
,
needsForces
,
needsEnergy
,
invalidatesForces
,
forceGroup
,
computeBoth
);
enumeratePaths
(
0
,
stepsInPath
,
jumps
,
blockEnd
,
stepType
,
needsForces
,
needsEnergy
,
invalidatesForces
,
forceGroup
,
computeBoth
);
// Make sure calls to deriv() all valid.
vector
<
string
>
derivNames
=
energyGroupName
;
derivNames
.
push_back
(
"energy"
);
for
(
int
i
=
0
;
i
<
expressions
.
size
();
i
++
)
for
(
int
j
=
0
;
j
<
expressions
[
i
].
size
();
j
++
)
validateDerivatives
(
expressions
[
i
][
j
].
getRootNode
(),
derivNames
);
}
}
void
CustomIntegratorUtilities
::
enumeratePaths
(
int
firstStep
,
vector
<
int
>
steps
,
vector
<
int
>
jumps
,
const
vector
<
int
>&
blockEnd
,
void
CustomIntegratorUtilities
::
enumeratePaths
(
int
firstStep
,
vector
<
int
>
steps
,
vector
<
int
>
jumps
,
const
vector
<
int
>&
blockEnd
,
...
@@ -265,3 +277,18 @@ void CustomIntegratorUtilities::analyzeForceComputationsForPath(vector<int>& ste
...
@@ -265,3 +277,18 @@ void CustomIntegratorUtilities::analyzeForceComputationsForPath(vector<int>& ste
}
}
}
}
}
}
void
CustomIntegratorUtilities
::
validateDerivatives
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
vector
<
string
>&
derivNames
)
{
const
Lepton
::
Operation
&
op
=
node
.
getOperation
();
if
(
op
.
getId
()
==
Lepton
::
Operation
::
CUSTOM
&&
op
.
getName
()
==
"deriv"
)
{
const
Lepton
::
Operation
&
child
=
node
.
getChildren
()[
0
].
getOperation
();
if
(
child
.
getId
()
!=
Lepton
::
Operation
::
VARIABLE
||
find
(
derivNames
.
begin
(),
derivNames
.
end
(),
child
.
getName
())
==
derivNames
.
end
())
throw
OpenMMException
(
"The first argument to deriv() must be an energy variable"
);
if
(
node
.
getChildren
()[
1
].
getOperation
().
getId
()
!=
Lepton
::
Operation
::
VARIABLE
)
throw
OpenMMException
(
"The second argument to deriv() must be a context parameter"
);
}
else
{
for
(
int
i
=
0
;
i
<
node
.
getChildren
().
size
();
i
++
)
validateDerivatives
(
node
.
getChildren
()[
i
],
derivNames
);
}
}
platforms/reference/include/ReferenceCustomDynamics.h
View file @
2a52e208
...
@@ -41,6 +41,7 @@ namespace OpenMM {
...
@@ -41,6 +41,7 @@ namespace OpenMM {
class
ReferenceCustomDynamics
:
public
ReferenceDynamics
{
class
ReferenceCustomDynamics
:
public
ReferenceDynamics
{
private:
private:
class
DerivFunction
;
const
OpenMM
::
CustomIntegrator
&
integrator
;
const
OpenMM
::
CustomIntegrator
&
integrator
;
std
::
vector
<
RealOpenMM
>
inverseMasses
;
std
::
vector
<
RealOpenMM
>
inverseMasses
;
std
::
vector
<
OpenMM
::
RealVec
>
sumBuffer
,
oldPos
;
std
::
vector
<
OpenMM
::
RealVec
>
sumBuffer
,
oldPos
;
...
@@ -51,6 +52,7 @@ private:
...
@@ -51,6 +52,7 @@ private:
std
::
vector
<
bool
>
invalidatesForces
,
needsForces
,
needsEnergy
,
computeBothForceAndEnergy
;
std
::
vector
<
bool
>
invalidatesForces
,
needsForces
,
needsEnergy
,
computeBothForceAndEnergy
;
std
::
vector
<
int
>
forceGroupFlags
,
blockEnd
;
std
::
vector
<
int
>
forceGroupFlags
,
blockEnd
;
RealOpenMM
energy
;
RealOpenMM
energy
;
std
::
map
<
std
::
string
,
double
>
energyParamDerivs
;
Lepton
::
CompiledExpression
kineticEnergyExpression
;
Lepton
::
CompiledExpression
kineticEnergyExpression
;
bool
kineticEnergyNeedsForce
;
bool
kineticEnergyNeedsForce
;
CompiledExpressionSet
expressionSet
;
CompiledExpressionSet
expressionSet
;
...
@@ -59,6 +61,8 @@ private:
...
@@ -59,6 +61,8 @@ private:
void
initialize
(
OpenMM
::
ContextImpl
&
context
,
std
::
vector
<
RealOpenMM
>&
masses
,
std
::
map
<
std
::
string
,
RealOpenMM
>&
globals
);
void
initialize
(
OpenMM
::
ContextImpl
&
context
,
std
::
vector
<
RealOpenMM
>&
masses
,
std
::
map
<
std
::
string
,
RealOpenMM
>&
globals
);
Lepton
::
ExpressionTreeNode
replaceDerivFunctions
(
const
Lepton
::
ExpressionTreeNode
&
node
,
OpenMM
::
ContextImpl
&
context
);
void
computePerDof
(
int
numberOfAtoms
,
std
::
vector
<
OpenMM
::
RealVec
>&
results
,
const
std
::
vector
<
OpenMM
::
RealVec
>&
atomCoordinates
,
void
computePerDof
(
int
numberOfAtoms
,
std
::
vector
<
OpenMM
::
RealVec
>&
results
,
const
std
::
vector
<
OpenMM
::
RealVec
>&
atomCoordinates
,
const
std
::
vector
<
OpenMM
::
RealVec
>&
velocities
,
const
std
::
vector
<
OpenMM
::
RealVec
>&
forces
,
const
std
::
vector
<
RealOpenMM
>&
masses
,
const
std
::
vector
<
OpenMM
::
RealVec
>&
velocities
,
const
std
::
vector
<
OpenMM
::
RealVec
>&
forces
,
const
std
::
vector
<
RealOpenMM
>&
masses
,
const
std
::
vector
<
std
::
vector
<
OpenMM
::
RealVec
>
>&
perDof
,
const
Lepton
::
CompiledExpression
&
expression
,
int
forceIndex
);
const
std
::
vector
<
std
::
vector
<
OpenMM
::
RealVec
>
>&
perDof
,
const
Lepton
::
CompiledExpression
&
expression
,
int
forceIndex
);
...
...
platforms/reference/src/SimTKReference/ReferenceCustomDynamics.cpp
View file @
2a52e208
...
@@ -36,6 +36,28 @@
...
@@ -36,6 +36,28 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
OpenMM
;
using
namespace
OpenMM
;
using
namespace
Lepton
;
class
ReferenceCustomDynamics
::
DerivFunction
:
public
CustomFunction
{
public:
DerivFunction
(
map
<
string
,
double
>&
energyParamDerivs
,
const
string
&
param
)
:
energyParamDerivs
(
energyParamDerivs
),
param
(
param
)
{
}
int
getNumArguments
()
const
{
return
0
;
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
energyParamDerivs
[
param
];
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0
;
}
CustomFunction
*
clone
()
const
{
return
new
DerivFunction
(
energyParamDerivs
,
param
);
}
private:
map
<
string
,
double
>&
energyParamDerivs
;
string
param
;
};
/**---------------------------------------------------------------------------------------
/**---------------------------------------------------------------------------------------
...
@@ -56,7 +78,7 @@ ReferenceCustomDynamics::ReferenceCustomDynamics(int numberOfAtoms, const Custom
...
@@ -56,7 +78,7 @@ ReferenceCustomDynamics::ReferenceCustomDynamics(int numberOfAtoms, const Custom
string
expression
;
string
expression
;
integrator
.
getComputationStep
(
i
,
stepType
[
i
],
stepVariable
[
i
],
expression
);
integrator
.
getComputationStep
(
i
,
stepType
[
i
],
stepVariable
[
i
],
expression
);
}
}
kineticEnergyExpression
=
Lepton
::
Parser
::
parse
(
integrator
.
getKineticEnergyExpression
()).
optimize
().
createCompiledExpression
();
kineticEnergyExpression
=
Parser
::
parse
(
integrator
.
getKineticEnergyExpression
()).
optimize
().
createCompiledExpression
();
expressionSet
.
registerExpression
(
kineticEnergyExpression
);
expressionSet
.
registerExpression
(
kineticEnergyExpression
);
kineticEnergyNeedsForce
=
false
;
kineticEnergyNeedsForce
=
false
;
if
(
kineticEnergyExpression
.
getVariables
().
find
(
"f"
)
!=
kineticEnergyExpression
.
getVariables
().
end
())
if
(
kineticEnergyExpression
.
getVariables
().
find
(
"f"
)
!=
kineticEnergyExpression
.
getVariables
().
end
())
...
@@ -78,13 +100,13 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM
...
@@ -78,13 +100,13 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM
int
numSteps
=
stepType
.
size
();
int
numSteps
=
stepType
.
size
();
vector
<
int
>
forceGroup
;
vector
<
int
>
forceGroup
;
vector
<
vector
<
Lepton
::
ParsedExpression
>
>
expressions
;
vector
<
vector
<
ParsedExpression
>
>
expressions
;
CustomIntegratorUtilities
::
analyzeComputations
(
context
,
integrator
,
expressions
,
comparisons
,
blockEnd
,
invalidatesForces
,
needsForces
,
needsEnergy
,
computeBothForceAndEnergy
,
forceGroup
);
CustomIntegratorUtilities
::
analyzeComputations
(
context
,
integrator
,
expressions
,
comparisons
,
blockEnd
,
invalidatesForces
,
needsForces
,
needsEnergy
,
computeBothForceAndEnergy
,
forceGroup
);
stepExpressions
.
resize
(
expressions
.
size
());
stepExpressions
.
resize
(
expressions
.
size
());
for
(
int
i
=
0
;
i
<
numSteps
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numSteps
;
i
++
)
{
stepExpressions
[
i
].
resize
(
expressions
[
i
].
size
());
stepExpressions
[
i
].
resize
(
expressions
[
i
].
size
());
for
(
int
j
=
0
;
j
<
(
int
)
expressions
[
i
].
size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
(
int
)
expressions
[
i
].
size
();
j
++
)
{
stepExpressions
[
i
][
j
]
=
expressions
[
i
][
j
]
.
createCompiledExpression
();
stepExpressions
[
i
][
j
]
=
ParsedExpression
(
replaceDerivFunctions
(
expressions
[
i
][
j
].
getRootNode
(),
context
))
.
createCompiledExpression
();
expressionSet
.
registerExpression
(
stepExpressions
[
i
][
j
]);
expressionSet
.
registerExpression
(
stepExpressions
[
i
][
j
]);
}
}
if
(
stepType
[
i
]
==
CustomIntegrator
::
WhileBlockStart
)
if
(
stepType
[
i
]
==
CustomIntegrator
::
WhileBlockStart
)
...
@@ -141,6 +163,22 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM
...
@@ -141,6 +163,22 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM
stepVariableIndex
.
push_back
(
expressionSet
.
getVariableIndex
(
stepVariable
[
i
]));
stepVariableIndex
.
push_back
(
expressionSet
.
getVariableIndex
(
stepVariable
[
i
]));
}
}
ExpressionTreeNode
ReferenceCustomDynamics
::
replaceDerivFunctions
(
const
ExpressionTreeNode
&
node
,
ContextImpl
&
context
)
{
const
Operation
&
op
=
node
.
getOperation
();
if
(
op
.
getId
()
==
Operation
::
CUSTOM
&&
op
.
getName
()
==
"deriv"
)
{
string
param
=
node
.
getChildren
()[
1
].
getOperation
().
getName
();
if
(
context
.
getParameters
().
find
(
param
)
==
context
.
getParameters
().
end
())
throw
OpenMMException
(
"The second argument to deriv() must be a context parameter"
);
return
ExpressionTreeNode
(
new
Operation
::
Custom
(
"deriv"
,
new
DerivFunction
(
energyParamDerivs
,
param
)));
}
else
{
vector
<
ExpressionTreeNode
>
children
;
for
(
int
i
=
0
;
i
<
(
int
)
node
.
getChildren
().
size
();
i
++
)
children
.
push_back
(
replaceDerivFunctions
(
node
.
getChildren
()[
i
],
context
));
return
ExpressionTreeNode
(
op
.
clone
(),
children
);
}
}
/**---------------------------------------------------------------------------------------
/**---------------------------------------------------------------------------------------
Update -- driver routine for performing Custom dynamics update of coordinates
Update -- driver routine for performing Custom dynamics update of coordinates
...
@@ -178,8 +216,10 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
...
@@ -178,8 +216,10 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
bool
computeEnergy
=
needsEnergy
[
step
]
||
computeBothForceAndEnergy
[
step
];
bool
computeEnergy
=
needsEnergy
[
step
]
||
computeBothForceAndEnergy
[
step
];
recordChangedParameters
(
context
,
globals
);
recordChangedParameters
(
context
,
globals
);
RealOpenMM
e
=
context
.
calcForcesAndEnergy
(
computeForce
,
computeEnergy
,
forceGroupFlags
[
step
]);
RealOpenMM
e
=
context
.
calcForcesAndEnergy
(
computeForce
,
computeEnergy
,
forceGroupFlags
[
step
]);
if
(
computeEnergy
)
if
(
computeEnergy
)
{
energy
=
e
;
energy
=
e
;
context
.
getEnergyParameterDerivatives
(
energyParamDerivs
);
}
forcesAreValid
=
true
;
forcesAreValid
=
true
;
}
}
expressionSet
.
setVariable
(
energyVariableIndex
[
step
],
energy
);
expressionSet
.
setVariable
(
energyVariableIndex
[
step
],
energy
);
...
@@ -266,7 +306,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
...
@@ -266,7 +306,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
void
ReferenceCustomDynamics
::
computePerDof
(
int
numberOfAtoms
,
vector
<
RealVec
>&
results
,
const
vector
<
RealVec
>&
atomCoordinates
,
void
ReferenceCustomDynamics
::
computePerDof
(
int
numberOfAtoms
,
vector
<
RealVec
>&
results
,
const
vector
<
RealVec
>&
atomCoordinates
,
const
vector
<
RealVec
>&
velocities
,
const
vector
<
RealVec
>&
forces
,
const
vector
<
RealOpenMM
>&
masses
,
const
vector
<
RealVec
>&
velocities
,
const
vector
<
RealVec
>&
forces
,
const
vector
<
RealOpenMM
>&
masses
,
const
vector
<
vector
<
RealVec
>
>&
perDof
,
const
Lepton
::
CompiledExpression
&
expression
,
int
forceIndex
)
{
const
vector
<
vector
<
RealVec
>
>&
perDof
,
const
CompiledExpression
&
expression
,
int
forceIndex
)
{
// Loop over all degrees of freedom.
// Loop over all degrees of freedom.
for
(
int
i
=
0
;
i
<
numberOfAtoms
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numberOfAtoms
;
i
++
)
{
...
...
tests/TestCustomIntegrator.h
View file @
2a52e208
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* *
* Portions copyright (c) 2008-201
5
Stanford University and the Authors. *
* Portions copyright (c) 2008-201
6
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: *
* *
* *
...
@@ -29,15 +29,21 @@
...
@@ -29,15 +29,21 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
* -------------------------------------------------------------------------- */
#ifdef WIN32
#define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/Context.h"
#include "openmm/Context.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/CustomAngleForce.h"
#include "openmm/CustomBondForce.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h"
#include "openmm/NonbondedForce.h"
#include "openmm/System.h"
#include "openmm/System.h"
#include "openmm/CustomIntegrator.h"
#include "SimTKOpenMMRealType.h"
#include "SimTKOpenMMRealType.h"
#include "sfmt/SFMT.h"
#include "sfmt/SFMT.h"
#include <cmath>
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
...
@@ -770,6 +776,84 @@ void testChangingGlobal() {
...
@@ -770,6 +776,84 @@ void testChangingGlobal() {
}
}
}
}
/**
* Test steps that depend on derivatives of the energy with respect to parameters.
*/
void
testEnergyParameterDerivatives
()
{
System
system
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
system
.
addParticle
(
1.0
);
// Create some custom forces that depend on parameters.
CustomBondForce
*
bonds
=
new
CustomBondForce
(
"K*(A*r-r0)^2"
);
system
.
addForce
(
bonds
);
bonds
->
addGlobalParameter
(
"K"
,
2.0
);
bonds
->
addGlobalParameter
(
"A"
,
1.0
);
bonds
->
addGlobalParameter
(
"r0"
,
1.5
);
bonds
->
addEnergyParameterDerivative
(
"K"
);
bonds
->
addEnergyParameterDerivative
(
"r0"
);
bonds
->
addBond
(
0
,
1
);
bonds
->
setForceGroup
(
0
);
CustomAngleForce
*
angles
=
new
CustomAngleForce
(
"K*(B*theta-theta0)^2"
);
system
.
addForce
(
angles
);
angles
->
addGlobalParameter
(
"K"
,
2.0
);
angles
->
addGlobalParameter
(
"B"
,
1.0
);
angles
->
addGlobalParameter
(
"theta0"
,
M_PI
/
3
);
angles
->
addEnergyParameterDerivative
(
"K"
);
angles
->
addEnergyParameterDerivative
(
"theta0"
);
angles
->
addAngle
(
0
,
1
,
2
);
angles
->
setForceGroup
(
1
);
// Create an integrator that records parameter derivatives.
CustomIntegrator
integrator
(
0.1
);
integrator
.
addGlobalVariable
(
"dEdK"
,
0.0
);
integrator
.
addGlobalVariable
(
"dEdr0"
,
0.0
);
integrator
.
addGlobalVariable
(
"dEdtheta0"
,
0.0
);
integrator
.
addGlobalVariable
(
"dEdK_0"
,
0.0
);
integrator
.
addGlobalVariable
(
"dEdr0_0"
,
0.0
);
integrator
.
addGlobalVariable
(
"dEdtheta0_0"
,
0.0
);
integrator
.
addGlobalVariable
(
"dEdK_1"
,
0.0
);
integrator
.
addGlobalVariable
(
"dEdr0_1"
,
0.0
);
integrator
.
addGlobalVariable
(
"dEdtheta0_1"
,
0.0
);
integrator
.
addComputeGlobal
(
"dEdK"
,
"deriv(energy, K)"
);
integrator
.
addComputeGlobal
(
"dEdr0"
,
"deriv(energy, r0)"
);
integrator
.
addComputeGlobal
(
"dEdtheta0"
,
"deriv(energy, theta0)"
);
integrator
.
addComputeGlobal
(
"dEdK_0"
,
"deriv(energy0, K)"
);
integrator
.
addComputeGlobal
(
"dEdr0_0"
,
"deriv(energy0, r0)"
);
integrator
.
addComputeGlobal
(
"dEdtheta0_0"
,
"deriv(energy0, theta0)"
);
integrator
.
addComputeGlobal
(
"dEdK_1"
,
"deriv(energy1, K)"
);
integrator
.
addComputeGlobal
(
"dEdr0_1"
,
"deriv(energy1, r0)"
);
integrator
.
addComputeGlobal
(
"dEdtheta0_1"
,
"deriv(energy1, theta0)"
);
// Create a Context.
Context
context
(
system
,
integrator
,
platform
);
vector
<
Vec3
>
positions
(
3
);
positions
[
0
]
=
Vec3
(
0
,
1
,
0
);
positions
[
1
]
=
Vec3
(
0
,
0
,
0
);
positions
[
2
]
=
Vec3
(
1
,
0
,
0
);
context
.
setPositions
(
positions
);
// Check the results.
integrator
.
step
(
1
);
double
dEdK_0
=
(
1.0
-
1.5
)
*
(
1.0
-
1.5
);
double
dEdK_1
=
(
M_PI
/
2
-
M_PI
/
3
)
*
(
M_PI
/
2
-
M_PI
/
3
);
ASSERT_EQUAL_TOL
(
dEdK_0
,
integrator
.
getGlobalVariableByName
(
"dEdK_0"
),
1e-5
);
ASSERT_EQUAL_TOL
(
dEdK_1
,
integrator
.
getGlobalVariableByName
(
"dEdK_1"
),
1e-5
);
ASSERT_EQUAL_TOL
(
dEdK_0
+
dEdK_1
,
integrator
.
getGlobalVariableByName
(
"dEdK"
),
1e-5
);
double
dEdr0
=
-
2.0
*
2.0
*
(
1.0
-
1.5
);
ASSERT_EQUAL_TOL
(
dEdr0
,
integrator
.
getGlobalVariableByName
(
"dEdr0_0"
),
1e-5
);
ASSERT_EQUAL_TOL
(
0.0
,
integrator
.
getGlobalVariableByName
(
"dEdr0_1"
),
1e-5
);
ASSERT_EQUAL_TOL
(
dEdr0
,
integrator
.
getGlobalVariableByName
(
"dEdr0"
),
1e-5
);
double
dEdtheta0
=
-
2.0
*
2.0
*
(
M_PI
/
2
-
M_PI
/
3
);
ASSERT_EQUAL_TOL
(
0.0
,
integrator
.
getGlobalVariableByName
(
"dEdtheta0_0"
),
1e-5
);
ASSERT_EQUAL_TOL
(
dEdtheta0
,
integrator
.
getGlobalVariableByName
(
"dEdtheta0_1"
),
1e-5
);
ASSERT_EQUAL_TOL
(
dEdtheta0
,
integrator
.
getGlobalVariableByName
(
"dEdtheta0"
),
1e-5
);
}
void
runPlatformTests
();
void
runPlatformTests
();
int
main
(
int
argc
,
char
*
argv
[])
{
int
main
(
int
argc
,
char
*
argv
[])
{
...
@@ -790,6 +874,7 @@ int main(int argc, char* argv[]) {
...
@@ -790,6 +874,7 @@ int main(int argc, char* argv[]) {
testIfBlock
();
testIfBlock
();
testWhileBlock
();
testWhileBlock
();
testChangingGlobal
();
testChangingGlobal
();
testEnergyParameterDerivatives
();
runPlatformTests
();
runPlatformTests
();
}
}
catch
(
const
exception
&
e
)
{
catch
(
const
exception
&
e
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment