Commit 8c693ef7 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1816 from peastman/integratorfunctions

CustomIntegrator can use tabulated functions
parents 7e2d6195 212e8367
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2016 Stanford University and the Authors. * * Portions copyright (c) 2011-2017 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "Integrator.h" #include "Integrator.h"
#include "TabulatedFunction.h"
#include "Vec3.h" #include "Vec3.h"
#include "openmm/Kernel.h" #include "openmm/Kernel.h"
#include "internal/windowsExport.h" #include "internal/windowsExport.h"
...@@ -241,6 +242,9 @@ namespace OpenMM { ...@@ -241,6 +242,9 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta, select. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta, select. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* select(x,y,z) = z if x = 0, y otherwise. An expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * select(x,y,z) = z if x = 0, y otherwise. An expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
*
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in expressions.
*/ */
class OPENMM_EXPORT CustomIntegrator : public Integrator { class OPENMM_EXPORT CustomIntegrator : public Integrator {
...@@ -292,6 +296,7 @@ public: ...@@ -292,6 +296,7 @@ public:
* @param stepSize the step size with which to integrate the system (in picoseconds) * @param stepSize the step size with which to integrate the system (in picoseconds)
*/ */
CustomIntegrator(double stepSize); CustomIntegrator(double stepSize);
~CustomIntegrator();
/** /**
* Get the number of global variables that have been defined. * Get the number of global variables that have been defined.
*/ */
...@@ -310,6 +315,12 @@ public: ...@@ -310,6 +315,12 @@ public:
int getNumComputations() const { int getNumComputations() const {
return computations.size(); return computations.size();
} }
/**
* Get the number of tabulated functions that have been defined.
*/
int getNumTabulatedFunctions() const {
return functions.size();
}
/** /**
* Define a new global variable. * Define a new global variable.
* *
...@@ -495,6 +506,37 @@ public: ...@@ -495,6 +506,37 @@ public:
* will be an empty string. * will be an empty string.
*/ */
void getComputationStep(int index, ComputationType& type, std::string& variable, std::string& expression) const; void getComputationStep(int index, ComputationType& type, std::string& variable, std::string& expression) const;
/**
* Add a tabulated function that may appear in expressions.
*
* @param name the name of the function as it appears in expressions
* @param function a TabulatedFunction object defining the function. The TabulatedFunction
* should have been created on the heap with the "new" operator. The
* integrator takes over ownership of it, and deletes it when the integrator itself is deleted.
* @return the index of the function that was added
*/
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getTabulatedFunctionName(int index) const;
/** /**
* Get the expression to use for computing the kinetic energy. The expression is evaluated * Get the expression to use for computing the kinetic energy. The expression is evaluated
* for every degree of freedom. Those values are then added together, and the sum * for every degree of freedom. Those values are then added together, and the sum
...@@ -560,11 +602,13 @@ protected: ...@@ -560,11 +602,13 @@ protected:
double computeKineticEnergy(); double computeKineticEnergy();
private: private:
class ComputationInfo; class ComputationInfo;
class FunctionInfo;
std::vector<std::string> globalNames; std::vector<std::string> globalNames;
std::vector<std::string> perDofNames; std::vector<std::string> perDofNames;
mutable std::vector<double> globalValues; mutable std::vector<double> globalValues;
std::vector<std::vector<Vec3> > perDofValues; std::vector<std::vector<Vec3> > perDofValues;
std::vector<ComputationInfo> computations; std::vector<ComputationInfo> computations;
std::vector<FunctionInfo> functions;
std::string kineticEnergy; std::string kineticEnergy;
mutable bool globalsAreCurrent; mutable bool globalsAreCurrent;
int randomNumberSeed; int randomNumberSeed;
...@@ -587,6 +631,20 @@ public: ...@@ -587,6 +631,20 @@ public:
} }
}; };
/**
* This is an internal class used to record information about a tabulated function.
* @private
*/
class CustomIntegrator::FunctionInfo {
public:
std::string name;
TabulatedFunction* function;
FunctionInfo() {
}
FunctionInfo(const std::string& name, TabulatedFunction* function) : name(name), function(function) {
}
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_CUSTOMINTEGRATOR_H_*/ #endif /*OPENMM_CUSTOMINTEGRATOR_H_*/
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
*/ */
static void analyzeComputations(const ContextImpl& context, const CustomIntegrator& integrator, std::vector<std::vector<Lepton::ParsedExpression> >& expressions, static void analyzeComputations(const ContextImpl& context, const CustomIntegrator& integrator, std::vector<std::vector<Lepton::ParsedExpression> >& expressions,
std::vector<Comparison>& comparisons, std::vector<int>& blockEnd, std::vector<bool>& invalidatesForces, std::vector<bool>& needsForces, std::vector<Comparison>& comparisons, std::vector<int>& blockEnd, std::vector<bool>& invalidatesForces, std::vector<bool>& needsForces,
std::vector<bool>& needsEnergy, std::vector<bool>& computeBoth, std::vector<int>& forceGroup); std::vector<bool>& needsEnergy, std::vector<bool>& computeBoth, std::vector<int>& forceGroup, const std::map<std::string, Lepton::CustomFunction*>& functions);
/** /**
* Determine whether an expression involves a particular variable. * Determine whether an expression involves a particular variable.
*/ */
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2016 Stanford University and the Authors. * * Portions copyright (c) 2011-2017 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -48,6 +48,11 @@ CustomIntegrator::CustomIntegrator(double stepSize) : globalsAreCurrent(true), f ...@@ -48,6 +48,11 @@ CustomIntegrator::CustomIntegrator(double stepSize) : globalsAreCurrent(true), f
kineticEnergy = "m*v*v/2"; kineticEnergy = "m*v*v/2";
} }
CustomIntegrator::~CustomIntegrator() {
for (auto function : functions)
delete function.function;
}
void CustomIntegrator::initialize(ContextImpl& contextRef) { void CustomIntegrator::initialize(ContextImpl& contextRef) {
if (owner != NULL && &contextRef.getOwner() != owner) if (owner != NULL && &contextRef.getOwner() != owner)
throw OpenMMException("This Integrator is already bound to a context"); throw OpenMMException("This Integrator is already bound to a context");
...@@ -281,6 +286,26 @@ void CustomIntegrator::getComputationStep(int index, ComputationType& type, stri ...@@ -281,6 +286,26 @@ void CustomIntegrator::getComputationStep(int index, ComputationType& type, stri
expression = computations[index].expression; expression = computations[index].expression;
} }
int CustomIntegrator::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomIntegrator::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomIntegrator::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomIntegrator::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
const string& CustomIntegrator::getKineticEnergyExpression() const { const string& CustomIntegrator::getKineticEnergyExpression() const {
return kineticEnergy; return kineticEnergy;
} }
......
...@@ -71,7 +71,7 @@ bool CustomIntegratorUtilities::usesVariable(const Lepton::ParsedExpression& exp ...@@ -71,7 +71,7 @@ bool CustomIntegratorUtilities::usesVariable(const Lepton::ParsedExpression& exp
void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, const CustomIntegrator& integrator, vector<vector<Lepton::ParsedExpression> >& expressions, void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, const CustomIntegrator& integrator, vector<vector<Lepton::ParsedExpression> >& expressions,
vector<Comparison>& comparisons, vector<int>& blockEnd, vector<bool>& invalidatesForces, vector<bool>& needsForces, vector<bool>& needsEnergy, vector<Comparison>& comparisons, vector<int>& blockEnd, vector<bool>& invalidatesForces, vector<bool>& needsForces, vector<bool>& needsEnergy,
vector<bool>& computeBoth, vector<int>& forceGroup) { vector<bool>& computeBoth, vector<int>& forceGroup, const map<string, Lepton::CustomFunction*>& functions) {
int numSteps = integrator.getNumComputations(); int numSteps = integrator.getNumComputations();
expressions.resize(numSteps); expressions.resize(numSteps);
comparisons.resize(numSteps); comparisons.resize(numSteps);
...@@ -82,7 +82,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -82,7 +82,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
forceGroup.resize(numSteps, -2); forceGroup.resize(numSteps, -2);
vector<CustomIntegrator::ComputationType> stepType(numSteps); vector<CustomIntegrator::ComputationType> stepType(numSteps);
vector<string> stepVariable(numSteps); vector<string> stepVariable(numSteps);
map<string, Lepton::CustomFunction*> customFunctions; map<string, Lepton::CustomFunction*> customFunctions = functions;
DerivFunction derivFunction; DerivFunction derivFunction;
customFunctions["deriv"] = &derivFunction; customFunctions["deriv"] = &derivFunction;
......
...@@ -1485,7 +1485,9 @@ private: ...@@ -1485,7 +1485,9 @@ private:
class ReorderListener; class ReorderListener;
class GlobalTarget; class GlobalTarget;
class DerivFunction; class DerivFunction;
std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const std::string& forceName, const std::string& energyName); std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator,
const std::string& forceName, const std::string& energyName, std::vector<const TabulatedFunction*>& functions,
std::vector<std::pair<std::string, std::string> >& functionNames);
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid); void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context); Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes); void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes);
...@@ -1504,6 +1506,7 @@ private: ...@@ -1504,6 +1506,7 @@ private:
CudaArray* uniformRandoms; CudaArray* uniformRandoms;
CudaArray* randomSeed; CudaArray* randomSeed;
CudaArray* perDofEnergyParamDerivs; CudaArray* perDofEnergyParamDerivs;
std::vector<CudaArray*> tabulatedFunctions;
std::map<int, CudaArray*> savedForces; std::map<int, CudaArray*> savedForces;
std::set<int> validSavedForces; std::set<int> validSavedForces;
CudaParameterSet* perDofValues; CudaParameterSet* perDofValues;
......
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include "lepton/Operation.h" #include "lepton/Operation.h"
#include "lepton/Parser.h" #include "lepton/Parser.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include "ReferenceTabulatedFunction.h"
#include "SimTKOpenMMRealType.h" #include "SimTKOpenMMRealType.h"
#include "SimTKOpenMMUtilities.h" #include "SimTKOpenMMUtilities.h"
#include <algorithm> #include <algorithm>
...@@ -7061,6 +7062,8 @@ CudaIntegrateCustomStepKernel::~CudaIntegrateCustomStepKernel() { ...@@ -7061,6 +7062,8 @@ CudaIntegrateCustomStepKernel::~CudaIntegrateCustomStepKernel() {
delete perDofEnergyParamDerivs; delete perDofEnergyParamDerivs;
if (perDofValues != NULL) if (perDofValues != NULL)
delete perDofValues; delete perDofValues;
for (auto function : tabulatedFunctions)
delete function;
for (auto& f : savedForces) for (auto& f : savedForces)
delete f.second; delete f.second;
} }
...@@ -7078,7 +7081,8 @@ void CudaIntegrateCustomStepKernel::initialize(const System& system, const Custo ...@@ -7078,7 +7081,8 @@ void CudaIntegrateCustomStepKernel::initialize(const System& system, const Custo
SimTKOpenMMUtilities::setRandomNumberSeed(integrator.getRandomNumberSeed()); SimTKOpenMMUtilities::setRandomNumberSeed(integrator.getRandomNumberSeed());
} }
string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) { string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator,
const string& forceName, const string& energyName, vector<const TabulatedFunction*>& functions, vector<pair<string, string> >& functionNames) {
const string suffixes[] = {".x", ".y", ".z"}; const string suffixes[] = {".x", ".y", ".z"};
string suffix = suffixes[component]; string suffix = suffixes[component];
map<string, Lepton::ParsedExpression> expressions; map<string, Lepton::ParsedExpression> expressions;
...@@ -7111,8 +7115,6 @@ string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& vari ...@@ -7111,8 +7115,6 @@ string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& vari
variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i); variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "globals["+cu.intToString(parameterVariableIndex[i])+"]"; variables[parameterNames[i]] = "globals["+cu.intToString(parameterVariableIndex[i])+"]";
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
vector<pair<ExpressionTreeNode, string> > variableNodes; vector<pair<ExpressionTreeNode, string> > variableNodes;
findExpressionsForDerivs(expr.getRootNode(), variableNodes); findExpressionsForDerivs(expr.getRootNode(), variableNodes);
for (auto& var : variables) for (auto& var : variables)
...@@ -7149,13 +7151,35 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7149,13 +7151,35 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
defines["PADDED_NUM_ATOMS"] = cu.intToString(cu.getPaddedNumAtoms()); defines["PADDED_NUM_ATOMS"] = cu.intToString(cu.getPaddedNumAtoms());
defines["WORK_GROUP_SIZE"] = cu.intToString(CudaContext::ThreadBlockSize); defines["WORK_GROUP_SIZE"] = cu.intToString(CudaContext::ThreadBlockSize);
defines["SUM_BUFFER_SIZE"] = "0"; defines["SUM_BUFFER_SIZE"] = "0";
// Record the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionNames;
vector<const TabulatedFunction*> functionList;
vector<string> tableTypes;
for (int i = 0; i < integrator.getNumTabulatedFunctions(); i++) {
functionList.push_back(&integrator.getTabulatedFunction(i));
string name = integrator.getTabulatedFunctionName(i);
string arrayName = "table"+cu.intToString(i);
functionNames.push_back(make_pair(name, arrayName));
functions[name] = createReferenceTabulatedFunction(integrator.getTabulatedFunction(i));
int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(integrator.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
if (width == 1)
tableTypes.push_back("float");
else
tableTypes.push_back("float"+cu.intToString(width));
}
// Record information about all the computation steps. // Record information about all the computation steps.
vector<string> variable(numSteps); vector<string> variable(numSteps);
vector<int> forceGroup; vector<int> forceGroup;
vector<vector<Lepton::ParsedExpression> > expression; vector<vector<Lepton::ParsedExpression> > expression;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup); CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup, functions);
for (int step = 0; step < numSteps; step++) { for (int step = 0; step < numSteps; step++) {
string expr; string expr;
integrator.getComputationStep(step, stepType[step], variable[step], expr); integrator.getComputationStep(step, stepType[step], variable[step], expr);
...@@ -7326,7 +7350,7 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7326,7 +7350,7 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
if (numUniform > 0) if (numUniform > 0)
compute << "float4 uniform = uniformValues[uniformIndex+index];\n"; compute << "float4 uniform = uniformValues[uniformIndex+index];\n";
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; i++)
compute << createPerDofComputation(stepType[j] == CustomIntegrator::ComputePerDof ? variable[j] : "", expression[j][0], i, integrator, forceName[j], energyName[j]); compute << createPerDofComputation(stepType[j] == CustomIntegrator::ComputePerDof ? variable[j] : "", expression[j][0], i, integrator, forceName[j], energyName[j], functionList, functionNames);
if (variable[j] == "x") { if (variable[j] == "x") {
if (storePosAsDelta[j]) if (storePosAsDelta[j])
compute << "posDelta[index] = convertFromDouble4(position-convertToDouble4(loadPos(posq, posqCorrection, index)));\n"; compute << "posDelta[index] = convertFromDouble4(position-convertToDouble4(loadPos(posq, posqCorrection, index)));\n";
...@@ -7357,6 +7381,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7357,6 +7381,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
string valueName = "perDofValues"+cu.intToString(i+1); string valueName = "perDofValues"+cu.intToString(i+1);
args << ", " << buffer.getType() << "* __restrict__ " << valueName; args << ", " << buffer.getType() << "* __restrict__ " << valueName;
} }
for (int i = 0; i < (int) tableTypes.size(); i++)
args << ", const " << tableTypes[i]<< "* __restrict__ table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
if (loadPosAsDelta[step]) if (loadPosAsDelta[step])
defines["LOAD_POS_AS_DELTA"] = "1"; defines["LOAD_POS_AS_DELTA"] = "1";
...@@ -7386,6 +7412,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7386,6 +7412,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
args1.push_back(&perDofEnergyParamDerivs->getDevicePointer()); args1.push_back(&perDofEnergyParamDerivs->getDevicePointer());
for (auto& buffer : perDofValues->getBuffers()) for (auto& buffer : perDofValues->getBuffers())
args1.push_back(&buffer.getMemory()); args1.push_back(&buffer.getMemory());
for (auto array : tabulatedFunctions)
args1.push_back(&array->getDevicePointer());
kernelArgs[step].push_back(args1); kernelArgs[step].push_back(args1);
if (stepType[step] == CustomIntegrator::ComputeSum) { if (stepType[step] == CustomIntegrator::ComputeSum) {
// Create a second kernel for this step that sums the values. // Create a second kernel for this step that sums the values.
...@@ -7448,7 +7476,7 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7448,7 +7476,7 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
} }
Lepton::ParsedExpression keExpression = Lepton::Parser::parse(integrator.getKineticEnergyExpression()).optimize(); Lepton::ParsedExpression keExpression = Lepton::Parser::parse(integrator.getKineticEnergyExpression()).optimize();
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; i++)
computeKE << createPerDofComputation("", keExpression, i, integrator, "f", ""); computeKE << createPerDofComputation("", keExpression, i, integrator, "f", "", functionList, functionNames);
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_STEP"] = computeKE.str(); replacements["COMPUTE_STEP"] = computeKE.str();
stringstream args; stringstream args;
...@@ -7457,6 +7485,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7457,6 +7485,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
string valueName = "perDofValues"+cu.intToString(i+1); string valueName = "perDofValues"+cu.intToString(i+1);
args << ", " << buffer.getType() << "* __restrict__ " << valueName; args << ", " << buffer.getType() << "* __restrict__ " << valueName;
} }
for (int i = 0; i < (int) tableTypes.size(); i++)
args << ", const " << tableTypes[i]<< "* __restrict__ table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
defines["SUM_BUFFER_SIZE"] = cu.intToString(3*numAtoms); defines["SUM_BUFFER_SIZE"] = cu.intToString(3*numAtoms);
if (defines.find("LOAD_POS_AS_DELTA") != defines.end()) if (defines.find("LOAD_POS_AS_DELTA") != defines.end())
...@@ -7481,6 +7511,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7481,6 +7511,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
kineticEnergyArgs.push_back(&perDofEnergyParamDerivs->getDevicePointer()); kineticEnergyArgs.push_back(&perDofEnergyParamDerivs->getDevicePointer());
for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++) for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++)
kineticEnergyArgs.push_back(&perDofValues->getBuffers()[i].getMemory()); kineticEnergyArgs.push_back(&perDofValues->getBuffers()[i].getMemory());
for (auto array : tabulatedFunctions)
kineticEnergyArgs.push_back(&array->getDevicePointer());
keNeedsForce = usesVariable(keExpression, "f"); keNeedsForce = usesVariable(keExpression, "f");
// Create a second kernel to sum the values. // Create a second kernel to sum the values.
...@@ -7488,6 +7520,11 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7488,6 +7520,11 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
defines["SUM_BUFFER_SIZE"] = cu.intToString(3*numAtoms); defines["SUM_BUFFER_SIZE"] = cu.intToString(3*numAtoms);
module = cu.createModule(CudaKernelSources::customIntegrator, defines); module = cu.createModule(CudaKernelSources::customIntegrator, defines);
sumKineticEnergyKernel = cu.getKernel(module, useDouble ? "computeDoubleSum" : "computeFloatSum"); sumKineticEnergyKernel = cu.getKernel(module, useDouble ? "computeDoubleSum" : "computeFloatSum");
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
} }
// Make sure all values (variables, parameters, etc.) are up to date. // Make sure all values (variables, parameters, etc.) are up to date.
......
...@@ -1472,7 +1472,9 @@ private: ...@@ -1472,7 +1472,9 @@ private:
class ReorderListener; class ReorderListener;
class GlobalTarget; class GlobalTarget;
class DerivFunction; class DerivFunction;
std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const std::string& forceName, const std::string& energyName); std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator,
const std::string& forceName, const std::string& energyName, std::vector<const TabulatedFunction*>& functions,
std::vector<std::pair<std::string, std::string> >& functionNames);
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid); void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context); Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes); void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes);
...@@ -1491,6 +1493,7 @@ private: ...@@ -1491,6 +1493,7 @@ private:
OpenCLArray* uniformRandoms; OpenCLArray* uniformRandoms;
OpenCLArray* randomSeed; OpenCLArray* randomSeed;
OpenCLArray* perDofEnergyParamDerivs; OpenCLArray* perDofEnergyParamDerivs;
std::vector<OpenCLArray*> tabulatedFunctions;
std::map<int, OpenCLArray*> savedForces; std::map<int, OpenCLArray*> savedForces;
std::set<int> validSavedForces; std::set<int> validSavedForces;
OpenCLParameterSet* perDofValues; OpenCLParameterSet* perDofValues;
......
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include "lepton/Operation.h" #include "lepton/Operation.h"
#include "lepton/Parser.h" #include "lepton/Parser.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include "ReferenceTabulatedFunction.h"
#include "SimTKOpenMMRealType.h" #include "SimTKOpenMMRealType.h"
#include "SimTKOpenMMUtilities.h" #include "SimTKOpenMMUtilities.h"
#include <algorithm> #include <algorithm>
...@@ -7408,6 +7409,8 @@ OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() { ...@@ -7408,6 +7409,8 @@ OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() {
delete perDofEnergyParamDerivs; delete perDofEnergyParamDerivs;
if (perDofValues != NULL) if (perDofValues != NULL)
delete perDofValues; delete perDofValues;
for (auto function : tabulatedFunctions)
delete function;
for (auto& f : savedForces) for (auto& f : savedForces)
delete f.second; delete f.second;
} }
...@@ -7424,7 +7427,8 @@ void OpenCLIntegrateCustomStepKernel::initialize(const System& system, const Cus ...@@ -7424,7 +7427,8 @@ void OpenCLIntegrateCustomStepKernel::initialize(const System& system, const Cus
SimTKOpenMMUtilities::setRandomNumberSeed(integrator.getRandomNumberSeed()); SimTKOpenMMUtilities::setRandomNumberSeed(integrator.getRandomNumberSeed());
} }
string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) { string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator,
const string& forceName, const string& energyName, vector<const TabulatedFunction*>& functions, vector<pair<string, string> >& functionNames) {
const string suffixes[] = {".x", ".y", ".z"}; const string suffixes[] = {".x", ".y", ".z"};
string suffix = suffixes[component]; string suffix = suffixes[component];
map<string, Lepton::ParsedExpression> expressions; map<string, Lepton::ParsedExpression> expressions;
...@@ -7457,8 +7461,6 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va ...@@ -7457,8 +7461,6 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va
variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i); variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "globals["+cl.intToString(parameterVariableIndex[i])+"]"; variables[parameterNames[i]] = "globals["+cl.intToString(parameterVariableIndex[i])+"]";
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float"); string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float");
vector<pair<ExpressionTreeNode, string> > variableNodes; vector<pair<ExpressionTreeNode, string> > variableNodes;
findExpressionsForDerivs(expr.getRootNode(), variableNodes); findExpressionsForDerivs(expr.getRootNode(), variableNodes);
...@@ -7492,13 +7494,35 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7492,13 +7494,35 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
map<string, string> defines; map<string, string> defines;
defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms()); defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
defines["WORK_GROUP_SIZE"] = cl.intToString(OpenCLContext::ThreadBlockSize); defines["WORK_GROUP_SIZE"] = cl.intToString(OpenCLContext::ThreadBlockSize);
// Record the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionNames;
vector<const TabulatedFunction*> functionList;
vector<string> tableTypes;
for (int i = 0; i < integrator.getNumTabulatedFunctions(); i++) {
functionList.push_back(&integrator.getTabulatedFunction(i));
string name = integrator.getTabulatedFunctionName(i);
string arrayName = "table"+cl.intToString(i);
functionNames.push_back(make_pair(name, arrayName));
functions[name] = createReferenceTabulatedFunction(integrator.getTabulatedFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(integrator.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
if (width == 1)
tableTypes.push_back("float");
else
tableTypes.push_back("float"+cl.intToString(width));
}
// Record information about all the computation steps. // Record information about all the computation steps.
vector<string> variable(numSteps); vector<string> variable(numSteps);
vector<int> forceGroup; vector<int> forceGroup;
vector<vector<Lepton::ParsedExpression> > expression; vector<vector<Lepton::ParsedExpression> > expression;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup); CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup, functions);
for (int step = 0; step < numSteps; step++) { for (int step = 0; step < numSteps; step++) {
string expr; string expr;
integrator.getComputationStep(step, stepType[step], variable[step], expr); integrator.getComputationStep(step, stepType[step], variable[step], expr);
...@@ -7669,7 +7693,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7669,7 +7693,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
if (numUniform > 0) if (numUniform > 0)
compute << "float4 uniform = uniformValues[uniformIndex+index];\n"; compute << "float4 uniform = uniformValues[uniformIndex+index];\n";
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; i++)
compute << createPerDofComputation(stepType[j] == CustomIntegrator::ComputePerDof ? variable[j] : "", expression[j][0], i, integrator, forceName[j], energyName[j]); compute << createPerDofComputation(stepType[j] == CustomIntegrator::ComputePerDof ? variable[j] : "", expression[j][0], i, integrator, forceName[j], energyName[j], functionList, functionNames);
if (variable[j] == "x") { if (variable[j] == "x") {
if (storePosAsDelta[j]) { if (storePosAsDelta[j]) {
if (cl.getSupportsDoublePrecision()) if (cl.getSupportsDoublePrecision())
...@@ -7704,6 +7728,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7704,6 +7728,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
string valueName = "perDofValues"+cl.intToString(i+1); string valueName = "perDofValues"+cl.intToString(i+1);
args << ", __global " << buffer.getType() << "* restrict " << valueName; args << ", __global " << buffer.getType() << "* restrict " << valueName;
} }
for (int i = 0; i < (int) tableTypes.size(); i++)
args << ", __global const " << tableTypes[i]<< "* restrict table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
if (loadPosAsDelta[step]) if (loadPosAsDelta[step])
defines["LOAD_POS_AS_DELTA"] = "1"; defines["LOAD_POS_AS_DELTA"] = "1";
...@@ -7727,6 +7753,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7727,6 +7753,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
kernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs->getDeviceBuffer());
for (auto& buffer : perDofValues->getBuffers()) for (auto& buffer : perDofValues->getBuffers())
kernel.setArg<cl::Memory>(index++, buffer.getMemory()); kernel.setArg<cl::Memory>(index++, buffer.getMemory());
for (auto array : tabulatedFunctions)
kernel.setArg<cl::Buffer>(index++, array->getDeviceBuffer());
if (stepType[step] == CustomIntegrator::ComputeSum) { if (stepType[step] == CustomIntegrator::ComputeSum) {
// Create a second kernel for this step that sums the values. // Create a second kernel for this step that sums the values.
...@@ -7789,7 +7817,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7789,7 +7817,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
} }
Lepton::ParsedExpression keExpression = Lepton::Parser::parse(integrator.getKineticEnergyExpression()).optimize(); Lepton::ParsedExpression keExpression = Lepton::Parser::parse(integrator.getKineticEnergyExpression()).optimize();
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; i++)
computeKE << createPerDofComputation("", keExpression, i, integrator, "f", ""); computeKE << createPerDofComputation("", keExpression, i, integrator, "f", "", functionList, functionNames);
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_STEP"] = computeKE.str(); replacements["COMPUTE_STEP"] = computeKE.str();
stringstream args; stringstream args;
...@@ -7798,6 +7826,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7798,6 +7826,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
string valueName = "perDofValues"+cl.intToString(i+1); string valueName = "perDofValues"+cl.intToString(i+1);
args << ", __global " << buffer.getType() << "* restrict " << valueName; args << ", __global " << buffer.getType() << "* restrict " << valueName;
} }
for (int i = 0; i < (int) tableTypes.size(); i++)
args << ", __global const " << tableTypes[i]<< "* restrict table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
if (defines.find("LOAD_POS_AS_DELTA") != defines.end()) if (defines.find("LOAD_POS_AS_DELTA") != defines.end())
defines.erase("LOAD_POS_AS_DELTA"); defines.erase("LOAD_POS_AS_DELTA");
...@@ -7821,6 +7851,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7821,6 +7851,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
kineticEnergyKernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs->getDeviceBuffer()); kineticEnergyKernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs->getDeviceBuffer());
for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++) for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++)
kineticEnergyKernel.setArg<cl::Memory>(index++, perDofValues->getBuffers()[i].getMemory()); kineticEnergyKernel.setArg<cl::Memory>(index++, perDofValues->getBuffers()[i].getMemory());
for (auto array : tabulatedFunctions)
kineticEnergyKernel.setArg<cl::Buffer>(index++, array->getDeviceBuffer());
keNeedsForce = usesVariable(keExpression, "f"); keNeedsForce = usesVariable(keExpression, "f");
// Create a second kernel to sum the values. // Create a second kernel to sum the values.
...@@ -7831,8 +7863,13 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7831,8 +7863,13 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
sumKineticEnergyKernel.setArg<cl::Buffer>(index++, sumBuffer->getDeviceBuffer()); sumKineticEnergyKernel.setArg<cl::Buffer>(index++, sumBuffer->getDeviceBuffer());
sumKineticEnergyKernel.setArg<cl::Buffer>(index++, summedValue->getDeviceBuffer()); sumKineticEnergyKernel.setArg<cl::Buffer>(index++, summedValue->getDeviceBuffer());
sumKineticEnergyKernel.setArg<cl_int>(index++, 3*numAtoms); sumKineticEnergyKernel.setArg<cl_int>(index++, 3*numAtoms);
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
} }
// Make sure all values (variables, parameters, etc.) are up to date. // Make sure all values (variables, parameters, etc.) are up to date.
if (!deviceValuesAreCurrent) { if (!deviceValuesAreCurrent) {
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "SimTKOpenMMUtilities.h" #include "SimTKOpenMMUtilities.h"
#include "ReferenceVirtualSites.h" #include "ReferenceVirtualSites.h"
#include "ReferenceCustomDynamics.h" #include "ReferenceCustomDynamics.h"
#include "ReferenceTabulatedFunction.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/internal/ForceImpl.h" #include "openmm/internal/ForceImpl.h"
...@@ -113,12 +114,18 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<double>& m ...@@ -113,12 +114,18 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<double>& m
variableLocations[ename.str()] = &energy; variableLocations[ename.str()] = &energy;
} }
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < integrator.getNumTabulatedFunctions(); i++)
functions[integrator.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(integrator.getTabulatedFunction(i));
// Parse the expressions. // Parse the expressions.
int numSteps = stepType.size(); int numSteps = stepType.size();
vector<int> forceGroup; vector<int> forceGroup;
vector<vector<ParsedExpression> > expressions; vector<vector<ParsedExpression> > expressions;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expressions, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup); CustomIntegratorUtilities::analyzeComputations(context, integrator, expressions, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup, functions);
stepExpressions.resize(expressions.size()); stepExpressions.resize(expressions.size());
for (int i = 0; i < numSteps; i++) { for (int i = 0; i < numSteps; i++) {
stepExpressions[i].resize(expressions[i].size()); stepExpressions[i].resize(expressions[i].size());
...@@ -137,6 +144,11 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<double>& m ...@@ -137,6 +144,11 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<double>& m
if (kineticEnergyExpression.getVariables().find("f") != kineticEnergyExpression.getVariables().end()) if (kineticEnergyExpression.getVariables().find("f") != kineticEnergyExpression.getVariables().end())
kineticEnergyNeedsForce = true; kineticEnergyNeedsForce = true;
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
// Record the force group flags for each step. // Record the force group flags for each step.
forceGroupFlags.resize(numSteps, -1); forceGroupFlags.resize(numSteps, -1);
......
...@@ -40,22 +40,22 @@ CustomIntegratorProxy::CustomIntegratorProxy() : SerializationProxy("CustomInteg ...@@ -40,22 +40,22 @@ CustomIntegratorProxy::CustomIntegratorProxy() : SerializationProxy("CustomInteg
} }
void CustomIntegratorProxy::serialize(const void* object, SerializationNode& node) const { void CustomIntegratorProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1); node.setIntProperty("version", 2);
const CustomIntegrator& integrator = *reinterpret_cast<const CustomIntegrator*>(object); const CustomIntegrator& integrator = *reinterpret_cast<const CustomIntegrator*>(object);
SerializationNode& globalVariablesNode = node.createChildNode("GlobalVariables"); SerializationNode& globalVariablesNode = node.createChildNode("GlobalVariables");
for(int i=0; i<integrator.getNumGlobalVariables(); i++) { for (int i = 0; i < integrator.getNumGlobalVariables(); i++) {
globalVariablesNode.setDoubleProperty(integrator.getGlobalVariableName(i), integrator.getGlobalVariable(i)); globalVariablesNode.setDoubleProperty(integrator.getGlobalVariableName(i), integrator.getGlobalVariable(i));
} }
SerializationNode& perDofVariablesNode = node.createChildNode("PerDofVariables"); SerializationNode& perDofVariablesNode = node.createChildNode("PerDofVariables");
for(int i=0; i<integrator.getNumPerDofVariables(); i++) { for (int i = 0; i < integrator.getNumPerDofVariables(); i++) {
SerializationNode& perDofValuesNode = perDofVariablesNode.createChildNode(integrator.getPerDofVariableName(i)); SerializationNode& perDofValuesNode = perDofVariablesNode.createChildNode(integrator.getPerDofVariableName(i));
vector<Vec3> perDofValues; integrator.getPerDofVariable(i, perDofValues); vector<Vec3> perDofValues; integrator.getPerDofVariable(i, perDofValues);
for(int j=0; j<perDofValues.size(); j++) { for (int j = 0; j < perDofValues.size(); j++) {
perDofValuesNode.createChildNode("Value").setDoubleProperty("x",perDofValues[j][0]).setDoubleProperty("y",perDofValues[j][1]).setDoubleProperty("z",perDofValues[j][2]); perDofValuesNode.createChildNode("Value").setDoubleProperty("x",perDofValues[j][0]).setDoubleProperty("y",perDofValues[j][1]).setDoubleProperty("z",perDofValues[j][2]);
} }
} }
SerializationNode& computationsNode = node.createChildNode("Computations"); SerializationNode& computationsNode = node.createChildNode("Computations");
for(int i=0; i<integrator.getNumComputations(); i++) { for (int i = 0; i < integrator.getNumComputations(); i++) {
CustomIntegrator::ComputationType computationType; CustomIntegrator::ComputationType computationType;
string computationVariable; string computationVariable;
string computationExpression; string computationExpression;
...@@ -63,6 +63,9 @@ void CustomIntegratorProxy::serialize(const void* object, SerializationNode& nod ...@@ -63,6 +63,9 @@ void CustomIntegratorProxy::serialize(const void* object, SerializationNode& nod
computationsNode.createChildNode("Computation").setIntProperty("computationType",static_cast<int>(computationType)) computationsNode.createChildNode("Computation").setIntProperty("computationType",static_cast<int>(computationType))
.setStringProperty("computationVariable",computationVariable).setStringProperty("computationExpression",computationExpression); .setStringProperty("computationVariable",computationVariable).setStringProperty("computationExpression",computationExpression);
} }
SerializationNode& functions = node.createChildNode("Functions");
for (int i = 0; i < integrator.getNumTabulatedFunctions(); i++)
functions.createChildNode("Function", &integrator.getTabulatedFunction(i)).setStringProperty("name", integrator.getTabulatedFunctionName(i));
node.setStringProperty("kineticEnergyExpression",integrator.getKineticEnergyExpression()); node.setStringProperty("kineticEnergyExpression",integrator.getKineticEnergyExpression());
node.setIntProperty("randomSeed",integrator.getRandomNumberSeed()); node.setIntProperty("randomSeed",integrator.getRandomNumberSeed());
node.setDoubleProperty("stepSize",integrator.getStepSize()); node.setDoubleProperty("stepSize",integrator.getStepSize());
...@@ -70,7 +73,8 @@ void CustomIntegratorProxy::serialize(const void* object, SerializationNode& nod ...@@ -70,7 +73,8 @@ void CustomIntegratorProxy::serialize(const void* object, SerializationNode& nod
} }
void* CustomIntegratorProxy::deserialize(const SerializationNode& node) const { void* CustomIntegratorProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1) int version = node.getIntProperty("version");
if (version < 1 || version > 2)
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
CustomIntegrator* integrator = new CustomIntegrator(node.getDoubleProperty("stepSize")); CustomIntegrator* integrator = new CustomIntegrator(node.getDoubleProperty("stepSize"));
const SerializationNode& globalVariablesNode = node.getChildNode("GlobalVariables"); const SerializationNode& globalVariablesNode = node.getChildNode("GlobalVariables");
...@@ -112,6 +116,11 @@ void* CustomIntegratorProxy::deserialize(const SerializationNode& node) const { ...@@ -112,6 +116,11 @@ void* CustomIntegratorProxy::deserialize(const SerializationNode& node) const {
throw(OpenMMException("Custom Integrator Deserialization: Unknown computation type")); throw(OpenMMException("Custom Integrator Deserialization: Unknown computation type"));
} }
} }
if (version > 1) {
const SerializationNode& functions = node.getChildNode("Functions");
for (auto& function : functions.getChildren())
integrator->addTabulatedFunction(function.getStringProperty("name"), function.decodeObject<TabulatedFunction>());
}
integrator->setKineticEnergyExpression(node.getStringProperty("kineticEnergyExpression")); integrator->setKineticEnergyExpression(node.getStringProperty("kineticEnergyExpression"));
integrator->setRandomNumberSeed(node.getIntProperty("randomSeed")); integrator->setRandomNumberSeed(node.getIntProperty("randomSeed"));
integrator->setConstraintTolerance(node.getDoubleProperty("constraintTolerance")); integrator->setConstraintTolerance(node.getDoubleProperty("constraintTolerance"));
......
...@@ -156,6 +156,10 @@ void testSerializeCustomIntegrator() { ...@@ -156,6 +156,10 @@ void testSerializeCustomIntegrator() {
intg->addComputeSum("summand2", "v*v+f*f"); intg->addComputeSum("summand2", "v*v+f*f");
intg->setConstraintTolerance(1e-5); intg->setConstraintTolerance(1e-5);
intg->setKineticEnergyExpression("m*v1*v1/2; v1=v+0.5*dt*f/m"); intg->setKineticEnergyExpression("m*v1*v1/2; v1=v+0.5*dt*f/m");
vector<double> values(10);
for (int i = 0; i < 10; i++)
values[i] = sin((double) i);
intg->addTabulatedFunction("f", new Continuous1DFunction(values, 0.5, 1.5));
stringstream ss; stringstream ss;
XmlSerializer::serialize<Integrator>(intg, "CustomIntegrator", ss); XmlSerializer::serialize<Integrator>(intg, "CustomIntegrator", ss);
CustomIntegrator *intg2 = dynamic_cast<CustomIntegrator*>(XmlSerializer::deserialize<Integrator>(ss)); CustomIntegrator *intg2 = dynamic_cast<CustomIntegrator*>(XmlSerializer::deserialize<Integrator>(ss));
...@@ -190,6 +194,19 @@ void testSerializeCustomIntegrator() { ...@@ -190,6 +194,19 @@ void testSerializeCustomIntegrator() {
ASSERT_EQUAL(intg->getRandomNumberSeed(), intg2->getRandomNumberSeed()); ASSERT_EQUAL(intg->getRandomNumberSeed(), intg2->getRandomNumberSeed());
ASSERT_EQUAL(intg->getStepSize(), intg2->getStepSize()); ASSERT_EQUAL(intg->getStepSize(), intg2->getStepSize());
ASSERT_EQUAL(intg->getConstraintTolerance(), intg2->getConstraintTolerance()); ASSERT_EQUAL(intg->getConstraintTolerance(), intg2->getConstraintTolerance());
ASSERT_EQUAL(intg->getNumTabulatedFunctions(), intg2->getNumTabulatedFunctions());
for (int i = 0; i < intg->getNumTabulatedFunctions(); i++) {
double min1, min2, max1, max2;
vector<double> val1, val2;
dynamic_cast<Continuous1DFunction&>(intg->getTabulatedFunction(i)).getFunctionParameters(val1, min1, max1);
dynamic_cast<Continuous1DFunction&>(intg2->getTabulatedFunction(i)).getFunctionParameters(val2, min2, max2);
ASSERT_EQUAL(intg->getTabulatedFunctionName(i), intg2->getTabulatedFunctionName(i));
ASSERT_EQUAL(min1, min2);
ASSERT_EQUAL(max1, max2);
ASSERT_EQUAL(val1.size(), val2.size());
for (int j = 0; j < (int) val1.size(); j++)
ASSERT_EQUAL(val1[j], val2[j]);
}
delete intg; delete intg;
delete intg2; delete intg2;
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2017 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -895,6 +895,32 @@ void testChangeDT() { ...@@ -895,6 +895,32 @@ void testChangeDT() {
} }
} }
/**
* Test an integrator that uses a tabulated function.
*/
void testTabulatedFunction() {
System system;
system.addParticle(1.0);
CustomIntegrator integrator(1.0);
integrator.addGlobalVariable("global", 1.5);
integrator.addPerDofVariable("dof", 0.0);
integrator.addComputeGlobal("global", "fn(global)");
integrator.addComputePerDof("dof", "fn(x)");
vector<double> table;
table.push_back(10.0);
table.push_back(20.0);
integrator.addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 2.0));
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(1.2, 1.3, 1.4);
context.setPositions(positions);
integrator.step(1);
ASSERT_EQUAL_TOL(15.0, integrator.getGlobalVariable(0), 1e-5);
vector<Vec3> values;
integrator.getPerDofVariable(0, values);
ASSERT_EQUAL_VEC(Vec3(12.0, 13.0, 14.0), values[0], 1e-5);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -917,6 +943,7 @@ int main(int argc, char* argv[]) { ...@@ -917,6 +943,7 @@ int main(int argc, char* argv[]) {
testChangingGlobal(); testChangingGlobal();
testEnergyParameterDerivatives(); testEnergyParameterDerivatives();
testChangeDT(); testChangeDT();
testTabulatedFunction();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment