Commit 0fe7612b authored by peastman's avatar peastman
Browse files

Merge pull request #337 from peastman/functions

Created new API for tabulated functions
parents 7a7055b3 ed31a458
...@@ -214,7 +214,7 @@ void testCustomFunctions() { ...@@ -214,7 +214,7 @@ void testCustomFunctions() {
vector<double> function(2); vector<double> function(2);
function[0] = 0; function[0] = 0;
function[1] = 1; function[1] = 1;
custom->addFunction("foo", function, 0, 10); custom->addTabulatedFunction("foo", new Continuous1DFunction(function, 0, 10));
system.addForce(custom); system.addForce(custom);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(3); vector<Vec3> positions(3);
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -261,7 +261,7 @@ void testPeriodic() { ...@@ -261,7 +261,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction() { void testContinuous1DFunction() {
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
system.addParticle(1.0); system.addParticle(1.0);
...@@ -271,21 +271,20 @@ void testTabulatedFunction() { ...@@ -271,21 +271,20 @@ void testTabulatedFunction() {
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0); forceField->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0); positions[0] = Vec3(0, 0, 0);
double tol = 0.01;
for (int i = 1; i < 30; i++) { for (int i = 1; i < 30; i++) {
double x = (7.0/30.0)*i; double x = (7.0/30.0)*i;
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0)); double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1); ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
...@@ -295,11 +294,212 @@ void testTabulatedFunction() { ...@@ -295,11 +294,212 @@ void testTabulatedFunction() {
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Energy); State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
} }
} }
void testContinuous2DFunction() {
const int xsize = 20;
const int ysize = 21;
const double xmin = 0.4;
const double xmax = 1.5;
const double ymin = 0.0;
const double ymax = 2.1;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
double force = 0;
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = sin(0.25*x)*cos(0.33*y)+1.0;
force = -0.25*cos(0.25*x)*cos(0.33*y);
}
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
}
}
void testContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
const double zmin = 0.2;
const double zmax = 1.3;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
double z = zmin + k*(zmax-zmin)/zsize;
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(0.33*y)*(1+z);
}
}
}
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.1) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setParameter("b", z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
double force = 0;
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax && z >= zmin && z <= zmax) {
energy = sin(0.25*x)*cos(0.33*y)*(1.0+z)+1.0;
force = -0.25*cos(0.25*x)*cos(0.33*y)*(1.0+z);
}
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
void testDiscrete1DFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i));
forceField->addTabulatedFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i+1, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testDiscrete2DFunction() {
const int xsize = 10;
const int ysize = 5;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addTabulatedFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3((i%xsize)+1, 0, 0);
context.setPositions(positions);
context.setParameter("a", i/xsize);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testDiscrete3DFunction() {
const int xsize = 8;
const int ysize = 5;
const int zsize = 6;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addTabulatedFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3((i%xsize)+1, 0, 0);
context.setPositions(positions);
context.setParameter("a", (i/xsize)%ysize);
context.setParameter("b", i/(xsize*ysize));
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testCoulombLennardJones() { void testCoulombLennardJones() {
const int numMolecules = 300; const int numMolecules = 300;
const int numParticles = numMolecules*2; const int numParticles = numMolecules*2;
...@@ -725,7 +925,12 @@ int main(int argc, char* argv[]) { ...@@ -725,7 +925,12 @@ int main(int argc, char* argv[]) {
testExclusions(); testExclusions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(); testContinuous1DFunction();
testContinuous2DFunction();
testContinuous3DFunction();
testDiscrete1DFunction();
testDiscrete2DFunction();
testDiscrete3DFunction();
testCoulombLennardJones(); testCoulombLennardJones();
testParallelComputation(); testParallelComputation();
testSwitchingFunction(); testSwitchingFunction();
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2011 Stanford University and the Authors. * * Portions copyright (c) 2009-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "OpenCLContext.h" #include "OpenCLContext.h"
#include "openmm/TabulatedFunction.h"
#include "lepton/CustomFunction.h" #include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h" #include "lepton/ExpressionTreeNode.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
...@@ -45,75 +46,81 @@ namespace OpenMM { ...@@ -45,75 +46,81 @@ namespace OpenMM {
class OPENMM_EXPORT_OPENCL OpenCLExpressionUtilities { class OPENMM_EXPORT_OPENCL OpenCLExpressionUtilities {
public: public:
OpenCLExpressionUtilities(OpenCLContext& context) : context(context) { OpenCLExpressionUtilities(OpenCLContext& context);
}
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
* @param expressions the expressions to generate code for (keys are the variables to store the output values in) * @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are * @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are
* variable names, and the values are the code to generate for them. * variable names, and the values are the code to generate for them.
* @param functions defines the variable name for each tabulated function that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real") * @param tempType the type of value to use for temporary variables (defaults to "real")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="real"); const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& tempType="real");
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
* @param expressions the expressions to generate code for (keys are the variables to store the output values in) * @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions. * @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions.
* Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears. * Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears.
* @param functions defines the variable name for each tabulated function that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "float") * @param tempType the type of value to use for temporary variables (defaults to "float")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="float"); const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& tempType="float");
/** /**
* Calculate the spline coefficients for a tabulated function that appears in expressions. * Calculate the spline coefficients for a tabulated function that appears in expressions.
* *
* @param values the tabulated values of the function * @param function the function for which to compute coefficients
* @param min the value of the independent variable corresponding to the first element of values * @param width on output, the number of floats used for each value
* @param max the value of the independent variable corresponding to the last element of values
* @return the spline coefficients * @return the spline coefficients
*/ */
std::vector<mm_float4> computeFunctionCoefficients(const std::vector<double>& values, double min, double max); std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
class FunctionPlaceholder; /**
* Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
*
* @param function the function for which to get a placeholder
*/
Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function);
private: private:
class FunctionPlaceholder : public Lepton::CustomFunction {
public:
FunctionPlaceholder(int numArgs) : numArgs(numArgs) {
}
int getNumArguments() const {
return numArgs;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new FunctionPlaceholder(numArgs);
}
private:
int numArgs;
};
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node, void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode); std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers); std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
OpenCLContext& context; OpenCLContext& context;
}; FunctionPlaceholder fp1, fp2, fp3;
/**
* This class serves as a placeholder for custom functions in expressions.
*/
class OpenCLExpressionUtilities::FunctionPlaceholder : public Lepton::CustomFunction {
public:
int getNumArguments() const {
return 1;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new FunctionPlaceholder();
}
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -639,7 +639,7 @@ private: ...@@ -639,7 +639,7 @@ private:
class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel { class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public: public:
OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomNonbondedForceKernel(name, platform), OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomNonbondedForceKernel(name, platform),
cl(cl), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) { cl(cl), params(NULL), globals(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) {
} }
~OpenCLCalcCustomNonbondedForceKernel(); ~OpenCLCalcCustomNonbondedForceKernel();
/** /**
...@@ -670,7 +670,6 @@ private: ...@@ -670,7 +670,6 @@ private:
OpenCLContext& cl; OpenCLContext& cl;
OpenCLParameterSet* params; OpenCLParameterSet* params;
OpenCLArray* globals; OpenCLArray* globals;
OpenCLArray* tabulatedFunctionParams;
OpenCLArray* interactionGroupData; OpenCLArray* interactionGroupData;
cl::Kernel interactionGroupKernel; cl::Kernel interactionGroupKernel;
std::vector<void*> interactionGroupArgs; std::vector<void*> interactionGroupArgs;
...@@ -742,7 +741,7 @@ class OpenCLCalcCustomGBForceKernel : public CalcCustomGBForceKernel { ...@@ -742,7 +741,7 @@ class OpenCLCalcCustomGBForceKernel : public CalcCustomGBForceKernel {
public: public:
OpenCLCalcCustomGBForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomGBForceKernel(name, platform), OpenCLCalcCustomGBForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomGBForceKernel(name, platform),
hasInitializedKernels(false), cl(cl), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL), hasInitializedKernels(false), cl(cl), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL),
valueBuffers(NULL), longValueBuffers(NULL), tabulatedFunctionParams(NULL), system(system) { valueBuffers(NULL), longValueBuffers(NULL), system(system) {
} }
~OpenCLCalcCustomGBForceKernel(); ~OpenCLCalcCustomGBForceKernel();
/** /**
...@@ -780,7 +779,6 @@ private: ...@@ -780,7 +779,6 @@ private:
OpenCLArray* globals; OpenCLArray* globals;
OpenCLArray* valueBuffers; OpenCLArray* valueBuffers;
OpenCLArray* longValueBuffers; OpenCLArray* longValueBuffers;
OpenCLArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues; std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray*> tabulatedFunctions; std::vector<OpenCLArray*> tabulatedFunctions;
...@@ -841,8 +839,7 @@ class OpenCLCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel { ...@@ -841,8 +839,7 @@ class OpenCLCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel {
public: public:
OpenCLCalcCustomHbondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomHbondForceKernel(name, platform), OpenCLCalcCustomHbondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomHbondForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL), hasInitializedKernel(false), cl(cl), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL),
donorBufferIndices(NULL), acceptorBufferIndices(NULL), globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), donorBufferIndices(NULL), acceptorBufferIndices(NULL), globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), system(system) {
tabulatedFunctionParams(NULL), system(system) {
} }
~OpenCLCalcCustomHbondForceKernel(); ~OpenCLCalcCustomHbondForceKernel();
/** /**
...@@ -881,7 +878,6 @@ private: ...@@ -881,7 +878,6 @@ private:
OpenCLArray* acceptorBufferIndices; OpenCLArray* acceptorBufferIndices;
OpenCLArray* donorExclusions; OpenCLArray* donorExclusions;
OpenCLArray* acceptorExclusions; OpenCLArray* acceptorExclusions;
OpenCLArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues; std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray*> tabulatedFunctions; std::vector<OpenCLArray*> tabulatedFunctions;
...@@ -895,7 +891,7 @@ private: ...@@ -895,7 +891,7 @@ private:
class OpenCLCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel { class OpenCLCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel {
public: public:
OpenCLCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomCompoundBondForceKernel(name, platform), OpenCLCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomCompoundBondForceKernel(name, platform),
cl(cl), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), system(system) { cl(cl), params(NULL), globals(NULL), system(system) {
} }
~OpenCLCalcCustomCompoundBondForceKernel(); ~OpenCLCalcCustomCompoundBondForceKernel();
/** /**
...@@ -927,7 +923,6 @@ private: ...@@ -927,7 +923,6 @@ private:
OpenCLContext& cl; OpenCLContext& cl;
OpenCLParameterSet* params; OpenCLParameterSet* params;
OpenCLArray* globals; OpenCLArray* globals;
OpenCLArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues; std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray*> tabulatedFunctions; std::vector<OpenCLArray*> tabulatedFunctions;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2011 Stanford University and the Authors. * * Portions copyright (c) 2009-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,35 +33,40 @@ using namespace OpenMM; ...@@ -33,35 +33,40 @@ using namespace OpenMM;
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
OpenCLExpressionUtilities::OpenCLExpressionUtilities(OpenCLContext& context) : context(context), fp1(1), fp2(2), fp3(3) {
}
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes; vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter) for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second)); variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, prefix, functionParams, tempType); return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
} }
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables, string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
stringstream out; stringstream out;
vector<ParsedExpression> allExpressions; vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second); allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables; vector<pair<ExpressionTreeNode, string> > temps = variables;
vector<vector<double> > functionParams = computeFunctionParameters(functions);
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) { for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, prefix, functionParams, allExpressions, tempType); processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n"; out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
} }
return out.str(); return out.str();
} }
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps, void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const vector<ParsedExpression>& allExpressions, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++) for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node) if (temps[i].first == node)
return; return;
for (int i = 0; i < (int) node.getChildren().size(); i++) for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, functions, prefix, functionParams, allExpressions, tempType); processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
string name = prefix+context.intToString(temps.size()); string name = prefix+context.intToString(temps.size());
bool hasRecordedNode = false; bool hasRecordedNode = false;
...@@ -75,11 +80,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -75,11 +80,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::CUSTOM: case Operation::CUSTOM:
{ {
int i; int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++) for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
; ;
if (i == functions.size()) if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
out << "0.0f;\n"; out << "0.0f;\n";
temps.push_back(make_pair(node, name)); temps.push_back(make_pair(node, name));
hasRecordedNode = true; hasRecordedNode = true;
...@@ -87,39 +91,190 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -87,39 +91,190 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
// If both the value and derivative of the function are needed, it's faster to calculate them both // If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed. // at once, so check to see if both are needed.
const ExpressionTreeNode* valueNode = NULL; vector<const ExpressionTreeNode*> nodes;
const ExpressionTreeNode* derivNode = NULL;
for (int j = 0; j < (int) allExpressions.size(); j++) for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), valueNode, derivNode); findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), nodes);
string valueName = name; vector<string> nodeNames;
string derivName = name; nodeNames.push_back(name);
if (valueNode != NULL && derivNode != NULL) { for (int j = 1; j < (int) nodes.size(); j++) {
string name2 = prefix+context.intToString(temps.size()); string name2 = prefix+context.intToString(temps.size());
out << tempType << " " << name2 << " = 0.0f;\n"; out << tempType << " " << name2 << " = 0.0f;\n";
if (isDeriv) { nodeNames.push_back(name2);
valueName = name2; temps.push_back(make_pair(*nodes[j], name2));
temps.push_back(make_pair(*valueNode, name2)); }
out << "{\n";
vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j]));
}
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x-" << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0)
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
} }
else { out << "}\n";
derivName = name2; }
temps.push_back(make_pair(*derivNode, name2)); else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
out << "x = (x-" << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "y = (y-" << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n";
for (int j = 0; j < 4; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
out << nodeNames[j] << " *= " << paramsFloat[6] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
out << nodeNames[j] << " *= " << paramsFloat[7] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
}
out << "}\n";
}
else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "real z = " << getTempName(node.getChildren()[2], temps) << ";\n";
out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
out << "x = (x-" << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
out << "y = (y-" << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
out << "z = (z-" << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << ");\n";
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
out << "float4 c[16];\n";
for (int j = 0; j < 16; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
out << "real dc = z-u;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real value[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real derivx[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
}
out << nodeNames[j] << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[9] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
const string suffixes[] = {".x", ".y", ".z", ".w"};
out << "real derivy[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = 4*m;
string suffix = suffixes[m];
out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
}
out << nodeNames[j] << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[10] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
out << "real derivz[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
out << nodeNames[j] << " *= " << paramsFloat[11] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
}
out << "}\n";
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "int index = (int) round(x);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
}
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int xsize = " << paramsInt[0] << ";\n";
out << "int ysize = " << paramsInt[1] << ";\n";
out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
}
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n";
out << "int xsize = " << paramsInt[0] << ";\n";
out << "int ysize = " << paramsInt[1] << ";\n";
out << "int zsize = " << paramsInt[2] << ";\n";
out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
} }
} }
out << "{\n";
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "x = (x-params.x)*params.z;\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n";
out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "float b = x-index;\n";
out << "float a = 1.0f-b;\n";
if (valueNode != NULL)
out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
if (derivNode != NULL)
out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
out << "}\n";
out << "}"; out << "}";
break; break;
} }
...@@ -312,16 +467,27 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co ...@@ -312,16 +467,27 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
} }
void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
const ExpressionTreeNode*& valueNode, const ExpressionTreeNode*& derivNode) { vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0]) { if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
if (dynamic_cast<const Operation::Custom*>(&searchNode.getOperation())->getDerivOrder()[0] == 0) // Make sure the arguments are identical.
valueNode = &searchNode;
else for (int i = 0; i < (int) node.getChildren().size(); i++)
derivNode = &searchNode; if (node.getChildren()[i] != searchNode.getChildren()[i])
return;
// See if we already have an identical node.
for (int i = 0; i < (int) nodes.size(); i++)
if (*nodes[i] == searchNode)
return;
// Add the node.
nodes.push_back(&searchNode);
} }
else else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++) for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], valueNode, derivNode); findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
} }
void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) { void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
...@@ -341,16 +507,209 @@ void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node ...@@ -341,16 +507,209 @@ void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node
findRelatedPowers(node, searchNode.getChildren()[i], powers); findRelatedPowers(node, searchNode.getChildren()[i], powers);
} }
vector<mm_float4> OpenCLExpressionUtilities::computeFunctionCoefficients(const vector<double>& values, double min, double max) { vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
// Compute the spline coefficients. if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
vector<double> values;
double min, max;
fn.getFunctionParameters(values, min, max);
int numValues = values.size();
vector<double> x(numValues), derivs;
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs);
vector<float> f(4*(numValues-1));
for (int i = 0; i < (int) values.size()-1; i++) {
f[4*i] = (float) values[i];
f[4*i+1] = (float) values[i+1];
f[4*i+2] = (float) (derivs[i]/6.0);
f[4*i+3] = (float) (derivs[i+1]/6.0);
}
width = 4;
return f;
}
if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(function);
vector<double> values;
int xsize, ysize;
double xmin, xmax, ymin, ymax;
fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
vector<double> x(xsize), y(ysize);
for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1);
vector<vector<double> > c;
SplineFitter::create2DNaturalSpline(x, y, values, c);
vector<float> f(16*c.size());
for (int i = 0; i < (int) c.size(); i++) {
for (int j = 0; j < 16; j++)
f[16*i+j] = (float) c[i][j];
}
width = 4;
return f;
}
if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(function);
vector<double> values;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
vector<double> x(xsize), y(ysize), z(zsize);
for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1);
for (int i = 0; i < zsize; i++)
z[i] = zmin+i*(zmax-zmin)/(zsize-1);
vector<vector<double> > c;
SplineFitter::create3DNaturalSpline(x, y, z, values, c);
vector<float> f(64*c.size());
for (int i = 0; i < (int) c.size(); i++) {
for (int j = 0; j < 64; j++)
f[64*i+j] = (float) c[i][j];
}
width = 4;
return f;
}
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
vector<double> values;
fn.getFunctionParameters(values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(function);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(function);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}
vector<vector<double> > OpenCLExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
vector<vector<double> > params(functions.size());
for (int i = 0; i < (int) functions.size(); i++) {
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
vector<double> values;
double min, max;
fn.getFunctionParameters(values, min, max);
params[i].push_back(min);
params[i].push_back(max);
params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2);
}
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(*functions[i]);
vector<double> values;
int xsize, ysize;
double xmin, xmax, ymin, ymax;
fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
params[i].push_back(xsize-1);
params[i].push_back(ysize-1);
params[i].push_back(xmin);
params[i].push_back(xmax);
params[i].push_back(ymin);
params[i].push_back(ymax);
params[i].push_back((xsize-1)/(xmax-xmin));
params[i].push_back((ysize-1)/(ymax-ymin));
}
else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(*functions[i]);
vector<double> values;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
params[i].push_back(xsize-1);
params[i].push_back(ysize-1);
params[i].push_back(zsize-1);
params[i].push_back(xmin);
params[i].push_back(xmax);
params[i].push_back(ymin);
params[i].push_back(ymax);
params[i].push_back(zmin);
params[i].push_back(zmax);
params[i].push_back((xsize-1)/(xmax-xmin));
params[i].push_back((ysize-1)/(ymax-ymin));
params[i].push_back((zsize-1)/(zmax-zmin));
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values;
fn.getFunctionParameters(values);
params[i].push_back(values.size());
}
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
params[i].push_back(xsize);
params[i].push_back(ysize);
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
params[i].push_back(xsize);
params[i].push_back(ysize);
params[i].push_back(zsize);
}
else
throw OpenMMException("computeFunctionParameters: Unknown function type");
}
return params;
}
int numValues = values.size(); Lepton::CustomFunction* OpenCLExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
vector<double> x(numValues), derivs; if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
for (int i = 0; i < numValues; i++) return &fp1;
x[i] = min+i*(max-min)/(numValues-1); if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
SplineFitter::createNaturalSpline(x, values, derivs); return &fp2;
vector<mm_float4> f(numValues-1); if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
for (int i = 0; i < (int) values.size()-1; i++) return &fp3;
f[i] = mm_float4((cl_float) values[i], (cl_float) values[i+1], (cl_float) (derivs[i]/6.0), (cl_float) (derivs[i+1]/6.0)); if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
return f; return &fp1;
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
return &fp2;
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
return &fp3;
throw OpenMMException("getFunctionPlaceholder: Unknown function type");
} }
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -614,8 +614,9 @@ void OpenCLCalcCustomBondForceKernel::initialize(const System& system, const Cus ...@@ -614,8 +614,9 @@ void OpenCLCalcCustomBondForceKernel::initialize(const System& system, const Cus
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup());
...@@ -843,8 +844,9 @@ void OpenCLCalcCustomAngleForceKernel::initialize(const System& system, const Cu ...@@ -843,8 +844,9 @@ void OpenCLCalcCustomAngleForceKernel::initialize(const System& system, const Cu
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup());
...@@ -1247,8 +1249,9 @@ void OpenCLCalcCustomTorsionForceKernel::initialize(const System& system, const ...@@ -1247,8 +1249,9 @@ void OpenCLCalcCustomTorsionForceKernel::initialize(const System& system, const
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
...@@ -1923,8 +1926,6 @@ OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() { ...@@ -1923,8 +1926,6 @@ OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() {
delete params; delete params;
if (globals != NULL) if (globals != NULL)
delete globals; delete globals;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
if (interactionGroupData != NULL) if (interactionGroupData != NULL)
delete interactionGroupData; delete interactionGroupData;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
...@@ -1965,28 +1966,20 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -1965,28 +1966,20 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
// Record the tabulated functions. // Record the tabulated functions.
OpenCLExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<mm_float4> tabulatedFunctionParamsVec(force.getNumFunctions()); vector<const TabulatedFunction*> functionList;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
string name; functionList.push_back(&force.getTabulatedFunction(i));
vector<double> values; string name = force.getTabulatedFunctionName(i);
double min, max;
force.getFunctionParameters(i, name, values, min, max);
string arrayName = prefix+"table"+cl.intToString(i); string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); int width;
vector<mm_float4> f = cl.getExpressionUtilities().computeFunctionCoefficients(values, min, max); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<mm_float4>(cl, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
}
if (force.getNumFunctions() > 0) {
tabulatedFunctionParams = OpenCLArray::create<mm_float4>(cl, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters", CL_MEM_READ_ONLY);
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(prefix+"functionParams", "float", 4, sizeof(cl_float4), tabulatedFunctionParams->getDeviceBuffer()));
} }
// Record information for the expressions. // Record information for the expressions.
...@@ -2025,7 +2018,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -2025,7 +2018,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
variables.push_back(makeVariable(name, prefix+value)); variables.push_back(makeVariable(name, prefix+value));
} }
stringstream compute; stringstream compute;
compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, prefix+"temp", prefix+"functionParams"); compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, prefix+"temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0"); replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
...@@ -2664,8 +2657,6 @@ OpenCLCalcCustomGBForceKernel::~OpenCLCalcCustomGBForceKernel() { ...@@ -2664,8 +2657,6 @@ OpenCLCalcCustomGBForceKernel::~OpenCLCalcCustomGBForceKernel() {
delete valueBuffers; delete valueBuffers;
if (longValueBuffers != NULL) if (longValueBuffers != NULL)
delete longValueBuffers; delete longValueBuffers;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
delete tabulatedFunctions[i]; delete tabulatedFunctions[i];
} }
...@@ -2722,31 +2713,25 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2722,31 +2713,25 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
// Record the tabulated functions. // Record the tabulated functions.
OpenCLExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<mm_float4> tabulatedFunctionParamsVec(force.getNumFunctions()); vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name; functionList.push_back(&force.getTabulatedFunction(i));
vector<double> values; string name = force.getTabulatedFunctionName(i);
double min, max;
force.getFunctionParameters(i, name, values, min, max);
string arrayName = prefix+"table"+cl.intToString(i); string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); int width;
vector<mm_float4> f = cl.getExpressionUtilities().computeFunctionCoefficients(values, min, max); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<mm_float4>(cl, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
tableArgs << ", __global const float4* restrict " << arrayName; tableArgs << ", __global const float";
} if (width > 1)
if (force.getNumFunctions() > 0) { tableArgs << width;
tabulatedFunctionParams = OpenCLArray::create<mm_float4>(cl, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters", CL_MEM_READ_ONLY); tableArgs << "* restrict " << arrayName;
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(prefix+"functionParams", "float", 4, sizeof(cl_float4), tabulatedFunctionParams->getDeviceBuffer()));
tableArgs << ", __global const float4* " << prefix << "functionParams";
} }
// Record the global parameters. // Record the global parameters.
...@@ -2838,7 +2823,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2838,7 +2823,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize(); Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueExpressions["tempValue1 = "] = ex; n2ValueExpressions["tempValue1 = "] = ex;
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename); n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp");
map<string, string> replacements; map<string, string> replacements;
string n2ValueStr = n2ValueSource.str(); string n2ValueStr = n2ValueSource.str();
replacements["COMPUTE_VALUE"] = n2ValueStr; replacements["COMPUTE_VALUE"] = n2ValueStr;
...@@ -2914,7 +2899,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2914,7 +2899,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1); variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
map<string, Lepton::ParsedExpression> valueExpressions; map<string, Lepton::ParsedExpression> valueExpressions;
valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize(); valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
reductionSource << cl.getExpressionUtilities().createExpressions(valueExpressions, variables, functionDefinitions, "value"+cl.intToString(i)+"_temp", prefix+"functionParams"); reductionSource << cl.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cl.intToString(i)+"_temp");
} }
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) { for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
string valueName = "values"+cl.intToString(i+1); string valueName = "values"+cl.intToString(i+1);
...@@ -2978,7 +2963,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2978,7 +2963,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
} }
if (exclude) if (exclude)
n2EnergySource << "if (!isExcluded) {\n"; n2EnergySource << "if (!isExcluded) {\n";
n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp");
if (exclude) if (exclude)
n2EnergySource << "}\n"; n2EnergySource << "}\n";
} }
...@@ -3149,7 +3134,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3149,7 +3134,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
for (int i = 1; i < force.getNumComputedValues(); i++) for (int i = 1; i < force.getNumComputedValues(); i++)
for (int j = 0; j < i; j++) for (int j = 0; j < i; j++)
expressions["real dV"+cl.intToString(i)+"dV"+cl.intToString(j)+" = "] = valueDerivExpressions[i][j]; expressions["real dV"+cl.intToString(i)+"dV"+cl.intToString(j)+" = "] = valueDerivExpressions[i][j];
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "temp", prefix+"functionParams"); compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "temp");
// Record values. // Record values.
...@@ -3219,7 +3204,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3219,7 +3204,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
map<string, Lepton::ParsedExpression> derivExpressions; map<string, Lepton::ParsedExpression> derivExpressions;
string js = cl.intToString(j); string js = cl.intToString(j);
derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j]; derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams"); compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js);
compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n"; compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
} }
} }
...@@ -3230,7 +3215,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3230,7 +3215,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1]; gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2])) if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2]; gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp");
} }
for (int i = 1; i < force.getNumComputedValues(); i++) { for (int i = 1; i < force.getNumComputedValues(); i++) {
string is = cl.intToString(i); string is = cl.intToString(i);
...@@ -3271,7 +3256,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3271,7 +3256,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize(); Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["real dV0dR1 = "] = dVdR; derivExpressions["real dV0dR1 = "] = dVdR;
derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename); derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams"); chainSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, prefix+"temp0_");
if (needChainForValue[0]) { if (needChainForValue[0]) {
if (useExclusionsForValue) if (useExclusionsForValue)
chainSource << "if (!isExcluded) {\n"; chainSource << "if (!isExcluded) {\n";
...@@ -3412,11 +3397,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3412,11 +3397,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL); pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
} }
} }
if (tabulatedFunctionParams != NULL) { for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) pairValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
pairValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
pairValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctionParams->getDeviceBuffer());
}
index = 0; index = 0;
perParticleValueKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms()); perParticleValueKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms());
perParticleValueKernel.setArg<cl_int>(index++, nb.getNumForceBuffers()); perParticleValueKernel.setArg<cl_int>(index++, nb.getNumForceBuffers());
...@@ -3428,11 +3410,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3428,11 +3410,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
perParticleValueKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory()); perParticleValueKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory());
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) for (int i = 0; i < (int) computedValues->getBuffers().size(); i++)
perParticleValueKernel.setArg<cl::Memory>(index++, computedValues->getBuffers()[i].getMemory()); perParticleValueKernel.setArg<cl::Memory>(index++, computedValues->getBuffers()[i].getMemory());
if (tabulatedFunctionParams != NULL) { for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) perParticleValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
perParticleValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
perParticleValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctionParams->getDeviceBuffer());
}
index = 0; index = 0;
pairEnergyKernel.setArg<cl::Buffer>(index++, useLong ? cl.getLongForceBuffer().getDeviceBuffer() : cl.getForceBuffers().getDeviceBuffer()); pairEnergyKernel.setArg<cl::Buffer>(index++, useLong ? cl.getLongForceBuffer().getDeviceBuffer() : cl.getForceBuffers().getDeviceBuffer());
pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer()); pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
...@@ -3480,11 +3459,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3480,11 +3459,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL); pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
} }
} }
if (tabulatedFunctionParams != NULL) { for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) pairEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
pairEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
pairEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctionParams->getDeviceBuffer());
}
index = 0; index = 0;
perParticleEnergyKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms()); perParticleEnergyKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms());
perParticleEnergyKernel.setArg<cl_int>(index++, nb.getNumForceBuffers()); perParticleEnergyKernel.setArg<cl_int>(index++, nb.getNumForceBuffers());
...@@ -3503,11 +3479,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3503,11 +3479,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
perParticleEnergyKernel.setArg<cl::Memory>(index++, energyDerivChain->getBuffers()[i].getMemory()); perParticleEnergyKernel.setArg<cl::Memory>(index++, energyDerivChain->getBuffers()[i].getMemory());
if (useLong) if (useLong)
perParticleEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs->getDeviceBuffer()); perParticleEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs->getDeviceBuffer());
if (tabulatedFunctionParams != NULL) { for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) perParticleEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
perParticleEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
perParticleEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctionParams->getDeviceBuffer());
}
if (needParameterGradient) { if (needParameterGradient) {
index = 0; index = 0;
gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer()); gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
...@@ -3681,8 +3654,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const ...@@ -3681,8 +3654,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customExternalForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customExternalForce, replacements), force.getForceGroup());
...@@ -3825,8 +3799,6 @@ OpenCLCalcCustomHbondForceKernel::~OpenCLCalcCustomHbondForceKernel() { ...@@ -3825,8 +3799,6 @@ OpenCLCalcCustomHbondForceKernel::~OpenCLCalcCustomHbondForceKernel() {
delete donorExclusions; delete donorExclusions;
if (acceptorExclusions != NULL) if (acceptorExclusions != NULL)
delete acceptorExclusions; delete acceptorExclusions;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
delete tabulatedFunctions[i]; delete tabulatedFunctions[i];
} }
...@@ -3947,29 +3919,24 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -3947,29 +3919,24 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
// Record the tabulated functions. // Record the tabulated functions.
OpenCLExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<mm_float4> tabulatedFunctionParamsVec(force.getNumFunctions()); vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
string name; functionList.push_back(&force.getTabulatedFunction(i));
vector<double> values; string name = force.getTabulatedFunctionName(i);
double min, max;
force.getFunctionParameters(i, name, values, min, max);
string arrayName = "table"+cl.intToString(i); string arrayName = "table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); int width;
vector<mm_float4> f = cl.getExpressionUtilities().computeFunctionCoefficients(values, min, max); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<mm_float4>(cl, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", __global const float4* restrict " << arrayName; tableArgs << ", __global const float";
} if (width > 1)
if (force.getNumFunctions() > 0) { tableArgs << width;
tabulatedFunctionParams = OpenCLArray::create<mm_float4>(cl, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters", CL_MEM_READ_ONLY); tableArgs << "* restrict " << arrayName;
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
tableArgs << ", __global const float4* restrict functionParams";
} }
// Record information about parameters. // Record information about parameters.
...@@ -4085,9 +4052,9 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -4085,9 +4052,9 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
// Now evaluate the expressions. // Now evaluate the expressions.
computeAcceptor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams"); computeAcceptor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
forceExpressions["energy += "] = energyExpression; forceExpressions["energy += "] = energyExpression;
computeDonor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams"); computeDonor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
// Finally, apply forces to atoms. // Finally, apply forces to atoms.
...@@ -4200,11 +4167,8 @@ double OpenCLCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl ...@@ -4200,11 +4167,8 @@ double OpenCLCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
const OpenCLNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i]; const OpenCLNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i];
donorKernel.setArg<cl::Memory>(index++, buffer.getMemory()); donorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
} }
if (tabulatedFunctionParams != NULL) { for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) donorKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
donorKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
donorKernel.setArg<cl::Buffer>(index++, tabulatedFunctionParams->getDeviceBuffer());
}
index = 0; index = 0;
acceptorKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer()); acceptorKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
acceptorKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer()); acceptorKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
...@@ -4225,11 +4189,8 @@ double OpenCLCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl ...@@ -4225,11 +4189,8 @@ double OpenCLCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
const OpenCLNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i]; const OpenCLNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i];
acceptorKernel.setArg<cl::Memory>(index++, buffer.getMemory()); acceptorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
} }
if (tabulatedFunctionParams != NULL) { for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) acceptorKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
acceptorKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
acceptorKernel.setArg<cl::Buffer>(index++, tabulatedFunctionParams->getDeviceBuffer());
}
} }
setPeriodicBoxSizeArg(cl, donorKernel, 8); setPeriodicBoxSizeArg(cl, donorKernel, 8);
setInvPeriodicBoxSizeArg(cl, donorKernel, 9); setInvPeriodicBoxSizeArg(cl, donorKernel, 9);
...@@ -4314,8 +4275,6 @@ OpenCLCalcCustomCompoundBondForceKernel::~OpenCLCalcCustomCompoundBondForceKerne ...@@ -4314,8 +4275,6 @@ OpenCLCalcCustomCompoundBondForceKernel::~OpenCLCalcCustomCompoundBondForceKerne
delete params; delete params;
if (globals != NULL) if (globals != NULL)
delete globals; delete globals;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
delete tabulatedFunctions[i]; delete tabulatedFunctions[i];
} }
...@@ -4343,31 +4302,22 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -4343,31 +4302,22 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
// Record the tabulated functions. // Record the tabulated functions.
OpenCLExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<mm_float4> tabulatedFunctionParamsVec(force.getNumFunctions()); vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
string name; functionList.push_back(&force.getTabulatedFunction(i));
vector<double> values; string name = force.getTabulatedFunctionName(i);
double min, max; functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
force.getFunctionParameters(i, name, values, min, max); int width;
functions[name] = &fp; vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction");
vector<mm_float4> f = cl.getExpressionUtilities().computeFunctionCoefficients(values, min, max);
OpenCLArray* array = OpenCLArray::create<mm_float4>(cl, values.size()-1, "TabulatedFunction");
tabulatedFunctions.push_back(array); tabulatedFunctions.push_back(array);
array->upload(f); array->upload(f);
string arrayName = cl.getBondedUtilities().addArgument(array->getDeviceBuffer(), "float4"); string arrayName = cl.getBondedUtilities().addArgument(array->getDeviceBuffer(), width == 1 ? "float" : "float"+cl.intToString(width));
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
} }
string functionParamsName;
if (force.getNumFunctions() > 0) {
tabulatedFunctionParams = OpenCLArray::create<mm_float4>(cl, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters", CL_MEM_READ_ONLY);
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
functionParamsName = cl.getBondedUtilities().addArgument(tabulatedFunctionParams->getDeviceBuffer(), "float4");
}
// Record information about parameters. // Record information about parameters.
...@@ -4482,7 +4432,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -4482,7 +4432,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
forceExpressions["energy += "] = energyExpression; forceExpressions["energy += "] = energyExpression;
compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", functionParamsName); compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
// Finally, apply forces to atoms. // Finally, apply forces to atoms.
...@@ -4504,7 +4454,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -4504,7 +4454,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
if (!isZeroExpression(forceExpressionZ)) if (!isZeroExpression(forceExpressionZ))
expressions[forceName+".z -= "] = forceExpressionZ; expressions[forceName+".z -= "] = forceExpressionZ;
if (expressions.size() > 0) if (expressions.size() > 0)
compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "coordtemp", functionParamsName); compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp");
compute<<"}\n"; compute<<"}\n";
} }
index = 0; index = 0;
...@@ -5194,8 +5144,9 @@ string OpenCLIntegrateCustomStepKernel::createGlobalComputation(const string& va ...@@ -5194,8 +5144,9 @@ string OpenCLIntegrateCustomStepKernel::createGlobalComputation(const string& va
variables[integrator.getGlobalVariableName(i)] = "globals["+cl.intToString(i)+"]"; variables[integrator.getGlobalVariableName(i)] = "globals["+cl.intToString(i)+"]";
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cl.intToString(i)+"]"; variables[parameterNames[i]] = "params["+cl.intToString(i)+"]";
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
} }
string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) { string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) {
...@@ -5231,9 +5182,10 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va ...@@ -5231,9 +5182,10 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va
variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i); variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cl.intToString(i)+"]"; variables[parameterNames[i]] = "params["+cl.intToString(i)+"]";
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float"); string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float");
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp"+cl.intToString(component)+"_", "", tempType); return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp"+cl.intToString(component)+"_", tempType);
} }
void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
......
...@@ -199,6 +199,103 @@ void testParallelComputation() { ...@@ -199,6 +199,103 @@ void testParallelComputation() {
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5); ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5);
} }
void testContinuous2DFunction() {
const int xsize = 10;
const int ysize = 11;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
System system;
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomCompoundBondForce* forceField = new CustomCompoundBondForce(1, "fn(x1,y1)+1");
vector<int> particles(1, 0);
forceField->addBond(particles, vector<double>());
vector<double> table(xsize*ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[0] = Vec3(x, y, 1.5);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = sin(0.25*x)*cos(0.33*y)+1;
force[0] = -0.25*cos(0.25*x)*cos(0.33*y);
force[1] = 0.3*sin(0.25*x)*sin(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
void testContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
const double zmin = 0.2;
const double zmax = 1.3;
System system;
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomCompoundBondForce* forceField = new CustomCompoundBondForce(1, "fn(x1,y1,z1)+1");
vector<int> particles(1, 0);
forceField->addBond(particles, vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
double z = zmin + k*(zmax-zmin)/zsize;
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(0.33*y)*(1+z);
}
}
}
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.1) {
positions[0] = Vec3(x, y, z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax && z >= zmin && z <= zmax) {
energy = sin(0.25*x)*cos(0.33*y)*(1.0+z)+1;
force[0] = -0.25*cos(0.25*x)*cos(0.33*y)*(1.0+z);
force[1] = 0.3*sin(0.25*x)*sin(0.33*y)*(1.0+z);
force[2] = -sin(0.25*x)*cos(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -206,6 +303,8 @@ int main(int argc, char* argv[]) { ...@@ -206,6 +303,8 @@ int main(int argc, char* argv[]) {
testBond(); testBond();
testPositionDependence(); testPositionDependence();
testParallelComputation(); testParallelComputation();
testContinuous2DFunction();
testContinuous3DFunction();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -277,7 +277,7 @@ void testTabulatedFunction() { ...@@ -277,7 +277,7 @@ void testTabulatedFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0); force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force); system.addForce(force);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
......
...@@ -214,7 +214,7 @@ void testCustomFunctions() { ...@@ -214,7 +214,7 @@ void testCustomFunctions() {
vector<double> function(2); vector<double> function(2);
function[0] = 0; function[0] = 0;
function[1] = 1; function[1] = 1;
custom->addFunction("foo", function, 0, 10); custom->addTabulatedFunction("foo", new Continuous1DFunction(function, 0, 10));
system.addForce(custom); system.addForce(custom);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(3); vector<Vec3> positions(3);
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -261,7 +261,7 @@ void testPeriodic() { ...@@ -261,7 +261,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction() { void testContinuous1DFunction() {
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
system.addParticle(1.0); system.addParticle(1.0);
...@@ -271,21 +271,20 @@ void testTabulatedFunction() { ...@@ -271,21 +271,20 @@ void testTabulatedFunction() {
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0); forceField->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0); positions[0] = Vec3(0, 0, 0);
double tol = 0.01;
for (int i = 1; i < 30; i++) { for (int i = 1; i < 30; i++) {
double x = (7.0/30.0)*i; double x = (7.0/30.0)*i;
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0)); double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1); ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
...@@ -295,11 +294,211 @@ void testTabulatedFunction() { ...@@ -295,11 +294,211 @@ void testTabulatedFunction() {
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Energy); State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
} }
} }
void testContinuous2DFunction() {
const int xsize = 20;
const int ysize = 21;
const double xmin = 0.4;
const double xmax = 1.5;
const double ymin = 0.0;
const double ymax = 2.1;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
double force = 0;
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = sin(0.25*x)*cos(0.33*y)+1.0;
force = -0.25*cos(0.25*x)*cos(0.33*y);
}
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
}
}
void testContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
const double zmin = 0.2;
const double zmax = 1.3;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
double z = zmin + k*(zmax-zmin)/zsize;
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(0.33*y)*(1+z);
}
}
}
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.1) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setParameter("b", z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
double force = 0;
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax && z >= zmin && z <= zmax) {
energy = sin(0.25*x)*cos(0.33*y)*(1.0+z)+1.0;
force = -0.25*cos(0.25*x)*cos(0.33*y)*(1.0+z);
}
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
void testDiscrete1DFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i));
forceField->addTabulatedFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i+1, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testDiscrete2DFunction() {
const int xsize = 10;
const int ysize = 5;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addTabulatedFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3((i%xsize)+1, 0, 0);
context.setPositions(positions);
context.setParameter("a", i/xsize);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testDiscrete3DFunction() {
const int xsize = 8;
const int ysize = 5;
const int zsize = 6;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addTabulatedFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3((i%xsize)+1, 0, 0);
context.setPositions(positions);
context.setParameter("a", (i/xsize)%ysize);
context.setParameter("b", i/(xsize*ysize));
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testCoulombLennardJones() { void testCoulombLennardJones() {
const int numMolecules = 300; const int numMolecules = 300;
const int numParticles = numMolecules*2; const int numParticles = numMolecules*2;
...@@ -725,7 +924,12 @@ int main(int argc, char* argv[]) { ...@@ -725,7 +924,12 @@ int main(int argc, char* argv[]) {
testExclusions(); testExclusions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(); testContinuous1DFunction();
testContinuous2DFunction();
testContinuous3DFunction();
testDiscrete1DFunction();
testDiscrete2DFunction();
testDiscrete3DFunction();
testCoulombLennardJones(); testCoulombLennardJones();
testParallelComputation(); testParallelComputation();
testSwitchingFunction(); testSwitchingFunction();
......
#ifndef OPENMM_REFERENCETABULATEDFUNCTION_H_
#define OPENMM_REFERENCETABULATEDFUNCTION_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/TabulatedFunction.h"
#include "openmm/internal/windowsExport.h"
#include "lepton/CustomFunction.h"
#include <vector>
namespace OpenMM {
/**
* Given a TabulatedFunction, wrap it in an appropriate subclass of Lepton::CustomFunction.
*/
extern "C" Lepton::CustomFunction* createReferenceTabulatedFunction(const TabulatedFunction& function);
/**
* This class adapts a Continuous1DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceContinuous1DFunction : public Lepton::CustomFunction {
public:
ReferenceContinuous1DFunction(const Continuous1DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Continuous1DFunction& function;
double min, max;
std::vector<double> x, values, derivs;
};
/**
* This class adapts a Continuous2DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceContinuous2DFunction : public Lepton::CustomFunction {
public:
ReferenceContinuous2DFunction(const Continuous2DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Continuous2DFunction& function;
int xsize, ysize;
double xmin, xmax, ymin, ymax;
std::vector<double> x, y, values;
std::vector<std::vector<double> > c;
};
/**
* This class adapts a Continuous3DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceContinuous3DFunction : public Lepton::CustomFunction {
public:
ReferenceContinuous3DFunction(const Continuous3DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Continuous3DFunction& function;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
std::vector<double> x, y, z, values;
std::vector<std::vector<double> > c;
};
/**
* This class adapts a Discrete1DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceDiscrete1DFunction : public Lepton::CustomFunction {
public:
ReferenceDiscrete1DFunction(const Discrete1DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Discrete1DFunction& function;
std::vector<double> values;
};
/**
* This class adapts a Discrete2DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceDiscrete2DFunction : public Lepton::CustomFunction {
public:
ReferenceDiscrete2DFunction(const Discrete2DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Discrete2DFunction& function;
int xsize, ysize;
std::vector<double> values;
};
/**
* This class adapts a Discrete3DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceDiscrete3DFunction : public Lepton::CustomFunction {
public:
ReferenceDiscrete3DFunction(const Discrete3DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Discrete3DFunction& function;
int xsize, ysize, zsize;
std::vector<double> values;
};
} // namespace OpenMM
#endif /*OPENMM_REFERENCETABULATEDFUNCTION_H_*/
...@@ -55,6 +55,7 @@ ...@@ -55,6 +55,7 @@
#include "ReferenceProperDihedralBond.h" #include "ReferenceProperDihedralBond.h"
#include "ReferenceRbDihedralBond.h" #include "ReferenceRbDihedralBond.h"
#include "ReferenceStochasticDynamics.h" #include "ReferenceStochasticDynamics.h"
#include "ReferenceTabulatedFunction.h"
#include "ReferenceVariableStochasticDynamics.h" #include "ReferenceVariableStochasticDynamics.h"
#include "ReferenceVariableVerletDynamics.h" #include "ReferenceVariableVerletDynamics.h"
#include "ReferenceVerletDynamics.h" #include "ReferenceVerletDynamics.h"
...@@ -69,7 +70,6 @@ ...@@ -69,7 +70,6 @@
#include "openmm/internal/CustomNonbondedForceImpl.h" #include "openmm/internal/CustomNonbondedForceImpl.h"
#include "openmm/internal/CMAPTorsionForceImpl.h" #include "openmm/internal/CMAPTorsionForceImpl.h"
#include "openmm/internal/NonbondedForceImpl.h" #include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/SplineFitter.h"
#include "openmm/Integrator.h" #include "openmm/Integrator.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "SimTKOpenMMUtilities.h" #include "SimTKOpenMMUtilities.h"
...@@ -923,38 +923,6 @@ void ReferenceCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& con ...@@ -923,38 +923,6 @@ void ReferenceCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& con
dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(context.getSystem(), force); dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(context.getSystem(), force);
} }
class ReferenceTabulatedFunction : public Lepton::CustomFunction {
public:
ReferenceTabulatedFunction(double min, double max, const vector<double>& values) :
min(min), max(max), values(values) {
int numValues = values.size();
x.resize(numValues);
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs);
}
int getNumArguments() const {
return 1;
}
double evaluate(const double* arguments) const {
double t = arguments[0];
if (t < min || t > max)
return 0.0;
return SplineFitter::evaluateSpline(x, values, derivs, t);
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
double t = arguments[0];
if (t < min || t > max)
return 0.0;
return SplineFitter::evaluateSplineDerivative(x, values, derivs, t);
}
CustomFunction* clone() const {
return new ReferenceTabulatedFunction(min, max, values);
}
double min, max;
vector<double> x, values, derivs;
};
ReferenceCalcCustomNonbondedForceKernel::~ReferenceCalcCustomNonbondedForceKernel() { ReferenceCalcCustomNonbondedForceKernel::~ReferenceCalcCustomNonbondedForceKernel() {
disposeRealArray(particleParamArray, numParticles); disposeRealArray(particleParamArray, numParticles);
if (neighborList != NULL) if (neighborList != NULL)
...@@ -1001,13 +969,8 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c ...@@ -1001,13 +969,8 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++)
string name; functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
vector<double> values;
double min, max;
force.getFunctionParameters(i, name, values, min, max);
functions[name] = new ReferenceTabulatedFunction(min, max, values);
}
// Parse the various expressions used to calculate the force. // Parse the various expressions used to calculate the force.
...@@ -1288,13 +1251,8 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1288,13 +1251,8 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++)
string name; functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
vector<double> values;
double min, max;
force.getFunctionParameters(i, name, values, min, max);
functions[name] = new ReferenceTabulatedFunction(min, max, values);
}
// Parse the expressions for computed values. // Parse the expressions for computed values.
...@@ -1507,13 +1465,8 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const ...@@ -1507,13 +1465,8 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++)
string name; functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
vector<double> values;
double min, max;
force.getFunctionParameters(i, name, values, min, max);
functions[name] = new ReferenceTabulatedFunction(min, max, values);
}
// Parse the expression and create the object used to calculate the interaction. // Parse the expression and create the object used to calculate the interaction.
...@@ -1609,13 +1562,8 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system ...@@ -1609,13 +1562,8 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++)
string name; functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
vector<double> values;
double min, max;
force.getFunctionParameters(i, name, values, min, max);
functions[name] = new ReferenceTabulatedFunction(min, max, values);
}
// Parse the expression and create the object used to calculate the interaction. // Parse the expression and create the object used to calculate the interaction.
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "ReferenceTabulatedFunction.h"
#include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include <cmath>
using namespace OpenMM;
using namespace std;
using Lepton::CustomFunction;
extern "C" CustomFunction* createReferenceTabulatedFunction(const TabulatedFunction& function) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
return new ReferenceContinuous1DFunction(dynamic_cast<const Continuous1DFunction&>(function));
if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
return new ReferenceContinuous2DFunction(dynamic_cast<const Continuous2DFunction&>(function));
if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
return new ReferenceContinuous3DFunction(dynamic_cast<const Continuous3DFunction&>(function));
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
return new ReferenceDiscrete1DFunction(dynamic_cast<const Discrete1DFunction&>(function));
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
return new ReferenceDiscrete2DFunction(dynamic_cast<const Discrete2DFunction&>(function));
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
return new ReferenceDiscrete3DFunction(dynamic_cast<const Discrete3DFunction&>(function));
throw OpenMMException("createReferenceTabulatedFunction: Unknown function type");
}
ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const Continuous1DFunction& function) : function(function) {
function.getFunctionParameters(values, min, max);
int numValues = values.size();
x.resize(numValues);
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs);
}
int ReferenceContinuous1DFunction::getNumArguments() const {
return 1;
}
double ReferenceContinuous1DFunction::evaluate(const double* arguments) const {
double t = arguments[0];
if (t < min || t > max)
return 0.0;
return SplineFitter::evaluateSpline(x, values, derivs, t);
}
double ReferenceContinuous1DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
double t = arguments[0];
if (t < min || t > max)
return 0.0;
return SplineFitter::evaluateSplineDerivative(x, values, derivs, t);
}
CustomFunction* ReferenceContinuous1DFunction::clone() const {
return new ReferenceContinuous1DFunction(function);
}
ReferenceContinuous2DFunction::ReferenceContinuous2DFunction(const Continuous2DFunction& function) : function(function) {
function.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
x.resize(xsize);
y.resize(ysize);
for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1);
SplineFitter::create2DNaturalSpline(x, y, values, c);
}
int ReferenceContinuous2DFunction::getNumArguments() const {
return 2;
}
double ReferenceContinuous2DFunction::evaluate(const double* arguments) const {
double u = arguments[0];
if (u < xmin || u > xmax)
return 0.0;
double v = arguments[1];
if (v < ymin || v > ymax)
return 0.0;
return SplineFitter::evaluate2DSpline(x, y, values, c, u, v);
}
double ReferenceContinuous2DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
double u = arguments[0];
if (u < xmin || u > xmax)
return 0.0;
double v = arguments[1];
if (v < ymin || v > ymax)
return 0.0;
double dx, dy;
SplineFitter::evaluate2DSplineDerivatives(x, y, values, c, u, v, dx, dy);
if (derivOrder[0] == 1 && derivOrder[1] == 0)
return dx;
if (derivOrder[0] == 0 && derivOrder[1] == 1)
return dy;
throw OpenMMException("ReferenceContinuous2DFunction: Unsupported derivative order");
}
CustomFunction* ReferenceContinuous2DFunction::clone() const {
return new ReferenceContinuous2DFunction(function);
}
ReferenceContinuous3DFunction::ReferenceContinuous3DFunction(const Continuous3DFunction& function) : function(function) {
function.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
x.resize(xsize);
y.resize(ysize);
z.resize(zsize);
for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1);
for (int i = 0; i < zsize; i++)
z[i] = zmin+i*(zmax-zmin)/(zsize-1);
SplineFitter::create3DNaturalSpline(x, y, z, values, c);
}
int ReferenceContinuous3DFunction::getNumArguments() const {
return 3;
}
double ReferenceContinuous3DFunction::evaluate(const double* arguments) const {
double u = arguments[0];
if (u < xmin || u > xmax)
return 0.0;
double v = arguments[1];
if (v < ymin || v > ymax)
return 0.0;
double w = arguments[2];
if (w < zmin || w > zmax)
return 0.0;
return SplineFitter::evaluate3DSpline(x, y, z, values, c, u, v, w);
}
double ReferenceContinuous3DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
double u = arguments[0];
if (u < xmin || u > xmax)
return 0.0;
double v = arguments[1];
if (v < ymin || v > ymax)
return 0.0;
double w = arguments[2];
if (w < zmin || w > zmax)
return 0.0;
double dx, dy, dz;
SplineFitter::evaluate3DSplineDerivatives(x, y, z, values, c, u, v, w, dx, dy, dz);
if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0)
return dx;
if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0)
return dy;
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1)
return dz;
throw OpenMMException("ReferenceContinuous3DFunction: Unsupported derivative order");
}
CustomFunction* ReferenceContinuous3DFunction::clone() const {
return new ReferenceContinuous3DFunction(function);
}
ReferenceDiscrete1DFunction::ReferenceDiscrete1DFunction(const Discrete1DFunction& function) : function(function) {
function.getFunctionParameters(values);
}
int ReferenceDiscrete1DFunction::getNumArguments() const {
return 1;
}
double ReferenceDiscrete1DFunction::evaluate(const double* arguments) const {
int i = (int) round(arguments[0]);
if (i < 0 || i >= values.size())
throw OpenMMException("ReferenceDiscrete1DFunction: argument out of range");
return values[i];
}
double ReferenceDiscrete1DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* ReferenceDiscrete1DFunction::clone() const {
return new ReferenceDiscrete1DFunction(function);
}
ReferenceDiscrete2DFunction::ReferenceDiscrete2DFunction(const Discrete2DFunction& function) : function(function) {
function.getFunctionParameters(xsize, ysize, values);
}
int ReferenceDiscrete2DFunction::getNumArguments() const {
return 2;
}
double ReferenceDiscrete2DFunction::evaluate(const double* arguments) const {
int i = (int) round(arguments[0]);
int j = (int) round(arguments[1]);
if (i < 0 || i >= xsize || j < 0 || j >= ysize)
throw OpenMMException("ReferenceDiscrete2DFunction: argument out of range");
return values[i+j*xsize];
}
double ReferenceDiscrete2DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* ReferenceDiscrete2DFunction::clone() const {
return new ReferenceDiscrete2DFunction(function);
}
ReferenceDiscrete3DFunction::ReferenceDiscrete3DFunction(const Discrete3DFunction& function) : function(function) {
function.getFunctionParameters(xsize, ysize, zsize, values);
}
int ReferenceDiscrete3DFunction::getNumArguments() const {
return 3;
}
double ReferenceDiscrete3DFunction::evaluate(const double* arguments) const {
int i = (int) round(arguments[0]);
int j = (int) round(arguments[1]);
int k = (int) round(arguments[2]);
if (i < 0 || i >= xsize || j < 0 || j >= ysize || k < 0 || k >= zsize)
throw OpenMMException("ReferenceDiscrete3DFunction: argument out of range");
return values[i+(j+k*ysize)*xsize];
}
double ReferenceDiscrete3DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* ReferenceDiscrete3DFunction::clone() const {
return new ReferenceDiscrete3DFunction(function);
}
...@@ -167,10 +167,111 @@ void testPositionDependence() { ...@@ -167,10 +167,111 @@ void testPositionDependence() {
ASSERT_EQUAL_VEC(Vec3(-0.3, -2, 0), state.getForces()[1], 1e-5); ASSERT_EQUAL_VEC(Vec3(-0.3, -2, 0), state.getForces()[1], 1e-5);
} }
void testContinuous2DFunction() {
const int xsize = 10;
const int ysize = 11;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
ReferencePlatform platform;
System system;
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomCompoundBondForce* forceField = new CustomCompoundBondForce(1, "fn(x1,y1)+1");
vector<int> particles(1, 0);
forceField->addBond(particles, vector<double>());
vector<double> table(xsize*ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[0] = Vec3(x, y, 1.5);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = sin(0.25*x)*cos(0.33*y)+1;
force[0] = -0.25*cos(0.25*x)*cos(0.33*y);
force[1] = 0.3*sin(0.25*x)*sin(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
void testContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
const double zmin = 0.2;
const double zmax = 1.3;
ReferencePlatform platform;
System system;
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomCompoundBondForce* forceField = new CustomCompoundBondForce(1, "fn(x1,y1,z1)+1");
vector<int> particles(1, 0);
forceField->addBond(particles, vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
double z = zmin + k*(zmax-zmin)/zsize;
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(0.33*y)*(1+z);
}
}
}
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.1) {
positions[0] = Vec3(x, y, z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax && z >= zmin && z <= zmax) {
energy = sin(0.25*x)*cos(0.33*y)*(1.0+z)+1;
force[0] = -0.25*cos(0.25*x)*cos(0.33*y)*(1.0+z);
force[1] = 0.3*sin(0.25*x)*sin(0.33*y)*(1.0+z);
force[2] = -sin(0.25*x)*cos(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
int main() { int main() {
try { try {
testBond(); testBond();
testPositionDependence(); testPositionDependence();
testContinuous2DFunction();
testContinuous3DFunction();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -281,7 +281,7 @@ void testTabulatedFunction() { ...@@ -281,7 +281,7 @@ void testTabulatedFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0); force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force); system.addForce(force);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
......
...@@ -217,7 +217,7 @@ void testCustomFunctions() { ...@@ -217,7 +217,7 @@ void testCustomFunctions() {
vector<double> function(2); vector<double> function(2);
function[0] = 0; function[0] = 0;
function[1] = 1; function[1] = 1;
custom->addFunction("foo", function, 0, 10); custom->addTabulatedFunction("foo", new Continuous1DFunction(function, 0, 10));
system.addForce(custom); system.addForce(custom);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(3); vector<Vec3> positions(3);
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -227,7 +227,7 @@ void testPeriodic() { ...@@ -227,7 +227,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction() { void testContinuous1DFunction() {
ReferencePlatform platform; ReferencePlatform platform;
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -238,21 +238,20 @@ void testTabulatedFunction() { ...@@ -238,21 +238,20 @@ void testTabulatedFunction() {
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0); forceField->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0); positions[0] = Vec3(0, 0, 0);
double tol = 0.01;
for (int i = 1; i < 30; i++) { for (int i = 1; i < 30; i++) {
double x = (7.0/30.0)*i; double x = (7.0/30.0)*i;
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0)); double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1); ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
...@@ -262,11 +261,217 @@ void testTabulatedFunction() { ...@@ -262,11 +261,217 @@ void testTabulatedFunction() {
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Energy); State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
} }
} }
void testContinuous2DFunction() {
const int xsize = 20;
const int ysize = 21;
const double xmin = 0.4;
const double xmax = 1.5;
const double ymin = 0.0;
const double ymax = 2.1;
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
double force = 0;
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = sin(0.25*x)*cos(0.33*y)+1.0;
force = -0.25*cos(0.25*x)*cos(0.33*y);
}
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
}
}
void testContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
const double zmin = 0.2;
const double zmax = 1.3;
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
double z = zmin + k*(zmax-zmin)/zsize;
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(0.33*y)*(1+z);
}
}
}
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.1) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setParameter("b", z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
double force = 0;
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax && z >= zmin && z <= zmax) {
energy = sin(0.25*x)*cos(0.33*y)*(1.0+z)+1.0;
force = -0.25*cos(0.25*x)*cos(0.33*y)*(1.0+z);
}
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
void testDiscrete1DFunction() {
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i));
forceField->addTabulatedFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL(table[i]+1.0, state.getPotentialEnergy());
}
}
void testDiscrete2DFunction() {
const int xsize = 10;
const int ysize = 5;
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addTabulatedFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i%xsize, 0, 0);
context.setPositions(positions);
context.setParameter("a", i/xsize);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL(table[i]+1.0, state.getPotentialEnergy());
}
}
void testDiscrete3DFunction() {
const int xsize = 8;
const int ysize = 5;
const int zsize = 6;
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addTabulatedFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i%xsize, 0, 0);
context.setPositions(positions);
context.setParameter("a", (i/xsize)%ysize);
context.setParameter("b", i/(xsize*ysize));
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL(table[i]+1.0, state.getPotentialEnergy());
}
}
void testCoulombLennardJones() { void testCoulombLennardJones() {
const int numMolecules = 300; const int numMolecules = 300;
const int numParticles = numMolecules*2; const int numParticles = numMolecules*2;
...@@ -658,7 +863,12 @@ int main() { ...@@ -658,7 +863,12 @@ int main() {
testExclusions(); testExclusions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(); testContinuous1DFunction();
testContinuous2DFunction();
testContinuous3DFunction();
testDiscrete1DFunction();
testDiscrete2DFunction();
testDiscrete3DFunction();
testCoulombLennardJones(); testCoulombLennardJones();
testSwitchingFunction(); testSwitchingFunction();
testLongRangeCorrection(); testLongRangeCorrection();
......
#ifndef OPENMM_TABULATEDFUNCTION_PROXIES_H_
#define OPENMM_TABULATEDFUNCTION_PROXIES_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/internal/windowsExport.h"
#include "openmm/serialization/SerializationProxy.h"
namespace OpenMM {
/**
* This is a proxy for serializing Continuous1DFunction objects.
*/
class OPENMM_EXPORT Continuous1DFunctionProxy : public SerializationProxy {
public:
Continuous1DFunctionProxy();
void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const;
};
/**
* This is a proxy for serializing Continuous2DFunction objects.
*/
class OPENMM_EXPORT Continuous2DFunctionProxy : public SerializationProxy {
public:
Continuous2DFunctionProxy();
void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const;
};
/**
* This is a proxy for serializing Continuous3DFunction objects.
*/
class OPENMM_EXPORT Continuous3DFunctionProxy : public SerializationProxy {
public:
Continuous3DFunctionProxy();
void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const;
};
/**
* This is a proxy for serializing Discrete1DFunction objects.
*/
class OPENMM_EXPORT Discrete1DFunctionProxy : public SerializationProxy {
public:
Discrete1DFunctionProxy();
void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const;
};
/**
* This is a proxy for serializing Discrete2DFunction objects.
*/
class OPENMM_EXPORT Discrete2DFunctionProxy : public SerializationProxy {
public:
Discrete2DFunctionProxy();
void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const;
};
/**
* This is a proxy for serializing Discrete3DFunction objects.
*/
class OPENMM_EXPORT Discrete3DFunctionProxy : public SerializationProxy {
public:
Discrete3DFunctionProxy();
void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const;
};
} // namespace OpenMM
#endif /*OPENMM_TABULATEDFUNCTION_PROXIES_H_*/
...@@ -73,6 +73,9 @@ void CustomCompoundBondForceProxy::serialize(const void* object, SerializationNo ...@@ -73,6 +73,9 @@ void CustomCompoundBondForceProxy::serialize(const void* object, SerializationNo
node.setDoubleProperty(key.str(), params[j]); node.setDoubleProperty(key.str(), params[j]);
} }
} }
SerializationNode& functions = node.createChildNode("Functions");
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions.createChildNode("Function", &force.getTabulatedFunction(i)).setStringProperty("name", force.getTabulatedFunctionName(i));
} }
void* CustomCompoundBondForceProxy::deserialize(const SerializationNode& node) const { void* CustomCompoundBondForceProxy::deserialize(const SerializationNode& node) const {
...@@ -110,6 +113,22 @@ void* CustomCompoundBondForceProxy::deserialize(const SerializationNode& node) c ...@@ -110,6 +113,22 @@ void* CustomCompoundBondForceProxy::deserialize(const SerializationNode& node) c
} }
force->addBond(particles, params); force->addBond(particles, params);
} }
const SerializationNode& functions = node.getChildNode("Functions");
for (int i = 0; i < (int) functions.getChildren().size(); i++) {
const SerializationNode& function = functions.getChildren()[i];
if (function.hasProperty("type")) {
force->addTabulatedFunction(function.getStringProperty("name"), function.decodeObject<TabulatedFunction>());
}
else {
// This is an old file created before TabulatedFunction existed.
const SerializationNode& valuesNode = function.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
force->addTabulatedFunction(function.getStringProperty("name"), new Continuous1DFunction(values, function.getDoubleProperty("min"), function.getDoubleProperty("max")));
}
}
return force; return force;
} }
catch (...) { catch (...) {
......
...@@ -87,16 +87,8 @@ void CustomGBForceProxy::serialize(const void* object, SerializationNode& node) ...@@ -87,16 +87,8 @@ void CustomGBForceProxy::serialize(const void* object, SerializationNode& node)
exclusions.createChildNode("Exclusion").setIntProperty("p1", particle1).setIntProperty("p2", particle2); exclusions.createChildNode("Exclusion").setIntProperty("p1", particle1).setIntProperty("p2", particle2);
} }
SerializationNode& functions = node.createChildNode("Functions"); SerializationNode& functions = node.createChildNode("Functions");
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
string name; functions.createChildNode("Function", &force.getTabulatedFunction(i)).setStringProperty("name", force.getTabulatedFunctionName(i));
vector<double> values;
double min, max;
force.getFunctionParameters(i, name, values, min, max);
SerializationNode& node = functions.createChildNode("Function").setStringProperty("name", name).setDoubleProperty("min", min).setDoubleProperty("max", max);
SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
}
} }
void* CustomGBForceProxy::deserialize(const SerializationNode& node) const { void* CustomGBForceProxy::deserialize(const SerializationNode& node) const {
...@@ -147,11 +139,18 @@ void* CustomGBForceProxy::deserialize(const SerializationNode& node) const { ...@@ -147,11 +139,18 @@ void* CustomGBForceProxy::deserialize(const SerializationNode& node) const {
const SerializationNode& functions = node.getChildNode("Functions"); const SerializationNode& functions = node.getChildNode("Functions");
for (int i = 0; i < (int) functions.getChildren().size(); i++) { for (int i = 0; i < (int) functions.getChildren().size(); i++) {
const SerializationNode& function = functions.getChildren()[i]; const SerializationNode& function = functions.getChildren()[i];
const SerializationNode& valuesNode = function.getChildNode("Values"); if (function.hasProperty("type")) {
vector<double> values; force->addTabulatedFunction(function.getStringProperty("name"), function.decodeObject<TabulatedFunction>());
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++) }
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v")); else {
force->addFunction(function.getStringProperty("name"), values, function.getDoubleProperty("min"), function.getDoubleProperty("max")); // This is an old file created before TabulatedFunction existed.
const SerializationNode& valuesNode = function.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
force->addTabulatedFunction(function.getStringProperty("name"), new Continuous1DFunction(values, function.getDoubleProperty("min"), function.getDoubleProperty("max")));
}
} }
return force; return force;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment