Commit 094e6c71 authored by peastman's avatar peastman
Browse files

Began implementing tabulated functions for CustomIntegrator

parent 54ac9892
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2016 Stanford University and the Authors. * * Portions copyright (c) 2011-2017 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "Integrator.h" #include "Integrator.h"
#include "TabulatedFunction.h"
#include "Vec3.h" #include "Vec3.h"
#include "openmm/Kernel.h" #include "openmm/Kernel.h"
#include "internal/windowsExport.h" #include "internal/windowsExport.h"
...@@ -241,6 +242,9 @@ namespace OpenMM { ...@@ -241,6 +242,9 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta, select. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta, select. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* select(x,y,z) = z if x = 0, y otherwise. An expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * select(x,y,z) = z if x = 0, y otherwise. An expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
*
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in expressions.
*/ */
class OPENMM_EXPORT CustomIntegrator : public Integrator { class OPENMM_EXPORT CustomIntegrator : public Integrator {
...@@ -292,6 +296,7 @@ public: ...@@ -292,6 +296,7 @@ public:
* @param stepSize the step size with which to integrate the system (in picoseconds) * @param stepSize the step size with which to integrate the system (in picoseconds)
*/ */
CustomIntegrator(double stepSize); CustomIntegrator(double stepSize);
~CustomIntegrator();
/** /**
* Get the number of global variables that have been defined. * Get the number of global variables that have been defined.
*/ */
...@@ -310,6 +315,12 @@ public: ...@@ -310,6 +315,12 @@ public:
int getNumComputations() const { int getNumComputations() const {
return computations.size(); return computations.size();
} }
/**
* Get the number of tabulated functions that have been defined.
*/
int getNumTabulatedFunctions() const {
return functions.size();
}
/** /**
* Define a new global variable. * Define a new global variable.
* *
...@@ -495,6 +506,37 @@ public: ...@@ -495,6 +506,37 @@ public:
* will be an empty string. * will be an empty string.
*/ */
void getComputationStep(int index, ComputationType& type, std::string& variable, std::string& expression) const; void getComputationStep(int index, ComputationType& type, std::string& variable, std::string& expression) const;
/**
* Add a tabulated function that may appear in expressions.
*
* @param name the name of the function as it appears in expressions
* @param function a TabulatedFunction object defining the function. The TabulatedFunction
* should have been created on the heap with the "new" operator. The
* integrator takes over ownership of it, and deletes it when the integrator itself is deleted.
* @return the index of the function that was added
*/
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getTabulatedFunctionName(int index) const;
/** /**
* Get the expression to use for computing the kinetic energy. The expression is evaluated * Get the expression to use for computing the kinetic energy. The expression is evaluated
* for every degree of freedom. Those values are then added together, and the sum * for every degree of freedom. Those values are then added together, and the sum
...@@ -560,11 +602,13 @@ protected: ...@@ -560,11 +602,13 @@ protected:
double computeKineticEnergy(); double computeKineticEnergy();
private: private:
class ComputationInfo; class ComputationInfo;
class FunctionInfo;
std::vector<std::string> globalNames; std::vector<std::string> globalNames;
std::vector<std::string> perDofNames; std::vector<std::string> perDofNames;
mutable std::vector<double> globalValues; mutable std::vector<double> globalValues;
std::vector<std::vector<Vec3> > perDofValues; std::vector<std::vector<Vec3> > perDofValues;
std::vector<ComputationInfo> computations; std::vector<ComputationInfo> computations;
std::vector<FunctionInfo> functions;
std::string kineticEnergy; std::string kineticEnergy;
mutable bool globalsAreCurrent; mutable bool globalsAreCurrent;
int randomNumberSeed; int randomNumberSeed;
...@@ -587,6 +631,20 @@ public: ...@@ -587,6 +631,20 @@ public:
} }
}; };
/**
* This is an internal class used to record information about a tabulated function.
* @private
*/
class CustomIntegrator::FunctionInfo {
public:
std::string name;
TabulatedFunction* function;
FunctionInfo() {
}
FunctionInfo(const std::string& name, TabulatedFunction* function) : name(name), function(function) {
}
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_CUSTOMINTEGRATOR_H_*/ #endif /*OPENMM_CUSTOMINTEGRATOR_H_*/
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
*/ */
static void analyzeComputations(const ContextImpl& context, const CustomIntegrator& integrator, std::vector<std::vector<Lepton::ParsedExpression> >& expressions, static void analyzeComputations(const ContextImpl& context, const CustomIntegrator& integrator, std::vector<std::vector<Lepton::ParsedExpression> >& expressions,
std::vector<Comparison>& comparisons, std::vector<int>& blockEnd, std::vector<bool>& invalidatesForces, std::vector<bool>& needsForces, std::vector<Comparison>& comparisons, std::vector<int>& blockEnd, std::vector<bool>& invalidatesForces, std::vector<bool>& needsForces,
std::vector<bool>& needsEnergy, std::vector<bool>& computeBoth, std::vector<int>& forceGroup); std::vector<bool>& needsEnergy, std::vector<bool>& computeBoth, std::vector<int>& forceGroup, const std::map<std::string, Lepton::CustomFunction*>& functions);
/** /**
* Determine whether an expression involves a particular variable. * Determine whether an expression involves a particular variable.
*/ */
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2016 Stanford University and the Authors. * * Portions copyright (c) 2011-2017 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -48,6 +48,11 @@ CustomIntegrator::CustomIntegrator(double stepSize) : globalsAreCurrent(true), f ...@@ -48,6 +48,11 @@ CustomIntegrator::CustomIntegrator(double stepSize) : globalsAreCurrent(true), f
kineticEnergy = "m*v*v/2"; kineticEnergy = "m*v*v/2";
} }
CustomIntegrator::~CustomIntegrator() {
for (auto function : functions)
delete function.function;
}
void CustomIntegrator::initialize(ContextImpl& contextRef) { void CustomIntegrator::initialize(ContextImpl& contextRef) {
if (owner != NULL && &contextRef.getOwner() != owner) if (owner != NULL && &contextRef.getOwner() != owner)
throw OpenMMException("This Integrator is already bound to a context"); throw OpenMMException("This Integrator is already bound to a context");
...@@ -281,6 +286,26 @@ void CustomIntegrator::getComputationStep(int index, ComputationType& type, stri ...@@ -281,6 +286,26 @@ void CustomIntegrator::getComputationStep(int index, ComputationType& type, stri
expression = computations[index].expression; expression = computations[index].expression;
} }
int CustomIntegrator::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomIntegrator::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomIntegrator::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomIntegrator::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
const string& CustomIntegrator::getKineticEnergyExpression() const { const string& CustomIntegrator::getKineticEnergyExpression() const {
return kineticEnergy; return kineticEnergy;
} }
......
...@@ -71,7 +71,7 @@ bool CustomIntegratorUtilities::usesVariable(const Lepton::ParsedExpression& exp ...@@ -71,7 +71,7 @@ bool CustomIntegratorUtilities::usesVariable(const Lepton::ParsedExpression& exp
void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, const CustomIntegrator& integrator, vector<vector<Lepton::ParsedExpression> >& expressions, void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, const CustomIntegrator& integrator, vector<vector<Lepton::ParsedExpression> >& expressions,
vector<Comparison>& comparisons, vector<int>& blockEnd, vector<bool>& invalidatesForces, vector<bool>& needsForces, vector<bool>& needsEnergy, vector<Comparison>& comparisons, vector<int>& blockEnd, vector<bool>& invalidatesForces, vector<bool>& needsForces, vector<bool>& needsEnergy,
vector<bool>& computeBoth, vector<int>& forceGroup) { vector<bool>& computeBoth, vector<int>& forceGroup, const map<string, Lepton::CustomFunction*>& functions) {
int numSteps = integrator.getNumComputations(); int numSteps = integrator.getNumComputations();
expressions.resize(numSteps); expressions.resize(numSteps);
comparisons.resize(numSteps); comparisons.resize(numSteps);
...@@ -82,7 +82,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -82,7 +82,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
forceGroup.resize(numSteps, -2); forceGroup.resize(numSteps, -2);
vector<CustomIntegrator::ComputationType> stepType(numSteps); vector<CustomIntegrator::ComputationType> stepType(numSteps);
vector<string> stepVariable(numSteps); vector<string> stepVariable(numSteps);
map<string, Lepton::CustomFunction*> customFunctions; map<string, Lepton::CustomFunction*> customFunctions = functions;
DerivFunction derivFunction; DerivFunction derivFunction;
customFunctions["deriv"] = &derivFunction; customFunctions["deriv"] = &derivFunction;
......
...@@ -7155,7 +7155,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7155,7 +7155,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
vector<string> variable(numSteps); vector<string> variable(numSteps);
vector<int> forceGroup; vector<int> forceGroup;
vector<vector<Lepton::ParsedExpression> > expression; vector<vector<Lepton::ParsedExpression> > expression;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup); map<string, Lepton::CustomFunction*> functions;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup, functions);
for (int step = 0; step < numSteps; step++) { for (int step = 0; step < numSteps; step++) {
string expr; string expr;
integrator.getComputationStep(step, stepType[step], variable[step], expr); integrator.getComputationStep(step, stepType[step], variable[step], expr);
......
...@@ -7498,7 +7498,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7498,7 +7498,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
vector<string> variable(numSteps); vector<string> variable(numSteps);
vector<int> forceGroup; vector<int> forceGroup;
vector<vector<Lepton::ParsedExpression> > expression; vector<vector<Lepton::ParsedExpression> > expression;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup); map<string, Lepton::CustomFunction*> functions;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup, functions);
for (int step = 0; step < numSteps; step++) { for (int step = 0; step < numSteps; step++) {
string expr; string expr;
integrator.getComputationStep(step, stepType[step], variable[step], expr); integrator.getComputationStep(step, stepType[step], variable[step], expr);
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "SimTKOpenMMUtilities.h" #include "SimTKOpenMMUtilities.h"
#include "ReferenceVirtualSites.h" #include "ReferenceVirtualSites.h"
#include "ReferenceCustomDynamics.h" #include "ReferenceCustomDynamics.h"
#include "ReferenceTabulatedFunction.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/internal/ForceImpl.h" #include "openmm/internal/ForceImpl.h"
...@@ -113,12 +114,18 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<double>& m ...@@ -113,12 +114,18 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<double>& m
variableLocations[ename.str()] = &energy; variableLocations[ename.str()] = &energy;
} }
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < integrator.getNumTabulatedFunctions(); i++)
functions[integrator.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(integrator.getTabulatedFunction(i));
// Parse the expressions. // Parse the expressions.
int numSteps = stepType.size(); int numSteps = stepType.size();
vector<int> forceGroup; vector<int> forceGroup;
vector<vector<ParsedExpression> > expressions; vector<vector<ParsedExpression> > expressions;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expressions, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup); CustomIntegratorUtilities::analyzeComputations(context, integrator, expressions, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup, functions);
stepExpressions.resize(expressions.size()); stepExpressions.resize(expressions.size());
for (int i = 0; i < numSteps; i++) { for (int i = 0; i < numSteps; i++) {
stepExpressions[i].resize(expressions[i].size()); stepExpressions[i].resize(expressions[i].size());
...@@ -137,6 +144,11 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<double>& m ...@@ -137,6 +144,11 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<double>& m
if (kineticEnergyExpression.getVariables().find("f") != kineticEnergyExpression.getVariables().end()) if (kineticEnergyExpression.getVariables().find("f") != kineticEnergyExpression.getVariables().end())
kineticEnergyNeedsForce = true; kineticEnergyNeedsForce = true;
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
// Record the force group flags for each step. // Record the force group flags for each step.
forceGroupFlags.resize(numSteps, -1); forceGroupFlags.resize(numSteps, -1);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2017 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -895,6 +895,32 @@ void testChangeDT() { ...@@ -895,6 +895,32 @@ void testChangeDT() {
} }
} }
/**
* Test an integrator that uses a tabulated function.
*/
void testTabulatedFunction() {
System system;
system.addParticle(1.0);
CustomIntegrator integrator(1.0);
integrator.addGlobalVariable("global", 1.5);
integrator.addPerDofVariable("dof", 0.0);
integrator.addComputeGlobal("global", "fn(global)");
integrator.addComputePerDof("dof", "fn(x)");
vector<double> table;
table.push_back(10.0);
table.push_back(20.0);
integrator.addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 2.0));
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(1.2, 1.3, 1.4);
context.setPositions(positions);
integrator.step(1);
ASSERT_EQUAL_TOL(15.0, integrator.getGlobalVariable(0), 1e-5);
vector<Vec3> values;
integrator.getPerDofVariable(0, values);
ASSERT_EQUAL_VEC(Vec3(12.0, 13.0, 14.0), values[0], 1e-5);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -917,6 +943,7 @@ int main(int argc, char* argv[]) { ...@@ -917,6 +943,7 @@ int main(int argc, char* argv[]) {
testChangingGlobal(); testChangingGlobal();
testEnergyParameterDerivatives(); testEnergyParameterDerivatives();
testChangeDT(); testChangeDT();
testTabulatedFunction();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment