Commit 0fe7612b authored by peastman's avatar peastman
Browse files

Merge pull request #337 from peastman/functions

Created new API for tabulated functions
parents 7a7055b3 ed31a458
...@@ -62,6 +62,7 @@ ...@@ -62,6 +62,7 @@
#include "openmm/RBTorsionForce.h" #include "openmm/RBTorsionForce.h"
#include "openmm/State.h" #include "openmm/State.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/TabulatedFunction.h"
#include "openmm/Units.h" #include "openmm/Units.h"
#include "openmm/VariableLangevinIntegrator.h" #include "openmm/VariableLangevinIntegrator.h"
#include "openmm/VariableVerletIntegrator.h" #include "openmm/VariableVerletIntegrator.h"
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "TabulatedFunction.h"
#include "Force.h" #include "Force.h"
#include "Vec3.h" #include "Vec3.h"
#include <vector> #include <vector>
...@@ -91,8 +92,8 @@ namespace OpenMM { ...@@ -91,8 +92,8 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of * In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* values, and a natural spline is created from them. That function can then appear in the expression. * creating a TabulatedFunction object. That function can then appear in the expression.
*/ */
class OPENMM_EXPORT CustomCompoundBondForce : public Force { class OPENMM_EXPORT CustomCompoundBondForce : public Force {
...@@ -106,6 +107,7 @@ public: ...@@ -106,6 +107,7 @@ public:
* and per-bond parameters * and per-bond parameters
*/ */
explicit CustomCompoundBondForce(int numParticles, const std::string& energy); explicit CustomCompoundBondForce(int numParticles, const std::string& energy);
~CustomCompoundBondForce();
/** /**
* Get the number of particles used to define each bond. * Get the number of particles used to define each bond.
*/ */
...@@ -133,6 +135,14 @@ public: ...@@ -133,6 +135,14 @@ public:
/** /**
* Get the number of tabulated functions that have been defined. * Get the number of tabulated functions that have been defined.
*/ */
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const { int getNumFunctions() const {
return functions.size(); return functions.size();
} }
...@@ -229,33 +239,51 @@ public: ...@@ -229,33 +239,51 @@ public:
* Add a tabulated function that may appear in the energy expression. * Add a tabulated function that may appear in the energy expression.
* *
* @param name the name of the function as it appears in expressions * @param name the name of the function as it appears in expressions
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max. * @param function a TabulatedFunction object defining the function. The TabulatedFunction
* The function is assumed to be zero for x &lt; min or x &gt; max. * should have been created on the heap with the "new" operator. The
* @param min the value of the independent variable corresponding to the first element of values * Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @param max the value of the independent variable corresponding to the last element of values
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getTabulatedFunctionName(int index) const;
/**
* Add a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function for which to get parameters * @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* @param name the name of the function as it appears in expressions * If the specified function is not a Continuous1DFunction, this throws an exception.
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in algebraic expressions. * Set the parameters for a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function for which to set parameters * @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* @param name the name of the function as it appears in expressions * If the specified function is not a Continuous1DFunction, this throws an exception.
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
...@@ -333,12 +361,10 @@ public: ...@@ -333,12 +361,10 @@ public:
class CustomCompoundBondForce::FunctionInfo { class CustomCompoundBondForce::FunctionInfo {
public: public:
std::string name; std::string name;
std::vector<double> values; TabulatedFunction* function;
double min, max;
FunctionInfo() { FunctionInfo() {
} }
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) : FunctionInfo(const std::string& name, TabulatedFunction* function) : name(name), function(function) {
name(name), values(values), min(min), max(max) {
} }
}; };
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "TabulatedFunction.h"
#include "Force.h" #include "Force.h"
#include "Vec3.h" #include "Vec3.h"
#include <map> #include <map>
...@@ -134,8 +135,8 @@ namespace OpenMM { ...@@ -134,8 +135,8 @@ namespace OpenMM {
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example, * have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of * In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* values, and a natural spline is created from them. That function can then appear in expressions. * creating a TabulatedFunction object. That function can then appear in expressions.
*/ */
class OPENMM_EXPORT CustomGBForce : public Force { class OPENMM_EXPORT CustomGBForce : public Force {
...@@ -181,6 +182,7 @@ public: ...@@ -181,6 +182,7 @@ public:
* Create a CustomGBForce. * Create a CustomGBForce.
*/ */
CustomGBForce(); CustomGBForce();
~CustomGBForce();
/** /**
* Get the number of particles for which force field parameters have been defined. * Get the number of particles for which force field parameters have been defined.
*/ */
...@@ -208,6 +210,14 @@ public: ...@@ -208,6 +210,14 @@ public:
/** /**
* Get the number of tabulated functions that have been defined. * Get the number of tabulated functions that have been defined.
*/ */
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const { int getNumFunctions() const {
return functions.size(); return functions.size();
} }
...@@ -452,36 +462,54 @@ public: ...@@ -452,36 +462,54 @@ public:
*/ */
void setExclusionParticles(int index, int particle1, int particle2); void setExclusionParticles(int index, int particle1, int particle2);
/** /**
* Add a tabulated function that may appear in the energy expression. * Add a tabulated function that may appear in expressions.
* *
* @param name the name of the function as it appears in expressions * @param name the name of the function as it appears in expressions
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max. * @param function a TabulatedFunction object defining the function. The TabulatedFunction
* The function is assumed to be zero for x &lt; min or x &gt; max. * should have been created on the heap with the "new" operator. The
* @param min the value of the independent variable corresponding to the first element of values * Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @param max the value of the independent variable corresponding to the last element of values
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getTabulatedFunctionName(int index) const;
/**
* Add a tabulated function that may appear in expressions.
*
* @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in expressions.
* *
* @param index the index of the function for which to get parameters * @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* @param name the name of the function as it appears in expressions * If the specified function is not a Continuous1DFunction, this throws an exception.
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in algebraic expressions. * Set the parameters for a tabulated function that may appear in expressions.
* *
* @param index the index of the function for which to set parameters * @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* @param name the name of the function as it appears in expressions * If the specified function is not a Continuous1DFunction, this throws an exception.
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
...@@ -577,12 +605,10 @@ public: ...@@ -577,12 +605,10 @@ public:
class CustomGBForce::FunctionInfo { class CustomGBForce::FunctionInfo {
public: public:
std::string name; std::string name;
std::vector<double> values; TabulatedFunction* function;
double min, max;
FunctionInfo() { FunctionInfo() {
} }
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) : FunctionInfo(const std::string& name, TabulatedFunction* function) : name(name), function(function) {
name(name), values(values), min(min), max(max) {
} }
}; };
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "TabulatedFunction.h"
#include "Force.h" #include "Force.h"
#include "Vec3.h" #include "Vec3.h"
#include <map> #include <map>
...@@ -91,8 +92,8 @@ namespace OpenMM { ...@@ -91,8 +92,8 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of * In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* values, and a natural spline is created from them. That function can then appear in the expression. * creating a TabulatedFunction object. That function can then appear in the expression.
*/ */
class OPENMM_EXPORT CustomHbondForce : public Force { class OPENMM_EXPORT CustomHbondForce : public Force {
...@@ -124,6 +125,7 @@ public: ...@@ -124,6 +125,7 @@ public:
* per-acceptor parameters * per-acceptor parameters
*/ */
explicit CustomHbondForce(const std::string& energy); explicit CustomHbondForce(const std::string& energy);
~CustomHbondForce();
/** /**
* Get the number of donors for which force field parameters have been defined. * Get the number of donors for which force field parameters have been defined.
*/ */
...@@ -163,6 +165,14 @@ public: ...@@ -163,6 +165,14 @@ public:
/** /**
* Get the number of tabulated functions that have been defined. * Get the number of tabulated functions that have been defined.
*/ */
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const { int getNumFunctions() const {
return functions.size(); return functions.size();
} }
...@@ -374,33 +384,51 @@ public: ...@@ -374,33 +384,51 @@ public:
* Add a tabulated function that may appear in the energy expression. * Add a tabulated function that may appear in the energy expression.
* *
* @param name the name of the function as it appears in expressions * @param name the name of the function as it appears in expressions
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max. * @param function a TabulatedFunction object defining the function. The TabulatedFunction
* The function is assumed to be zero for x &lt; min or x &gt; max. * should have been created on the heap with the "new" operator. The
* @param min the value of the independent variable corresponding to the first element of values * Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @param max the value of the independent variable corresponding to the last element of values
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getTabulatedFunctionName(int index) const;
/**
* Add a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function for which to get parameters * @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* @param name the name of the function as it appears in expressions * If the specified function is not a Continuous1DFunction, this throws an exception.
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in algebraic expressions. * Set the parameters for a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function for which to set parameters * @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* @param name the name of the function as it appears in expressions * If the specified function is not a Continuous1DFunction, this throws an exception.
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
...@@ -499,12 +527,10 @@ public: ...@@ -499,12 +527,10 @@ public:
class CustomHbondForce::FunctionInfo { class CustomHbondForce::FunctionInfo {
public: public:
std::string name; std::string name;
std::vector<double> values; TabulatedFunction* function;
double min, max;
FunctionInfo() { FunctionInfo() {
} }
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) : FunctionInfo(const std::string& name, TabulatedFunction* function) : name(name), function(function) {
name(name), values(values), min(min), max(max) {
} }
}; };
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "TabulatedFunction.h"
#include "Force.h" #include "Force.h"
#include "Vec3.h" #include "Vec3.h"
#include <map> #include <map>
...@@ -124,8 +125,8 @@ namespace OpenMM { ...@@ -124,8 +125,8 @@ namespace OpenMM {
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example, * have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of * In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* values, and a natural spline is created from them. That function can then appear in the expression. * creating a TabulatedFunction object. That function can then appear in the expression.
*/ */
class OPENMM_EXPORT CustomNonbondedForce : public Force { class OPENMM_EXPORT CustomNonbondedForce : public Force {
...@@ -156,6 +157,7 @@ public: ...@@ -156,6 +157,7 @@ public:
* of r, the distance between them, as well as any global and per-particle parameters * of r, the distance between them, as well as any global and per-particle parameters
*/ */
explicit CustomNonbondedForce(const std::string& energy); explicit CustomNonbondedForce(const std::string& energy);
~CustomNonbondedForce();
/** /**
* Get the number of particles for which force field parameters have been defined. * Get the number of particles for which force field parameters have been defined.
*/ */
...@@ -183,6 +185,14 @@ public: ...@@ -183,6 +185,14 @@ public:
/** /**
* Get the number of tabulated functions that have been defined. * Get the number of tabulated functions that have been defined.
*/ */
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const { int getNumFunctions() const {
return functions.size(); return functions.size();
} }
...@@ -359,33 +369,51 @@ public: ...@@ -359,33 +369,51 @@ public:
* Add a tabulated function that may appear in the energy expression. * Add a tabulated function that may appear in the energy expression.
* *
* @param name the name of the function as it appears in expressions * @param name the name of the function as it appears in expressions
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max. * @param function a TabulatedFunction object defining the function. The TabulatedFunction
* The function is assumed to be zero for x &lt; min or x &gt; max. * should have been created on the heap with the "new" operator. The
* @param min the value of the independent variable corresponding to the first element of values * Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @param max the value of the independent variable corresponding to the last element of values
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getTabulatedFunctionName(int index) const;
/**
* Add a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function for which to get parameters * @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* @param name the name of the function as it appears in expressions * If the specified function is not a Continuous1DFunction, this throws an exception.
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in algebraic expressions. * Set the parameters for a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function for which to set parameters * @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* @param name the name of the function as it appears in expressions * If the specified function is not a Continuous1DFunction, this throws an exception.
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
...@@ -507,12 +535,10 @@ public: ...@@ -507,12 +535,10 @@ public:
class CustomNonbondedForce::FunctionInfo { class CustomNonbondedForce::FunctionInfo {
public: public:
std::string name; std::string name;
std::vector<double> values; TabulatedFunction* function;
double min, max;
FunctionInfo() { FunctionInfo() {
} }
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) : FunctionInfo(const std::string& name, TabulatedFunction* function) : name(name), function(function) {
name(name), values(values), min(min), max(max) {
} }
}; };
......
#ifndef OPENMM_TABULATEDFUNCTION_H_
#define OPENMM_TABULATEDFUNCTION_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "internal/windowsExport.h"
#include <vector>
namespace OpenMM {
/**
* A TabulatedFunction uses a set of tabulated values to define a mathematical function.
* It can be used by various custom forces.
*
* TabulatedFunction is an abstract class with concrete subclasses for more specific
* types of functions. There are subclasses for:
*
* <ul>
* <li>1, 2, and 3 dimensional functions. The dimensionality of a function means
* the number of input arguments it takes.</li>
* <li>Continuous and discrete functions. A continuous function is interpolated by
* fitting a natural cubic spline to the tabulated values. A discrete function is
* only defined for integer values of its arguments (that is, at the tabulated points),
* and does not try to interpolate between them. Discrete function can be evaluated
* more quickly than continuous ones.</li>
* </ul>
*/
class OPENMM_EXPORT TabulatedFunction {
public:
virtual ~TabulatedFunction() {
}
};
/**
* This is a TabulatedFunction that computes a continuous one dimensional function.
*/
class OPENMM_EXPORT Continuous1DFunction : public TabulatedFunction {
public:
/**
* Create a Continuous1DFunction f(x) based on a set of tabulated values.
*
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min
* and max. A natural cubic spline is used to interpolate between the tabulated values.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values
*/
Continuous1DFunction(const std::vector<double>& values, double min, double max);
/**
* Get the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min
* and max. A natural cubic spline is used to interpolate between the tabulated values.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values
*/
void getFunctionParameters(std::vector<double>& values, double& min, double& max) const;
/**
* Set the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min
* and max. A natural cubic spline is used to interpolate between the tabulated values.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values
*/
void setFunctionParameters(const std::vector<double>& values, double min, double max);
private:
std::vector<double> values;
double min, max;
};
/**
* This is a TabulatedFunction that computes a continuous two dimensional function.
*/
class OPENMM_EXPORT Continuous2DFunction : public TabulatedFunction {
public:
/**
* Create a Continuous2DFunction f(x,y) based on a set of tabulated values.
*
* @param values the tabulated values of the function f(x,y) at xsize uniformly spaced values of x between xmin
* and xmax, and ysize values of y between ymin and ymax. A natural cubic spline is used to interpolate between the tabulated values.
* The function is assumed to be zero when x or y is outside its specified range. The values should be ordered so that
* values[i+xsize*j] = f(x_i,y_j), where x_i is the i'th uniformly spaced value of x. This must be of length xsize*ysize.
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param xmin the value of x corresponding to the first element of values
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
*/
Continuous2DFunction(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax);
/**
* Get the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x,y) at xsize uniformly spaced values of x between xmin
* and xmax, and ysize values of y between ymin and ymax. A natural cubic spline is used to interpolate between the tabulated values.
* The function is assumed to be zero when x or y is outside its specified range. The values should be ordered so that
* values[i+xsize*j] = f(x_i,y_j), where x_i is the i'th uniformly spaced value of x. This must be of length xsize*ysize.
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param xmin the value of x corresponding to the first element of values
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
*/
void getFunctionParameters(int& xsize, int& ysize, std::vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax) const;
/**
* Set the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x,y) at xsize uniformly spaced values of x between xmin
* and xmax, and ysize values of y between ymin and ymax. A natural cubic spline is used to interpolate between the tabulated values.
* The function is assumed to be zero when x or y is outside its specified range. The values should be ordered so that
* values[i+xsize*j] = f(x_i,y_j), where x_i is the i'th uniformly spaced value of x. This must be of length xsize*ysize.
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param xmin the value of x corresponding to the first element of values
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
*/
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax);
private:
std::vector<double> values;
int xsize, ysize;
double xmin, xmax, ymin, ymax;
};
/**
* This is a TabulatedFunction that computes a continuous three dimensional function.
*/
class OPENMM_EXPORT Continuous3DFunction : public TabulatedFunction {
public:
/**
* Create a Continuous3DFunction f(x,y,z) based on a set of tabulated values.
*
* @param values the tabulated values of the function f(x,y,z) at xsize uniformly spaced values of x between xmin
* and xmax, ysize values of y between ymin and ymax, and zsize values of z between zmin and zmax.
* A natural cubic spline is used to interpolate between the tabulated values. The function is
* assumed to be zero when x, y, or z is outside its specified range. The values should be ordered so
* that values[i+xsize*j+xsize*ysize*k] = f(x_i,y_j,z_k), where x_i is the i'th uniformly spaced value of x.
* This must be of length xsize*ysize*zsize.
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param ysize the number of table elements along the z direction
* @param xmin the value of x corresponding to the first element of values
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
* @param zmin the value of z corresponding to the first element of values
* @param zmax the value of z corresponding to the last element of values
*/
Continuous3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
/**
* Get the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x,y,z) at xsize uniformly spaced values of x between xmin
* and xmax, ysize values of y between ymin and ymax, and zsize values of z between zmin and zmax.
* A natural cubic spline is used to interpolate between the tabulated values. The function is
* assumed to be zero when x, y, or z is outside its specified range. The values should be ordered so
* that values[i+xsize*j+xsize*ysize*k] = f(x_i,y_j,z_k), where x_i is the i'th uniformly spaced value of x.
* This must be of length xsize*ysize*zsize.
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param ysize the number of table elements along the z direction
* @param xmin the value of x corresponding to the first element of values
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
* @param zmin the value of z corresponding to the first element of values
* @param zmax the value of z corresponding to the last element of values
*/
void getFunctionParameters(int& xsize, int& ysize, int& zsize, std::vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax, double& zmin, double& zmax) const;
/**
* Set the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x,y,z) at xsize uniformly spaced values of x between xmin
* and xmax, ysize values of y between ymin and ymax, and zsize values of z between zmin and zmax.
* A natural cubic spline is used to interpolate between the tabulated values. The function is
* assumed to be zero when x, y, or z is outside its specified range. The values should be ordered so
* that values[i+xsize*j+xsize*ysize*k] = f(x_i,y_j,z_k), where x_i is the i'th uniformly spaced value of x.
* This must be of length xsize*ysize*zsize.
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param ysize the number of table elements along the z direction
* @param xmin the value of x corresponding to the first element of values
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
* @param zmin the value of z corresponding to the first element of values
* @param zmax the value of z corresponding to the last element of values
*/
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
private:
std::vector<double> values;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
};
/**
* This is a TabulatedFunction that computes a discrete one dimensional function f(x).
* To evaluate it, x is rounded to the nearest integer and the table element with that
* index is returned. If the index is outside the range [0, size), the result is undefined.
*/
class OPENMM_EXPORT Discrete1DFunction : public TabulatedFunction {
public:
/**
* Create a Discrete1DFunction f(x) based on a set of tabulated values.
*
* @param values the tabulated values of the function f(x)
*/
Discrete1DFunction(const std::vector<double>& values);
/**
* Get the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x)
*/
void getFunctionParameters(std::vector<double>& values) const;
/**
* Set the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x)
*/
void setFunctionParameters(const std::vector<double>& values);
private:
std::vector<double> values;
};
/**
* This is a TabulatedFunction that computes a discrete two dimensional function f(x,y).
* To evaluate it, x and y are each rounded to the nearest integer and the table element with those
* indices is returned. If either index is outside the range [0, size), the result is undefined.
*/
class OPENMM_EXPORT Discrete2DFunction : public TabulatedFunction {
public:
/**
* Create a Discrete2DFunction f(x,y) based on a set of tabulated values.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param values the tabulated values of the function f(x,y), ordered so that
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
Discrete2DFunction(int xsize, int ysize, const std::vector<double>& values);
/**
* Get the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param values the tabulated values of the function f(x,y), ordered so that
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
void getFunctionParameters(int& xsize, int& ysize, std::vector<double>& values) const;
/**
* Set the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param values the tabulated values of the function f(x,y), ordered so that
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values);
private:
int xsize, ysize;
std::vector<double> values;
};
/**
* This is a TabulatedFunction that computes a discrete three dimensional function f(x,y,z).
* To evaluate it, x, y, and z are each rounded to the nearest integer and the table element with those
* indices is returned. If any index is outside the range [0, size), the result is undefined.
*/
class OPENMM_EXPORT Discrete3DFunction : public TabulatedFunction {
public:
/**
* Create a Discrete3DFunction f(x,y,z) based on a set of tabulated values.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param zsize the number of table elements along the z direction
* @param values the tabulated values of the function f(x,y,z), ordered so that
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
Discrete3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values);
/**
* Get the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param zsize the number of table elements along the z direction
* @param values the tabulated values of the function f(x,y,z), ordered so that
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
void getFunctionParameters(int& xsize, int& ysize, int& zsize, std::vector<double>& values) const;
/**
* Set the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param zsize the number of table elements along the z direction
* @param values the tabulated values of the function f(x,y,z), ordered so that
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values);
private:
int xsize, ysize, zsize;
std::vector<double> values;
};
} // namespace OpenMM
#endif /*OPENMM_TABULATEDFUNCTION_H_*/
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010 Stanford University and the Authors. * * Portions copyright (c) 2010-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
*/ */
static void createPeriodicSpline(const std::vector<double>& x, const std::vector<double>& y, std::vector<double>& deriv); static void createPeriodicSpline(const std::vector<double>& x, const std::vector<double>& y, std::vector<double>& deriv);
/** /**
* Evaluate a spline generated by one of the other methods in this class. * Evaluate a 1D spline generated by one of the other methods in this class.
* *
* @param x the values of the independent variable at the data points to interpolate * @param x the values of the independent variable at the data points to interpolate
* @param y the values of the dependent variable at the data points to interpolate * @param y the values of the dependent variable at the data points to interpolate
...@@ -77,7 +77,7 @@ public: ...@@ -77,7 +77,7 @@ public:
*/ */
static double evaluateSpline(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& deriv, double t); static double evaluateSpline(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& deriv, double t);
/** /**
* Evaluate the derivative of a spline generated by one of the other methods in this class. * Evaluate the derivative of a 1D spline generated by one of the other methods in this class.
* *
* @param x the values of the independent variable at the data points to interpolate * @param x the values of the independent variable at the data points to interpolate
* @param y the values of the dependent variable at the data points to interpolate * @param y the values of the dependent variable at the data points to interpolate
...@@ -86,6 +86,90 @@ public: ...@@ -86,6 +86,90 @@ public:
* @return the value of the spline's derivative at the specified point * @return the value of the spline's derivative at the specified point
*/ */
static double evaluateSplineDerivative(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& deriv, double t); static double evaluateSplineDerivative(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& deriv, double t);
/**
* Fit a natural cubic spline surface f(x,y) to a 2D set of data points. The resulting spline interpolates all the
* data points, has a continuous second derivative everywhere, and has a second derivative of 0 at the boundary.
*
* @param x the values of the first independent variable at the data points to interpolate. They must
* be strictly increasing: x[i] > x[i-1].
* @param y the values of the second independent variable at the data points to interpolate. They must
* be strictly increasing: y[i] > y[i-1].
* @param values the values of the dependent variable at the data points to interpolate. They must be ordered
* so that values[i+xsize*j] = f(x[i],y[j]), where xsize is the length of x.
* @param c on exit, this contains the spline coefficients at each of the data points
*/
static void create2DNaturalSpline(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& values, std::vector<std::vector<double> >& c);
/**
* Evaluate a 2D spline generated by one of the other methods in this class.
*
* @param x the values of the first independent variable at the data points to interpolate
* @param y the values of the second independent variable at the data points to interpolate
* @param values the values of the dependent variable at the data points to interpolate
* @param c the vector of spline coefficients that was calculated by one of the other methods
* @param u the value of the first independent variable at which to evaluate the spline
* @param v the value of the second independent variable at which to evaluate the spline
* @return the value of the spline at the specified point
*/
static double evaluate2DSpline(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& values, const std::vector<std::vector<double> >& c, double u, double v);
/**
* Evaluate the derivatives of a 2D spline generated by one of the other methods in this class.
*
* @param x the values of the first independent variable at the data points to interpolate
* @param y the values of the second independent variable at the data points to interpolate
* @param values the values of the dependent variable at the data points to interpolate
* @param c the vector of spline coefficients that was calculated by one of the other methods
* @param u the value of the first independent variable at which to evaluate the spline
* @param v the value of the second independent variable at which to evaluate the spline
* @param dx on exit, the x derivative of the spline at the specified point
* @param dy on exit, the y derivative of the spline at the specified point
*/
static void evaluate2DSplineDerivatives(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& values, const std::vector<std::vector<double> >& c, double u, double v, double& dx, double& dy);
/**
* Fit a natural cubic spline surface f(x,y,z) to a 3D set of data points. The resulting spline interpolates all the
* data points, has a continuous second derivative everywhere, and has a second derivative of 0 at the boundary.
*
* @param x the values of the first independent variable at the data points to interpolate. They must
* be strictly increasing: x[i] > x[i-1].
* @param y the values of the second independent variable at the data points to interpolate. They must
* be strictly increasing: y[i] > y[i-1].
* @param z the values of the third independent variable at the data points to interpolate. They must
* be strictly increasing: z[i] > z[i-1].
* @param values the values of the dependent variable at the data points to interpolate. They must be ordered
* so that values[i+xsize*j+xsize*ysize*k] = f(x[i],y[j],z[k]), where xsize is the length of x
* and ysize is the length of y.
* @param c on exit, this contains the spline coefficients at each of the data points
*/
static void create3DNaturalSpline(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& z, const std::vector<double>& values, std::vector<std::vector<double> >& c);
/**
* Evaluate a 3D spline generated by one of the other methods in this class.
*
* @param x the values of the first independent variable at the data points to interpolate
* @param y the values of the second independent variable at the data points to interpolate
* @param z the values of the third independent variable at the data points to interpolate
* @param values the values of the dependent variable at the data points to interpolate
* @param c the vector of spline coefficients that was calculated by one of the other methods
* @param u the value of the first independent variable at which to evaluate the spline
* @param v the value of the second independent variable at which to evaluate the spline
* @param w the value of the third independent variable at which to evaluate the spline
* @return the value of the spline at the specified point
*/
static double evaluate3DSpline(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& z, const std::vector<double>& values, const std::vector<std::vector<double> >& c, double u, double v, double w);
/**
* Evaluate the derivatives of a 3D spline generated by one of the other methods in this class.
*
* @param x the values of the first independent variable at the data points to interpolate
* @param y the values of the second independent variable at the data points to interpolate
* @param z the values of the third independent variable at the data points to interpolate
* @param values the values of the dependent variable at the data points to interpolate
* @param c the vector of spline coefficients that was calculated by one of the other methods
* @param u the value of the first independent variable at which to evaluate the spline
* @param v the value of the second independent variable at which to evaluate the spline
* @param w the value of the third independent variable at which to evaluate the spline
* @param dx on exit, the x derivative of the spline at the specified point
* @param dy on exit, the y derivative of the spline at the specified point
* @param dz on exit, the z derivative of the spline at the specified point
*/
static void evaluate3DSplineDerivatives(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& z, const std::vector<double>& values, const std::vector<std::vector<double> >& c, double u, double v, double w, double& dx, double& dy, double &dz);
private: private:
static void solveTridiagonalMatrix(const std::vector<double>& a, const std::vector<double>& b, const std::vector<double>& c, const std::vector<double>& rhs, std::vector<double>& sol); static void solveTridiagonalMatrix(const std::vector<double>& a, const std::vector<double>& b, const std::vector<double>& c, const std::vector<double>& rhs, std::vector<double>& sol);
}; };
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -51,6 +51,12 @@ using std::vector; ...@@ -51,6 +51,12 @@ using std::vector;
CustomCompoundBondForce::CustomCompoundBondForce(int numParticles, const string& energy) : particlesPerBond(numParticles), energyExpression(energy) { CustomCompoundBondForce::CustomCompoundBondForce(int numParticles, const string& energy) : particlesPerBond(numParticles), energyExpression(energy) {
} }
CustomCompoundBondForce::~CustomCompoundBondForce() {
for (int i = 0; i < (int) functions.size(); i++)
delete functions[i].function;
}
const string& CustomCompoundBondForce::getEnergyFunction() const { const string& CustomCompoundBondForce::getEnergyFunction() const {
return energyExpression; return energyExpression;
} }
...@@ -120,33 +126,47 @@ void CustomCompoundBondForce::setBondParameters(int index, const vector<int>& pa ...@@ -120,33 +126,47 @@ void CustomCompoundBondForce::setBondParameters(int index, const vector<int>& pa
bonds[index].parameters = parameters; bonds[index].parameters = parameters;
} }
int CustomCompoundBondForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomCompoundBondForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomCompoundBondForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomCompoundBondForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
int CustomCompoundBondForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) { int CustomCompoundBondForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) functions.push_back(FunctionInfo(name, new Continuous1DFunction(values, min, max)));
throw OpenMMException("CustomCompoundBondForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomCompoundBondForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1; return functions.size()-1;
} }
void CustomCompoundBondForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const { void CustomCompoundBondForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
Continuous1DFunction* function = dynamic_cast<Continuous1DFunction*>(functions[index].function);
if (function == NULL)
throw OpenMMException("CustomCompoundBondForce: function is not a Continuous1DFunction");
name = functions[index].name; name = functions[index].name;
values = functions[index].values; function->getFunctionParameters(values, min, max);
min = functions[index].min;
max = functions[index].max;
} }
void CustomCompoundBondForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) { void CustomCompoundBondForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomCompoundBondForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomCompoundBondForce: a tabulated function must have at least two points");
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
Continuous1DFunction* function = dynamic_cast<Continuous1DFunction*>(functions[index].function);
if (function == NULL)
throw OpenMMException("CustomCompoundBondForce: function is not a Continuous1DFunction");
functions[index].name = name; functions[index].name = name;
functions[index].values = values; function->setFunctionParameters(values, min, max);
functions[index].min = min;
functions[index].max = max;
} }
ForceImpl* CustomCompoundBondForce::createImpl() const { ForceImpl* CustomCompoundBondForce::createImpl() const {
......
...@@ -147,7 +147,7 @@ ParsedExpression CustomCompoundBondForceImpl::prepareExpression(const CustomComp ...@@ -147,7 +147,7 @@ ParsedExpression CustomCompoundBondForceImpl::prepareExpression(const CustomComp
ExpressionTreeNode CustomCompoundBondForceImpl::replaceFunctions(const ExpressionTreeNode& node, map<string, int> atoms, ExpressionTreeNode CustomCompoundBondForceImpl::replaceFunctions(const ExpressionTreeNode& node, map<string, int> atoms,
map<string, vector<int> >& distances, map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals) { map<string, vector<int> >& distances, map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals) {
const Operation& op = node.getOperation(); const Operation& op = node.getOperation();
if (op.getId() != Operation::CUSTOM || op.getNumArguments() < 2) if (op.getId() != Operation::CUSTOM || (op.getName() != "distance" && op.getName() != "angle" && op.getName() != "dihedral"))
{ {
// This is not an angle or dihedral, so process its children. // This is not an angle or dihedral, so process its children.
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -50,6 +50,11 @@ using std::vector; ...@@ -50,6 +50,11 @@ using std::vector;
CustomGBForce::CustomGBForce() : nonbondedMethod(NoCutoff), cutoffDistance(1.0) { CustomGBForce::CustomGBForce() : nonbondedMethod(NoCutoff), cutoffDistance(1.0) {
} }
CustomGBForce::~CustomGBForce() {
for (int i = 0; i < (int) functions.size(); i++)
delete functions[i].function;
}
CustomGBForce::NonbondedMethod CustomGBForce::getNonbondedMethod() const { CustomGBForce::NonbondedMethod CustomGBForce::getNonbondedMethod() const {
return nonbondedMethod; return nonbondedMethod;
} }
...@@ -173,33 +178,47 @@ void CustomGBForce::setExclusionParticles(int index, int particle1, int particle ...@@ -173,33 +178,47 @@ void CustomGBForce::setExclusionParticles(int index, int particle1, int particle
exclusions[index].particle2 = particle2; exclusions[index].particle2 = particle2;
} }
int CustomGBForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomGBForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomGBForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomGBForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
int CustomGBForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) { int CustomGBForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) functions.push_back(FunctionInfo(name, new Continuous1DFunction(values, min, max)));
throw OpenMMException("CustomGBForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomGBForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1; return functions.size()-1;
} }
void CustomGBForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const { void CustomGBForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
Continuous1DFunction* function = dynamic_cast<Continuous1DFunction*>(functions[index].function);
if (function == NULL)
throw OpenMMException("CustomGBForce: function is not a Continuous1DFunction");
name = functions[index].name; name = functions[index].name;
values = functions[index].values; function->getFunctionParameters(values, min, max);
min = functions[index].min;
max = functions[index].max;
} }
void CustomGBForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) { void CustomGBForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomGBForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomGBForce: a tabulated function must have at least two points");
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
Continuous1DFunction* function = dynamic_cast<Continuous1DFunction*>(functions[index].function);
if (function == NULL)
throw OpenMMException("CustomGBForce: function is not a Continuous1DFunction");
functions[index].name = name; functions[index].name = name;
functions[index].values = values; function->setFunctionParameters(values, min, max);
functions[index].min = min;
functions[index].max = max;
} }
ForceImpl* CustomGBForce::createImpl() const { ForceImpl* CustomGBForce::createImpl() const {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -50,6 +50,12 @@ using std::vector; ...@@ -50,6 +50,12 @@ using std::vector;
CustomHbondForce::CustomHbondForce(const string& energy) : energyExpression(energy), nonbondedMethod(NoCutoff), cutoffDistance(1.0) { CustomHbondForce::CustomHbondForce(const string& energy) : energyExpression(energy), nonbondedMethod(NoCutoff), cutoffDistance(1.0) {
} }
CustomHbondForce::~CustomHbondForce() {
for (int i = 0; i < (int) functions.size(); i++)
delete functions[i].function;
}
const string& CustomHbondForce::getEnergyFunction() const { const string& CustomHbondForce::getEnergyFunction() const {
return energyExpression; return energyExpression;
} }
...@@ -187,33 +193,47 @@ void CustomHbondForce::setExclusionParticles(int index, int donor, int acceptor) ...@@ -187,33 +193,47 @@ void CustomHbondForce::setExclusionParticles(int index, int donor, int acceptor)
exclusions[index].acceptor = acceptor; exclusions[index].acceptor = acceptor;
} }
int CustomHbondForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomHbondForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomHbondForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomHbondForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
int CustomHbondForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) { int CustomHbondForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) functions.push_back(FunctionInfo(name, new Continuous1DFunction(values, min, max)));
throw OpenMMException("CustomHbondForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomHbondForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1; return functions.size()-1;
} }
void CustomHbondForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const { void CustomHbondForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
Continuous1DFunction* function = dynamic_cast<Continuous1DFunction*>(functions[index].function);
if (function == NULL)
throw OpenMMException("CustomHbondForce: function is not a Continuous1DFunction");
name = functions[index].name; name = functions[index].name;
values = functions[index].values; function->getFunctionParameters(values, min, max);
min = functions[index].min;
max = functions[index].max;
} }
void CustomHbondForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) { void CustomHbondForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomHbondForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomHbondForce: a tabulated function must have at least two points");
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
Continuous1DFunction* function = dynamic_cast<Continuous1DFunction*>(functions[index].function);
if (function == NULL)
throw OpenMMException("CustomHbondForce: function is not a Continuous1DFunction");
functions[index].name = name; functions[index].name = name;
functions[index].values = values; function->setFunctionParameters(values, min, max);
functions[index].min = min;
functions[index].max = max;
} }
ForceImpl* CustomHbondForce::createImpl() const { ForceImpl* CustomHbondForce::createImpl() const {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -51,6 +51,11 @@ CustomNonbondedForce::CustomNonbondedForce(const string& energy) : energyExpress ...@@ -51,6 +51,11 @@ CustomNonbondedForce::CustomNonbondedForce(const string& energy) : energyExpress
switchingDistance(-1.0), useSwitchingFunction(false), useLongRangeCorrection(false) { switchingDistance(-1.0), useSwitchingFunction(false), useLongRangeCorrection(false) {
} }
CustomNonbondedForce::~CustomNonbondedForce() {
for (int i = 0; i < (int) functions.size(); i++)
delete functions[i].function;
}
const string& CustomNonbondedForce::getEnergyFunction() const { const string& CustomNonbondedForce::getEnergyFunction() const {
return energyExpression; return energyExpression;
} }
...@@ -169,34 +174,47 @@ void CustomNonbondedForce::setExclusionParticles(int index, int particle1, int p ...@@ -169,34 +174,47 @@ void CustomNonbondedForce::setExclusionParticles(int index, int particle1, int p
exclusions[index].particle1 = particle1; exclusions[index].particle1 = particle1;
exclusions[index].particle2 = particle2; exclusions[index].particle2 = particle2;
} }
int CustomNonbondedForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomNonbondedForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomNonbondedForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomNonbondedForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
int CustomNonbondedForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) { int CustomNonbondedForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) functions.push_back(FunctionInfo(name, new Continuous1DFunction(values, min, max)));
throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomNonbondedForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1; return functions.size()-1;
} }
void CustomNonbondedForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const { void CustomNonbondedForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
Continuous1DFunction* function = dynamic_cast<Continuous1DFunction*>(functions[index].function);
if (function == NULL)
throw OpenMMException("CustomNonbondedForce: function is not a Continuous1DFunction");
name = functions[index].name; name = functions[index].name;
values = functions[index].values; function->getFunctionParameters(values, min, max);
min = functions[index].min;
max = functions[index].max;
} }
void CustomNonbondedForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) { void CustomNonbondedForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomNonbondedForce: a tabulated function must have at least two points");
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
Continuous1DFunction* function = dynamic_cast<Continuous1DFunction*>(functions[index].function);
if (function == NULL)
throw OpenMMException("CustomNonbondedForce: function is not a Continuous1DFunction");
functions[index].name = name; functions[index].name = name;
functions[index].values = values; function->setFunctionParameters(values, min, max);
functions[index].min = min;
functions[index].max = max;
} }
int CustomNonbondedForce::addInteractionGroup(const std::set<int>& set1, const std::set<int>& set2) { int CustomNonbondedForce::addInteractionGroup(const std::set<int>& set1, const std::set<int>& set2) {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010 Stanford University and the Authors. * * Portions copyright (c) 2010-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -190,3 +190,535 @@ void SplineFitter::solveTridiagonalMatrix(const vector<double>& a, const vector< ...@@ -190,3 +190,535 @@ void SplineFitter::solveTridiagonalMatrix(const vector<double>& a, const vector<
for (int i = n-2; i >= 0; i--) for (int i = n-2; i >= 0; i--)
sol[i] -= gamma[i+1]*sol[i+1]; sol[i] -= gamma[i+1]*sol[i+1];
} }
void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& values, vector<vector<double> >& c) {
int xsize = x.size(), ysize = y.size();
if (xsize < 2 || ysize < 2)
throw OpenMMException("create2DNaturalSpline: must have at least two points along each axis");
if (values.size() != xsize*ysize)
throw OpenMMException("create2DNaturalSpline: incorrect number of values");
vector<double> d1(xsize*ysize), d2(xsize*ysize), d12(xsize*ysize);
vector<double> t(xsize), deriv(xsize);
// Compute derivatives with respect to x.
for (int i = 0; i < ysize; i++) {
for (int j = 0; j < xsize; j++)
t[j] = values[j+xsize*i];
SplineFitter::createNaturalSpline(x, t, deriv);
for (int j = 0; j < xsize; j++)
d1[j+xsize*i] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[j]);
}
// Compute derivatives with respect to y.
t.resize(ysize);
deriv.resize(ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++)
t[j] = values[i+xsize*j];
SplineFitter::createNaturalSpline(y, t, deriv);
for (int j = 0; j < ysize; j++)
d2[i+xsize*j] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[j]);
}
// Compute cross derivatives.
t.resize(xsize);
deriv.resize(xsize);
for (int i = 0; i < ysize; i++) {
for (int j = 0; j < xsize; j++)
t[j] = d2[j+xsize*i];
SplineFitter::createNaturalSpline(x, t, deriv);
for (int j = 0; j < xsize; j++)
d12[j+xsize*i] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[j]);
}
// Now compute the coefficients.
const int wt[] = {
1, 0, -3, 2, 0, 0, 0, 0, -3, 0, 9, -6, 2, 0, -6, 4,
0, 0, 0, 0, 0, 0, 0, 0, 3, 0, -9, 6, -2, 0, 6, -4,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, -6, 0, 0, -6, 4,
0, 0, 3, -2, 0, 0, 0, 0, 0, 0, -9, 6, 0, 0, 6, -4,
0, 0, 0, 0, 1, 0, -3, 2, -2, 0, 6, -4, 1, 0, -3, 2,
0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 3, -2, 1, 0, -3, 2,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, 2, 0, 0, 3, -2,
0, 0, 0, 0, 0, 0, 3, -2, 0, 0, -6, 4, 0, 0, 3, -2,
0, 1, -2, 1, 0, 0, 0, 0, 0, -3, 6, -3, 0, 2, -4, 2,
0, 0, 0, 0, 0, 0, 0, 0, 0, 3, -6, 3, 0, -2, 4, -2,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, 3, 0, 0, 2, -2,
0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 3, -3, 0, 0, -2, 2,
0, 0, 0, 0, 0, 1, -2, 1, 0, -2, 4, -2, 0, 1, -2, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 2, -1, 0, 1, -2, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, 1,
0, 0, 0, 0, 0, 0, -1, 1, 0, 0, 2, -2, 0, 0, -1, 1
};
vector<double> rhs(16);
c.resize((xsize-1)*(ysize-1));
for (int i = 0; i < xsize-1; i++) {
for (int j = 0; j < ysize-1; j++) {
// Compute the 16 coefficients for patch (i, j).
int nexti = i+1;
int nextj = j+1;
double deltax = x[nexti]-x[i];
double deltay = y[nextj]-y[j];
double e[] = {values[i+j*xsize], values[nexti+j*xsize], values[nexti+nextj*xsize], values[i+nextj*xsize]};
double e1[] = {d1[i+j*xsize], d1[nexti+j*xsize], d1[nexti+nextj*xsize], d1[i+nextj*xsize]};
double e2[] = {d2[i+j*xsize], d2[nexti+j*xsize], d2[nexti+nextj*xsize], d2[i+nextj*xsize]};
double e12[] = {d12[i+j*xsize], d12[nexti+j*xsize], d12[nexti+nextj*xsize], d12[i+nextj*xsize]};
for (int k = 0; k < 4; k++) {
rhs[k] = e[k];
rhs[k+4] = e1[k]*deltax;
rhs[k+8] = e2[k]*deltay;
rhs[k+12] = e12[k]*deltax*deltay;
}
vector<double>& coeff = c[i+j*(xsize-1)];
coeff.resize(16);
for (int k = 0; k < 16; k++) {
double sum = 0.0;
for (int m = 0; m < 16; m++)
sum += wt[k+16*m]*rhs[m];
coeff[k] = sum;
}
}
}
}
double SplineFitter::evaluate2DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& values, const vector<vector<double> >& c, double u, double v) {
int xsize = x.size();
int ysize = y.size();
if (u < x[0] || u > x[xsize-1] || v < y[0] || v > y[ysize-1])
throw OpenMMException("evaluate2DSpline: specified point is outside the range defined by the spline");
// Perform a binary search to identify the interval containing the point to evaluate.
int lowerx = 0;
int upperx = xsize-1;
while (upperx-lowerx > 1) {
int middle = (upperx+lowerx)/2;
if (x[middle] > u)
upperx = middle;
else
lowerx = middle;
}
int lowery = 0;
int uppery = ysize-1;
while (uppery-lowery > 1) {
int middle = (uppery+lowery)/2;
if (y[middle] > v)
uppery = middle;
else
lowery = middle;
}
double deltax = x[upperx]-x[lowerx];
double deltay = y[uppery]-y[lowery];
double da = (u-x[lowerx])/deltax;
double db = (v-y[lowery])/deltay;
const vector<double>& coeff = c[lowerx+(xsize-1)*lowery];
// Evaluate the spline to determine the value.
double value = 0;
for (int i = 3; i >= 0; i--)
value = da*value + ((coeff[i*4+3]*db + coeff[i*4+2])*db + coeff[i*4+1])*db + coeff[i*4+0];
return value;
}
void SplineFitter::evaluate2DSplineDerivatives(const vector<double>& x, const vector<double>& y, const vector<double>& values, const vector<vector<double> >& c, double u, double v, double& dx, double &dy) {
int xsize = x.size();
int ysize = y.size();
if (u < x[0] || u > x[xsize-1] || v < y[0] || v > y[ysize-1])
throw OpenMMException("evaluate2DSplineDerivatives: specified point is outside the range defined by the spline");
// Perform a binary search to identify the interval containing the point to evaluate.
int lowerx = 0;
int upperx = xsize-1;
while (upperx-lowerx > 1) {
int middle = (upperx+lowerx)/2;
if (x[middle] > u)
upperx = middle;
else
lowerx = middle;
}
int lowery = 0;
int uppery = ysize-1;
while (uppery-lowery > 1) {
int middle = (uppery+lowery)/2;
if (y[middle] > v)
uppery = middle;
else
lowery = middle;
}
double deltax = x[upperx]-x[lowerx];
double deltay = y[uppery]-y[lowery];
double da = (u-x[lowerx])/deltax;
double db = (v-y[lowery])/deltay;
const vector<double>& coeff = c[lowerx+(xsize-1)*lowery];
// Evaluate the spline to determine the derivatives.
dx = 0;
dy = 0;
for (int i = 3; i >= 0; i--) {
dx = db*dx + (3.0*coeff[i+3*4]*da + 2.0*coeff[i+2*4])*da + coeff[i+1*4];
dy = da*dy + (3.0*coeff[i*4+3]*db + 2.0*coeff[i*4+2])*db + coeff[i*4+1];
}
dx /= deltax;
dy /= deltay;
}
void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, vector<vector<double> >& c) {
int xsize = x.size(), ysize = y.size(), zsize = z.size();
int xysize = xsize*ysize;
if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("create2DNaturalSpline: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("create2DNaturalSpline: incorrect number of values");
vector<double> d1(xsize*ysize*zsize), d2(xsize*ysize*zsize), d3(xsize*ysize*zsize);
vector<double> d12(xsize*ysize*zsize), d13(xsize*ysize*zsize), d23(xsize*ysize*zsize), d123(xsize*ysize*zsize);
vector<double> t(xsize), deriv(xsize);
// Compute derivatives with respect to x.
for (int i = 0; i < ysize; i++) {
for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++)
t[k] = values[k+xsize*i+xysize*j];
SplineFitter::createNaturalSpline(x, t, deriv);
for (int k = 0; k < xsize; k++)
d1[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
}
}
// Compute derivatives with respect to y.
t.resize(ysize);
deriv.resize(ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < zsize; j++) {
for (int k = 0; k < ysize; k++)
t[k] = values[i+xsize*k+xysize*j];
SplineFitter::createNaturalSpline(y, t, deriv);
for (int k = 0; k < ysize; k++)
d2[i+xsize*k+xysize*j] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]);
}
}
// Compute derivatives with respect to z.
t.resize(zsize);
deriv.resize(zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++)
t[k] = values[i+xsize*j+xysize*k];
SplineFitter::createNaturalSpline(z, t, deriv);
for (int k = 0; k < zsize; k++)
d3[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]);
}
}
// Compute second derivatives with respect to x and y.
t.resize(xsize);
deriv.resize(xsize);
for (int i = 0; i < ysize; i++) {
for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++)
t[k] = d2[k+xsize*i+xysize*j];
SplineFitter::createNaturalSpline(x, t, deriv);
for (int k = 0; k < xsize; k++)
d12[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
}
}
// Compute second derivatives with respect to y and z.
t.resize(ysize);
deriv.resize(ysize);
for (int i = 0; i < zsize; i++) {
for (int j = 0; j < xsize; j++) {
for (int k = 0; k < ysize; k++)
t[k] = d3[j+xsize*k+xysize*i];
SplineFitter::createNaturalSpline(y, t, deriv);
for (int k = 0; k < ysize; k++)
d23[j+xsize*k+xysize*i] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]);
}
}
// Compute second derivatives with respect to x and z.
t.resize(zsize);
deriv.resize(zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++)
t[k] = d1[i+xsize*j+xysize*k];
SplineFitter::createNaturalSpline(z, t, deriv);
for (int k = 0; k < zsize; k++)
d13[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]);
}
}
// Compute third derivatives with respect to x, y, and z.
t.resize(xsize);
deriv.resize(xsize);
for (int i = 0; i < ysize; i++) {
for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++)
t[k] = d23[k+xsize*i+xysize*j];
SplineFitter::createNaturalSpline(x, t, deriv);
for (int k = 0; k < xsize; k++)
d123[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
}
}
// Now compute the coefficients. This involves multiplying by a sparse 64x64 matrix, given
// here in packed form.
const int wt[] = {
1,0,1,
1,8,1,
4,0,-3,1,3,8,-2,9,-1,
4,0,2,1,-2,8,1,9,1,
1,16,1,
1,32,1,
4,16,-3,17,3,32,-2,33,-1,
4,16,2,17,-2,32,1,33,1,
4,0,-3,2,3,16,-2,18,-1,
4,8,-3,10,3,32,-2,34,-1,
16,0,9,1,-9,2,-9,3,9,8,6,9,3,10,-6,11,-3,16,6,17,-6,18,3,19,-3,32,4,33,2,34,2,35,1,
16,0,-6,1,6,2,6,3,-6,8,-3,9,-3,10,3,11,3,16,-4,17,4,18,-2,19,2,32,-2,33,-2,34,-1,35,-1,
4,0,2,2,-2,16,1,18,1,
4,8,2,10,-2,32,1,34,1,
16,0,-6,1,6,2,6,3,-6,8,-4,9,-2,10,4,11,2,16,-3,17,3,18,-3,19,3,32,-2,33,-1,34,-2,35,-1,
16,0,4,1,-4,2,-4,3,4,8,2,9,2,10,-2,11,-2,16,2,17,-2,18,2,19,-2,32,1,33,1,34,1,35,1,
1,24,1,
1,40,1,
4,24,-3,25,3,40,-2,41,-1,
4,24,2,25,-2,40,1,41,1,
1,48,1,
1,56,1,
4,48,-3,49,3,56,-2,57,-1,
4,48,2,49,-2,56,1,57,1,
4,24,-3,26,3,48,-2,50,-1,
4,40,-3,42,3,56,-2,58,-1,
16,24,9,25,-9,26,-9,27,9,40,6,41,3,42,-6,43,-3,48,6,49,-6,50,3,51,-3,56,4,57,2,58,2,59,1,
16,24,-6,25,6,26,6,27,-6,40,-3,41,-3,42,3,43,3,48,-4,49,4,50,-2,51,2,56,-2,57,-2,58,-1,59,-1,
4,24,2,26,-2,48,1,50,1,
4,40,2,42,-2,56,1,58,1,
16,24,-6,25,6,26,6,27,-6,40,-4,41,-2,42,4,43,2,48,-3,49,3,50,-3,51,3,56,-2,57,-1,58,-2,59,-1,
16,24,4,25,-4,26,-4,27,4,40,2,41,2,42,-2,43,-2,48,2,49,-2,50,2,51,-2,56,1,57,1,58,1,59,1,
4,0,-3,4,3,24,-2,28,-1,
4,8,-3,12,3,40,-2,44,-1,
16,0,9,1,-9,4,-9,5,9,8,6,9,3,12,-6,13,-3,24,6,25,-6,28,3,29,-3,40,4,41,2,44,2,45,1,
16,0,-6,1,6,4,6,5,-6,8,-3,9,-3,12,3,13,3,24,-4,25,4,28,-2,29,2,40,-2,41,-2,44,-1,45,-1,
4,16,-3,20,3,48,-2,52,-1,
4,32,-3,36,3,56,-2,60,-1,
16,16,9,17,-9,20,-9,21,9,32,6,33,3,36,-6,37,-3,48,6,49,-6,52,3,53,-3,56,4,57,2,60,2,61,1,
16,16,-6,17,6,20,6,21,-6,32,-3,33,-3,36,3,37,3,48,-4,49,4,52,-2,53,2,56,-2,57,-2,60,-1,61,-1,
16,0,9,2,-9,4,-9,6,9,16,6,18,3,20,-6,22,-3,24,6,26,-6,28,3,30,-3,48,4,50,2,52,2,54,1,
16,8,9,10,-9,12,-9,14,9,32,6,34,3,36,-6,38,-3,40,6,42,-6,44,3,46,-3,56,4,58,2,60,2,62,1,
64,0,-27,1,27,2,27,3,-27,4,27,5,-27,6,-27,7,27,8,-18,9,-9,10,18,11,9,12,18,13,9,14,-18,15,-9,16,-18,17,18,18,-9,19,9,20,18,21,-18,22,9,23,-9,24,-18,25,18,26,18,27,-18,28,-9,29,9,30,9,31,-9,32,-12,33,-6,34,-6,35,-3,36,12,37,6,38,6,39,3,40,-12,41,-6,42,12,43,6,44,-6,45,-3,46,6,47,3,48,-12,49,12,50,-6,51,6,52,-6,53,6,54,-3,55,3,56,-8,57,-4,58,-4,59,-2,60,-4,61,-2,62,-2,63,-1,
64,0,18,1,-18,2,-18,3,18,4,-18,5,18,6,18,7,-18,8,9,9,9,10,-9,11,-9,12,-9,13,-9,14,9,15,9,16,12,17,-12,18,6,19,-6,20,-12,21,12,22,-6,23,6,24,12,25,-12,26,-12,27,12,28,6,29,-6,30,-6,31,6,32,6,33,6,34,3,35,3,36,-6,37,-6,38,-3,39,-3,40,6,41,6,42,-6,43,-6,44,3,45,3,46,-3,47,-3,48,8,49,-8,50,4,51,-4,52,4,53,-4,54,2,55,-2,56,4,57,4,58,2,59,2,60,2,61,2,62,1,63,1,
16,0,-6,2,6,4,6,6,-6,16,-3,18,-3,20,3,22,3,24,-4,26,4,28,-2,30,2,48,-2,50,-2,52,-1,54,-1,
16,8,-6,10,6,12,6,14,-6,32,-3,34,-3,36,3,38,3,40,-4,42,4,44,-2,46,2,56,-2,58,-2,60,-1,62,-1,
64,0,18,1,-18,2,-18,3,18,4,-18,5,18,6,18,7,-18,8,12,9,6,10,-12,11,-6,12,-12,13,-6,14,12,15,6,16,9,17,-9,18,9,19,-9,20,-9,21,9,22,-9,23,9,24,12,25,-12,26,-12,27,12,28,6,29,-6,30,-6,31,6,32,6,33,3,34,6,35,3,36,-6,37,-3,38,-6,39,-3,40,8,41,4,42,-8,43,-4,44,4,45,2,46,-4,47,-2,48,6,49,-6,50,6,51,-6,52,3,53,-3,54,3,55,-3,56,4,57,2,58,4,59,2,60,2,61,1,62,2,63,1,
64,0,-12,1,12,2,12,3,-12,4,12,5,-12,6,-12,7,12,8,-6,9,-6,10,6,11,6,12,6,13,6,14,-6,15,-6,16,-6,17,6,18,-6,19,6,20,6,21,-6,22,6,23,-6,24,-8,25,8,26,8,27,-8,28,-4,29,4,30,4,31,-4,32,-3,33,-3,34,-3,35,-3,36,3,37,3,38,3,39,3,40,-4,41,-4,42,4,43,4,44,-2,45,-2,46,2,47,2,48,-4,49,4,50,-4,51,4,52,-2,53,2,54,-2,55,2,56,-2,57,-2,58,-2,59,-2,60,-1,61,-1,62,-1,63,-1,
4,0,2,4,-2,24,1,28,1,
4,8,2,12,-2,40,1,44,1,
16,0,-6,1,6,4,6,5,-6,8,-4,9,-2,12,4,13,2,24,-3,25,3,28,-3,29,3,40,-2,41,-1,44,-2,45,-1,
16,0,4,1,-4,4,-4,5,4,8,2,9,2,12,-2,13,-2,24,2,25,-2,28,2,29,-2,40,1,41,1,44,1,45,1,
4,16,2,20,-2,48,1,52,1,
4,32,2,36,-2,56,1,60,1,
16,16,-6,17,6,20,6,21,-6,32,-4,33,-2,36,4,37,2,48,-3,49,3,52,-3,53,3,56,-2,57,-1,60,-2,61,-1,
16,16,4,17,-4,20,-4,21,4,32,2,33,2,36,-2,37,-2,48,2,49,-2,52,2,53,-2,56,1,57,1,60,1,61,1,
16,0,-6,2,6,4,6,6,-6,16,-4,18,-2,20,4,22,2,24,-3,26,3,28,-3,30,3,48,-2,50,-1,52,-2,54,-1,
16,8,-6,10,6,12,6,14,-6,32,-4,34,-2,36,4,38,2,40,-3,42,3,44,-3,46,3,56,-2,58,-1,60,-2,62,-1,
64,0,18,1,-18,2,-18,3,18,4,-18,5,18,6,18,7,-18,8,12,9,6,10,-12,11,-6,12,-12,13,-6,14,12,15,6,16,12,17,-12,18,6,19,-6,20,-12,21,12,22,-6,23,6,24,9,25,-9,26,-9,27,9,28,9,29,-9,30,-9,31,9,32,8,33,4,34,4,35,2,36,-8,37,-4,38,-4,39,-2,40,6,41,3,42,-6,43,-3,44,6,45,3,46,-6,47,-3,48,6,49,-6,50,3,51,-3,52,6,53,-6,54,3,55,-3,56,4,57,2,58,2,59,1,60,4,61,2,62,2,63,1,
64,0,-12,1,12,2,12,3,-12,4,12,5,-12,6,-12,7,12,8,-6,9,-6,10,6,11,6,12,6,13,6,14,-6,15,-6,16,-8,17,8,18,-4,19,4,20,8,21,-8,22,4,23,-4,24,-6,25,6,26,6,27,-6,28,-6,29,6,30,6,31,-6,32,-4,33,-4,34,-2,35,-2,36,4,37,4,38,2,39,2,40,-3,41,-3,42,3,43,3,44,-3,45,-3,46,3,47,3,48,-4,49,4,50,-2,51,2,52,-4,53,4,54,-2,55,2,56,-2,57,-2,58,-1,59,-1,60,-2,61,-2,62,-1,63,-1,
16,0,4,2,-4,4,-4,6,4,16,2,18,2,20,-2,22,-2,24,2,26,-2,28,2,30,-2,48,1,50,1,52,1,54,1,
16,8,4,10,-4,12,-4,14,4,32,2,34,2,36,-2,38,-2,40,2,42,-2,44,2,46,-2,56,1,58,1,60,1,62,1,
64,0,-12,1,12,2,12,3,-12,4,12,5,-12,6,-12,7,12,8,-8,9,-4,10,8,11,4,12,8,13,4,14,-8,15,-4,16,-6,17,6,18,-6,19,6,20,6,21,-6,22,6,23,-6,24,-6,25,6,26,6,27,-6,28,-6,29,6,30,6,31,-6,32,-4,33,-2,34,-4,35,-2,36,4,37,2,38,4,39,2,40,-4,41,-2,42,4,43,2,44,-4,45,-2,46,4,47,2,48,-3,49,3,50,-3,51,3,52,-3,53,3,54,-3,55,3,56,-2,57,-1,58,-2,59,-1,60,-2,61,-1,62,-2,63,-1,
64,0,8,1,-8,2,-8,3,8,4,-8,5,8,6,8,7,-8,8,4,9,4,10,-4,11,-4,12,-4,13,-4,14,4,15,4,16,4,17,-4,18,4,19,-4,20,-4,21,4,22,-4,23,4,24,4,25,-4,26,-4,27,4,28,4,29,-4,30,-4,31,4,32,2,33,2,34,2,35,2,36,-2,37,-2,38,-2,39,-2,40,2,41,2,42,-2,43,-2,44,2,45,2,46,-2,47,-2,48,2,49,-2,50,2,51,-2,52,2,53,-2,54,2,55,-2,56,1,57,1,58,1,59,1,60,1,61,1,62,1,63,1
};
vector<vector<int> > weight(64);
int index = 0;
for (int i = 0; i < 64; i++) {
int numElements = wt[index++];
for (int j = 0; j < numElements; j++) {
weight[i].push_back(wt[index++]);
weight[i].push_back(wt[index++]);
}
}
vector<double> rhs(64);
c.resize((xsize-1)*(ysize-1)*(zsize-1));
for (int i = 0; i < xsize-1; i++) {
for (int j = 0; j < ysize-1; j++) {
for (int k = 0; k < zsize-1; k++) {
// Compute the 64 coefficients for patch (i, j, k).
int nexti = i+1;
int nextj = j+1;
int nextk = k+1;
double deltax = x[nexti]-x[i];
double deltay = y[nextj]-y[j];
double deltaz = z[nextj]-z[j];
double e[] = {values[i+j*xsize+k*xysize], values[nexti+j*xsize+k*xysize], values[i+nextj*xsize+k*xysize], values[nexti+nextj*xsize+k*xysize], values[i+j*xsize+nextk*xysize], values[nexti+j*xsize+nextk*xysize], values[i+nextj*xsize+nextk*xysize], values[nexti+nextj*xsize+nextk*xysize]};
double e1[] = {d1[i+j*xsize+k*xysize], d1[nexti+j*xsize+k*xysize], d1[i+nextj*xsize+k*xysize], d1[nexti+nextj*xsize+k*xysize], d1[i+j*xsize+nextk*xysize], d1[nexti+j*xsize+nextk*xysize], d1[i+nextj*xsize+nextk*xysize], d1[nexti+nextj*xsize+nextk*xysize]};
double e2[] = {d2[i+j*xsize+k*xysize], d2[nexti+j*xsize+k*xysize], d2[i+nextj*xsize+k*xysize], d2[nexti+nextj*xsize+k*xysize], d2[i+j*xsize+nextk*xysize], d2[nexti+j*xsize+nextk*xysize], d2[i+nextj*xsize+nextk*xysize], d2[nexti+nextj*xsize+nextk*xysize]};
double e3[] = {d3[i+j*xsize+k*xysize], d3[nexti+j*xsize+k*xysize], d3[i+nextj*xsize+k*xysize], d3[nexti+nextj*xsize+k*xysize], d3[i+j*xsize+nextk*xysize], d3[nexti+j*xsize+nextk*xysize], d3[i+nextj*xsize+nextk*xysize], d3[nexti+nextj*xsize+nextk*xysize]};
double e12[] = {d12[i+j*xsize+k*xysize], d12[nexti+j*xsize+k*xysize], d12[i+nextj*xsize+k*xysize], d12[nexti+nextj*xsize+k*xysize], d12[i+j*xsize+nextk*xysize], d12[nexti+j*xsize+nextk*xysize], d12[i+nextj*xsize+nextk*xysize], d12[nexti+nextj*xsize+nextk*xysize]};
double e13[] = {d13[i+j*xsize+k*xysize], d13[nexti+j*xsize+k*xysize], d13[i+nextj*xsize+k*xysize], d13[nexti+nextj*xsize+k*xysize], d13[i+j*xsize+nextk*xysize], d13[nexti+j*xsize+nextk*xysize], d13[i+nextj*xsize+nextk*xysize], d13[nexti+nextj*xsize+nextk*xysize]};
double e23[] = {d23[i+j*xsize+k*xysize], d23[nexti+j*xsize+k*xysize], d23[i+nextj*xsize+k*xysize], d23[nexti+nextj*xsize+k*xysize], d23[i+j*xsize+nextk*xysize], d23[nexti+j*xsize+nextk*xysize], d23[i+nextj*xsize+nextk*xysize], d23[nexti+nextj*xsize+nextk*xysize]};
double e123[] = {d123[i+j*xsize+k*xysize], d123[nexti+j*xsize+k*xysize], d123[i+nextj*xsize+k*xysize], d123[nexti+nextj*xsize+k*xysize], d123[i+j*xsize+nextk*xysize], d123[nexti+j*xsize+nextk*xysize], d123[i+nextj*xsize+nextk*xysize], d123[nexti+nextj*xsize+nextk*xysize]};
for (int m = 0; m < 8; m++) {
rhs[m] = e[m];
rhs[m+8] = e1[m]*deltax;
rhs[m+16] = e2[m]*deltay;
rhs[m+24] = e3[m]*deltaz;
rhs[m+32] = e12[m]*deltax*deltay;
rhs[m+40] = e13[m]*deltax*deltaz;
rhs[m+48] = e23[m]*deltay*deltaz;
rhs[m+56] = e123[m]*deltax*deltay*deltaz;
}
vector<double>& coeff = c[i+j*(xsize-1)+k*(xsize-1)*(ysize-1)];
coeff.resize(64);
for (int m = 0; m < 64; m++) {
double sum = 0.0;
int numElements = weight[m].size();
for (int n = 0; n < numElements; n += 2)
sum += weight[m][n+1]*rhs[weight[m][n]];
coeff[m] = sum;
}
}
}
}
}
double SplineFitter::evaluate3DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, const vector<vector<double> >& c, double u, double v, double w) {
int xsize = x.size();
int ysize = y.size();
int zsize = z.size();
if (u < x[0] || u > x[xsize-1] || v < y[0] || v > y[ysize-1] || w < z[0] || w > z[zsize-1])
throw OpenMMException("evaluate3DSpline: specified point is outside the range defined by the spline");
// Perform a binary search to identify the interval containing the point to evaluate.
int lowerx = 0;
int upperx = xsize-1;
while (upperx-lowerx > 1) {
int middle = (upperx+lowerx)/2;
if (x[middle] > u)
upperx = middle;
else
lowerx = middle;
}
int lowery = 0;
int uppery = ysize-1;
while (uppery-lowery > 1) {
int middle = (uppery+lowery)/2;
if (y[middle] > v)
uppery = middle;
else
lowery = middle;
}
int lowerz = 0;
int upperz = zsize-1;
while (upperz-lowerz > 1) {
int middle = (upperz+lowerz)/2;
if (z[middle] > w)
upperz = middle;
else
lowerz = middle;
}
double deltax = x[upperx]-x[lowerx];
double deltay = y[uppery]-y[lowery];
double deltaz = z[upperz]-z[lowerz];
double da = (u-x[lowerx])/deltax;
double db = (v-y[lowery])/deltay;
double dc = (w-z[lowerz])/deltaz;
const vector<double>& coeff = c[lowerx+(xsize-1)*lowery+(xsize-1)*(ysize-1)*lowerz];
// Evaluate the spline to determine the value and gradients.
double value[] = {0, 0, 0, 0};
for (int i = 3; i >= 0; i--) {
for (int j = 0; j < 4; j++) {
int base = 4*i + 16*j;
value[j] = db*value[j] + ((coeff[base+3]*da + coeff[base+2])*da + coeff[base+1])*da + coeff[base];
}
}
return value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));
}
void SplineFitter::evaluate3DSplineDerivatives(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, const vector<vector<double> >& c, double u, double v, double w, double& dx, double& dy, double& dz) {
int xsize = x.size();
int ysize = y.size();
int zsize = z.size();
if (u < x[0] || u > x[xsize-1] || v < y[0] || v > y[ysize-1] || w < z[0] || w > z[zsize-1])
throw OpenMMException("evaluate3DSpline: specified point is outside the range defined by the spline");
// Perform a binary search to identify the interval containing the point to evaluate.
int lowerx = 0;
int upperx = xsize-1;
while (upperx-lowerx > 1) {
int middle = (upperx+lowerx)/2;
if (x[middle] > u)
upperx = middle;
else
lowerx = middle;
}
int lowery = 0;
int uppery = ysize-1;
while (uppery-lowery > 1) {
int middle = (uppery+lowery)/2;
if (y[middle] > v)
uppery = middle;
else
lowery = middle;
}
int lowerz = 0;
int upperz = zsize-1;
while (upperz-lowerz > 1) {
int middle = (upperz+lowerz)/2;
if (z[middle] > w)
upperz = middle;
else
lowerz = middle;
}
double deltax = x[upperx]-x[lowerx];
double deltay = y[uppery]-y[lowery];
double deltaz = z[upperz]-z[lowerz];
double da = (u-x[lowerx])/deltax;
double db = (v-y[lowery])/deltay;
double dc = (w-z[lowerz])/deltaz;
const vector<double>& coeff = c[lowerx+(xsize-1)*lowery+(xsize-1)*(ysize-1)*lowerz];
// Evaluate the spline to determine the derivatives.
double derivx[] = {0, 0, 0, 0};
double derivy[] = {0, 0, 0, 0};
double derivz[] = {0, 0, 0, 0};
for (int i = 3; i >= 0; i--) {
for (int j = 0; j < 4; j++) {
int base = 4*i + 16*j;
derivx[j] = db*derivx[j] + (3.0*coeff[base+3]*da + 2.0*coeff[base+2])*da + coeff[base+1];
derivz[j] = db*derivz[j] + ((coeff[base+3]*da + coeff[base+2])*da + coeff[base+1])*da + coeff[base];
base = i + 16*j;
derivy[j] = da*derivy[j] + (3.0*coeff[base+12]*db + 2.0*coeff[base+8])*db + coeff[base+4];
}
}
dx = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));
dy = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));
dz = derivz[1] + dc*(2.0*derivz[2] + 3.0*dc*derivz[3]);
dx /= deltax;
dy /= deltay;
dz /= deltaz;
}
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/TabulatedFunction.h"
#include "openmm/OpenMMException.h"
using namespace OpenMM;
using namespace std;
Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("Continuous1DFunction: a tabulated function must have at least two points");
this->values = values;
this->min = min;
this->max = max;
}
void Continuous1DFunction::getFunctionParameters(vector<double>& values, double& min, double& max) const {
values = this->values;
min = this->min;
max = this->max;
}
void Continuous1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("Continuous1DFunction: a tabulated function must have at least two points");
this->values = values;
this->min = min;
this->max = max;
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize)
throw OpenMMException("Continuous2DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous2DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous2DFunction: ymax <= ymin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
}
void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax) const {
values = this->values;
xsize = this->xsize;
ysize = this->ysize;
xmin = this->xmin;
xmax = this->xmax;
ymin = this->ymin;
ymax = this->ymax;
}
void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize)
throw OpenMMException("Continuous2DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous2DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous2DFunction: ymax <= ymin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous3DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous3DFunction: ymax <= ymin for a tabulated function.");
if (zmax <= zmin)
throw OpenMMException("Continuous3DFunction: zmax <= zmin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
this->zmin = zmin;
this->zmax = zmax;
}
void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zsize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax, double& zmin, double& zmax) const {
values = this->values;
xsize = this->xsize;
ysize = this->ysize;
zsize = this->zsize;
xmin = this->xmin;
xmax = this->xmax;
ymin = this->ymin;
ymax = this->ymax;
zmin = this->zmin;
zmax = this->zmax;
}
void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous3DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous3DFunction: ymax <= ymin for a tabulated function.");
if (zmax <= zmin)
throw OpenMMException("Continuous3DFunction: zmax <= zmin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
this->zmin = zmin;
this->zmax = zmax;
}
Discrete1DFunction::Discrete1DFunction(const vector<double>& values) {
this->values = values;
}
void Discrete1DFunction::getFunctionParameters(vector<double>& values) const {
values = this->values;
}
void Discrete1DFunction::setFunctionParameters(const vector<double>& values) {
this->values = values;
}
Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) {
if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values");
this->xsize = xsize;
this->ysize = ysize;
this->values = values;
}
void Discrete2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<double>& values) const {
xsize = this->xsize;
ysize = this->ysize;
values = this->values;
}
void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vector<double>& values) {
if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values");
this->xsize = xsize;
this->ysize = ysize;
this->values = values;
}
Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) {
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values");
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->values = values;
}
void Discrete3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zsize, vector<double>& values) const {
xsize = this->xsize;
ysize = this->ysize;
zsize = this->zsize;
values = this->values;
}
void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, const vector<double>& values) {
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values");
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->values = values;
}
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2012 Stanford University and the Authors. * * Portions copyright (c) 2009-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "CudaContext.h" #include "CudaContext.h"
#include "openmm/TabulatedFunction.h"
#include "lepton/CustomFunction.h" #include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h" #include "lepton/ExpressionTreeNode.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
...@@ -45,65 +46,56 @@ namespace OpenMM { ...@@ -45,65 +46,56 @@ namespace OpenMM {
class OPENMM_EXPORT_CUDA CudaExpressionUtilities { class OPENMM_EXPORT_CUDA CudaExpressionUtilities {
public: public:
CudaExpressionUtilities(CudaContext& context) : context(context) { CudaExpressionUtilities(CudaContext& context);
}
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
* @param expressions the expressions to generate code for (keys are the variables to store the output values in) * @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are * @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are
* variable names, and the values are the code to generate for them. * variable names, and the values are the code to generate for them.
* @param functions defines the variable name for each tabulated function that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real") * @param tempType the type of value to use for temporary variables (defaults to "real")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="real"); const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& tempType="real");
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
* @param expressions the expressions to generate code for (keys are the variables to store the output values in) * @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions. * @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions.
* Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears. * Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears.
* @param functions defines the variable name for each tabulated function that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real") * @param tempType the type of value to use for temporary variables (defaults to "real")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="real"); const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& tempType="real");
/** /**
* Calculate the spline coefficients for a tabulated function that appears in expressions. * Calculate the spline coefficients for a tabulated function that appears in expressions.
* *
* @param values the tabulated values of the function * @param function the function for which to compute coefficients
* @param min the value of the independent variable corresponding to the first element of values * @param width on output, the number of floats used for each value
* @param max the value of the independent variable corresponding to the last element of values
* @return the spline coefficients * @return the spline coefficients
*/ */
std::vector<float4> computeFunctionCoefficients(const std::vector<double>& values, double min, double max); std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
class FunctionPlaceholder; /**
private: * Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node, *
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, * @param function the function for which to get a placeholder
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams,
const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
CudaContext& context;
};
/**
* This class serves as a placeholder for custom functions in expressions.
*/ */
Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function);
class CudaExpressionUtilities::FunctionPlaceholder : public Lepton::CustomFunction { private:
public: class FunctionPlaceholder : public Lepton::CustomFunction {
public:
FunctionPlaceholder(int numArgs) : numArgs(numArgs) {
}
int getNumArguments() const { int getNumArguments() const {
return 1; return numArgs;
} }
double evaluate(const double* arguments) const { double evaluate(const double* arguments) const {
return 0.0; return 0.0;
...@@ -112,8 +104,23 @@ public: ...@@ -112,8 +104,23 @@ public:
return 0.0; return 0.0;
} }
CustomFunction* clone() const { CustomFunction* clone() const {
return new FunctionPlaceholder(); return new FunctionPlaceholder(numArgs);
} }
private:
int numArgs;
};
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
CudaContext& context;
FunctionPlaceholder fp1, fp2, fp3;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -638,7 +638,7 @@ private: ...@@ -638,7 +638,7 @@ private:
class CudaCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel { class CudaCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public: public:
CudaCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomNonbondedForceKernel(name, platform), CudaCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomNonbondedForceKernel(name, platform),
cu(cu), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) { cu(cu), params(NULL), globals(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) {
} }
~CudaCalcCustomNonbondedForceKernel(); ~CudaCalcCustomNonbondedForceKernel();
/** /**
...@@ -669,7 +669,6 @@ private: ...@@ -669,7 +669,6 @@ private:
CudaContext& cu; CudaContext& cu;
CudaParameterSet* params; CudaParameterSet* params;
CudaArray* globals; CudaArray* globals;
CudaArray* tabulatedFunctionParams;
CudaArray* interactionGroupData; CudaArray* interactionGroupData;
CUfunction interactionGroupKernel; CUfunction interactionGroupKernel;
std::vector<void*> interactionGroupArgs; std::vector<void*> interactionGroupArgs;
...@@ -739,7 +738,7 @@ class CudaCalcCustomGBForceKernel : public CalcCustomGBForceKernel { ...@@ -739,7 +738,7 @@ class CudaCalcCustomGBForceKernel : public CalcCustomGBForceKernel {
public: public:
CudaCalcCustomGBForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomGBForceKernel(name, platform), CudaCalcCustomGBForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomGBForceKernel(name, platform),
hasInitializedKernels(false), cu(cu), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL), hasInitializedKernels(false), cu(cu), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL),
valueBuffers(NULL), tabulatedFunctionParams(NULL), system(system) { valueBuffers(NULL), system(system) {
} }
~CudaCalcCustomGBForceKernel(); ~CudaCalcCustomGBForceKernel();
/** /**
...@@ -776,7 +775,6 @@ private: ...@@ -776,7 +775,6 @@ private:
CudaArray* longEnergyDerivs; CudaArray* longEnergyDerivs;
CudaArray* globals; CudaArray* globals;
CudaArray* valueBuffers; CudaArray* valueBuffers;
CudaArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions; std::vector<CudaArray*> tabulatedFunctions;
...@@ -838,7 +836,7 @@ class CudaCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel { ...@@ -838,7 +836,7 @@ class CudaCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel {
public: public:
CudaCalcCustomHbondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomHbondForceKernel(name, platform), CudaCalcCustomHbondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomHbondForceKernel(name, platform),
hasInitializedKernel(false), cu(cu), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL), hasInitializedKernel(false), cu(cu), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL),
globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), tabulatedFunctionParams(NULL), system(system) { globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), system(system) {
} }
~CudaCalcCustomHbondForceKernel(); ~CudaCalcCustomHbondForceKernel();
/** /**
...@@ -875,7 +873,6 @@ private: ...@@ -875,7 +873,6 @@ private:
CudaArray* acceptors; CudaArray* acceptors;
CudaArray* donorExclusions; CudaArray* donorExclusions;
CudaArray* acceptorExclusions; CudaArray* acceptorExclusions;
CudaArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions; std::vector<CudaArray*> tabulatedFunctions;
...@@ -890,7 +887,7 @@ private: ...@@ -890,7 +887,7 @@ private:
class CudaCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel { class CudaCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel {
public: public:
CudaCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomCompoundBondForceKernel(name, platform), CudaCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomCompoundBondForceKernel(name, platform),
cu(cu), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), system(system) { cu(cu), params(NULL), globals(NULL), system(system) {
} }
~CudaCalcCustomCompoundBondForceKernel(); ~CudaCalcCustomCompoundBondForceKernel();
/** /**
...@@ -922,7 +919,6 @@ private: ...@@ -922,7 +919,6 @@ private:
CudaContext& cu; CudaContext& cu;
CudaParameterSet* params; CudaParameterSet* params;
CudaArray* globals; CudaArray* globals;
CudaArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions; std::vector<CudaArray*> tabulatedFunctions;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2012 Stanford University and the Authors. * * Portions copyright (c) 2009-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,35 +33,40 @@ using namespace OpenMM; ...@@ -33,35 +33,40 @@ using namespace OpenMM;
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context(context), fp1(1), fp2(2), fp3(3) {
}
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes; vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter) for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second)); variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, prefix, functionParams, tempType); return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
} }
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables, string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
stringstream out; stringstream out;
vector<ParsedExpression> allExpressions; vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second); allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables; vector<pair<ExpressionTreeNode, string> > temps = variables;
vector<vector<double> > functionParams = computeFunctionParameters(functions);
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) { for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, prefix, functionParams, allExpressions, tempType); processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n"; out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
} }
return out.str(); return out.str();
} }
void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps, void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const vector<ParsedExpression>& allExpressions, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++) for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node) if (temps[i].first == node)
return; return;
for (int i = 0; i < (int) node.getChildren().size(); i++) for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, functions, prefix, functionParams, allExpressions, tempType); processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
string name = prefix+context.intToString(temps.size()); string name = prefix+context.intToString(temps.size());
bool hasRecordedNode = false; bool hasRecordedNode = false;
...@@ -75,11 +80,10 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -75,11 +80,10 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
case Operation::CUSTOM: case Operation::CUSTOM:
{ {
int i; int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++) for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
; ;
if (i == functions.size()) if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
out << "0.0f;\n"; out << "0.0f;\n";
temps.push_back(make_pair(node, name)); temps.push_back(make_pair(node, name));
hasRecordedNode = true; hasRecordedNode = true;
...@@ -87,39 +91,190 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -87,39 +91,190 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
// If both the value and derivative of the function are needed, it's faster to calculate them both // If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed. // at once, so check to see if both are needed.
const ExpressionTreeNode* valueNode = NULL; vector<const ExpressionTreeNode*> nodes;
const ExpressionTreeNode* derivNode = NULL;
for (int j = 0; j < (int) allExpressions.size(); j++) for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), valueNode, derivNode); findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), nodes);
string valueName = name; vector<string> nodeNames;
string derivName = name; nodeNames.push_back(name);
if (valueNode != NULL && derivNode != NULL) { for (int j = 1; j < (int) nodes.size(); j++) {
string name2 = prefix+context.intToString(temps.size()); string name2 = prefix+context.intToString(temps.size());
out << tempType << " " << name2 << " = 0.0f;\n"; out << tempType << " " << name2 << " = 0.0f;\n";
if (isDeriv) { nodeNames.push_back(name2);
valueName = name2; temps.push_back(make_pair(*nodes[j], name2));
temps.push_back(make_pair(*valueNode, name2));
}
else {
derivName = name2;
temps.push_back(make_pair(*derivNode, name2));
}
} }
out << "{\n"; out << "{\n";
out << "float4 params = " << functionParams << "[" << i << "];\n"; vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j]));
}
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n"; out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x-params.x)*params.z;\n"; out << "x = (x-" << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n"; out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n"; out << "index = min(index, (int) " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functions[i].second << "[index];\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n"; out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n"; out << "real a = 1.0f-b;\n";
if (valueNode != NULL) for (int j = 0; j < nodes.size(); j++) {
out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n"; const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivNode != NULL) if (derivOrder[0] == 0)
out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n"; out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
}
out << "}\n";
}
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
out << "x = (x-" << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "y = (y-" << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n";
for (int j = 0; j < 4; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
out << nodeNames[j] << " *= " << paramsFloat[6] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
out << nodeNames[j] << " *= " << paramsFloat[7] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
}
out << "}\n";
}
else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "real z = " << getTempName(node.getChildren()[2], temps) << ";\n";
out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
out << "x = (x-" << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
out << "y = (y-" << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
out << "z = (z-" << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << ");\n";
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
out << "float4 c[16];\n";
for (int j = 0; j < 16; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
out << "real dc = z-u;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real value[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real derivx[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
}
out << nodeNames[j] << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[9] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
const string suffixes[] = {".x", ".y", ".z", ".w"};
out << "real derivy[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = 4*m;
string suffix = suffixes[m];
out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
}
out << nodeNames[j] << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[10] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
out << "real derivz[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
out << nodeNames[j] << " *= " << paramsFloat[11] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
}
out << "}\n"; out << "}\n";
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "int index = (int) round(x);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
}
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
}
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n";
out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int zsize = (int) " << paramsInt[2] << ";\n";
out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
}
}
out << "}"; out << "}";
break; break;
} }
...@@ -312,16 +467,27 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons ...@@ -312,16 +467,27 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons
} }
void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
const ExpressionTreeNode*& valueNode, const ExpressionTreeNode*& derivNode) { vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0]) { if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
if (dynamic_cast<const Operation::Custom*>(&searchNode.getOperation())->getDerivOrder()[0] == 0) // Make sure the arguments are identical.
valueNode = &searchNode;
else for (int i = 0; i < (int) node.getChildren().size(); i++)
derivNode = &searchNode; if (node.getChildren()[i] != searchNode.getChildren()[i])
return;
// See if we already have an identical node.
for (int i = 0; i < (int) nodes.size(); i++)
if (*nodes[i] == searchNode)
return;
// Add the node.
nodes.push_back(&searchNode);
} }
else else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++) for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], valueNode, derivNode); findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
} }
void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) { void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
...@@ -341,16 +507,209 @@ void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, ...@@ -341,16 +507,209 @@ void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node,
findRelatedPowers(node, searchNode.getChildren()[i], powers); findRelatedPowers(node, searchNode.getChildren()[i], powers);
} }
vector<float4> CudaExpressionUtilities::computeFunctionCoefficients(const vector<double>& values, double min, double max) { vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
// Compute the spline coefficients. // Compute the spline coefficients.
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
vector<double> values;
double min, max;
fn.getFunctionParameters(values, min, max);
int numValues = values.size(); int numValues = values.size();
vector<double> x(numValues), derivs; vector<double> x(numValues), derivs;
for (int i = 0; i < numValues; i++) for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1); x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs); SplineFitter::createNaturalSpline(x, values, derivs);
vector<float4> f(numValues-1); vector<float> f(4*(numValues-1));
for (int i = 0; i < (int) values.size()-1; i++) for (int i = 0; i < (int) values.size()-1; i++) {
f[i] = make_float4((float) values[i], (float) values[i+1], (float) (derivs[i]/6.0), (float) (derivs[i+1]/6.0)); f[4*i] = (float) values[i];
f[4*i+1] = (float) values[i+1];
f[4*i+2] = (float) (derivs[i]/6.0);
f[4*i+3] = (float) (derivs[i+1]/6.0);
}
width = 4;
return f;
}
if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(function);
vector<double> values;
int xsize, ysize;
double xmin, xmax, ymin, ymax;
fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
vector<double> x(xsize), y(ysize);
for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1);
vector<vector<double> > c;
SplineFitter::create2DNaturalSpline(x, y, values, c);
vector<float> f(16*c.size());
for (int i = 0; i < (int) c.size(); i++) {
for (int j = 0; j < 16; j++)
f[16*i+j] = (float) c[i][j];
}
width = 4;
return f;
}
if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(function);
vector<double> values;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
vector<double> x(xsize), y(ysize), z(zsize);
for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1);
for (int i = 0; i < zsize; i++)
z[i] = zmin+i*(zmax-zmin)/(zsize-1);
vector<vector<double> > c;
SplineFitter::create3DNaturalSpline(x, y, z, values, c);
vector<float> f(64*c.size());
for (int i = 0; i < (int) c.size(); i++) {
for (int j = 0; j < 64; j++)
f[64*i+j] = (float) c[i][j];
}
width = 4;
return f;
}
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
vector<double> values;
fn.getFunctionParameters(values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(function);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f; return f;
}
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(function);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}
vector<vector<double> > CudaExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
vector<vector<double> > params(functions.size());
for (int i = 0; i < (int) functions.size(); i++) {
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
vector<double> values;
double min, max;
fn.getFunctionParameters(values, min, max);
params[i].push_back(min);
params[i].push_back(max);
params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2);
}
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(*functions[i]);
vector<double> values;
int xsize, ysize;
double xmin, xmax, ymin, ymax;
fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
params[i].push_back(xsize-1);
params[i].push_back(ysize-1);
params[i].push_back(xmin);
params[i].push_back(xmax);
params[i].push_back(ymin);
params[i].push_back(ymax);
params[i].push_back((xsize-1)/(xmax-xmin));
params[i].push_back((ysize-1)/(ymax-ymin));
}
else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(*functions[i]);
vector<double> values;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
params[i].push_back(xsize-1);
params[i].push_back(ysize-1);
params[i].push_back(zsize-1);
params[i].push_back(xmin);
params[i].push_back(xmax);
params[i].push_back(ymin);
params[i].push_back(ymax);
params[i].push_back(zmin);
params[i].push_back(zmax);
params[i].push_back((xsize-1)/(xmax-xmin));
params[i].push_back((ysize-1)/(ymax-ymin));
params[i].push_back((zsize-1)/(zmax-zmin));
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values;
fn.getFunctionParameters(values);
params[i].push_back(values.size());
}
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
params[i].push_back(xsize);
params[i].push_back(ysize);
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
params[i].push_back(xsize);
params[i].push_back(ysize);
params[i].push_back(zsize);
}
else
throw OpenMMException("computeFunctionParameters: Unknown function type");
}
return params;
}
Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
return &fp1;
if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
return &fp2;
if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
return &fp3;
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
return &fp1;
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
return &fp2;
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
return &fp3;
throw OpenMMException("getFunctionPlaceholder: Unknown function type");
} }
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -595,8 +595,9 @@ void CudaCalcCustomBondForceKernel::initialize(const System& system, const Custo ...@@ -595,8 +595,9 @@ void CudaCalcCustomBondForceKernel::initialize(const System& system, const Custo
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::bondForce, replacements), force.getForceGroup()); cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::bondForce, replacements), force.getForceGroup());
...@@ -830,8 +831,9 @@ void CudaCalcCustomAngleForceKernel::initialize(const System& system, const Cust ...@@ -830,8 +831,9 @@ void CudaCalcCustomAngleForceKernel::initialize(const System& system, const Cust
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::angleForce, replacements), force.getForceGroup()); cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::angleForce, replacements), force.getForceGroup());
...@@ -1253,8 +1255,9 @@ void CudaCalcCustomTorsionForceKernel::initialize(const System& system, const Cu ...@@ -1253,8 +1255,9 @@ void CudaCalcCustomTorsionForceKernel::initialize(const System& system, const Cu
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::torsionForce, replacements), force.getForceGroup()); cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::torsionForce, replacements), force.getForceGroup());
...@@ -1912,8 +1915,6 @@ CudaCalcCustomNonbondedForceKernel::~CudaCalcCustomNonbondedForceKernel() { ...@@ -1912,8 +1915,6 @@ CudaCalcCustomNonbondedForceKernel::~CudaCalcCustomNonbondedForceKernel() {
delete params; delete params;
if (globals != NULL) if (globals != NULL)
delete globals; delete globals;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
if (interactionGroupData != NULL) if (interactionGroupData != NULL)
delete interactionGroupData; delete interactionGroupData;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
...@@ -1955,28 +1956,20 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -1955,28 +1956,20 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
// Record the tabulated functions. // Record the tabulated functions.
CudaExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<float4> tabulatedFunctionParamsVec(force.getNumFunctions()); vector<const TabulatedFunction*> functionList;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
string name; functionList.push_back(&force.getTabulatedFunction(i));
vector<double> values; string name = force.getTabulatedFunctionName(i);
double min, max;
force.getFunctionParameters(i, name, values, min, max);
string arrayName = prefix+"table"+cu.intToString(i); string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
tabulatedFunctionParamsVec[i] = make_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); int width;
vector<float4> f = cu.getExpressionUtilities().computeFunctionCoefficients(values, min, max); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float4>(cu, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer())); cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
}
if (force.getNumFunctions() > 0) {
tabulatedFunctionParams = CudaArray::create<float4>(cu, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters");
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(prefix+"functionParams", "float", 4, sizeof(float4), tabulatedFunctionParams->getDevicePointer()));
} }
// Record information for the expressions. // Record information for the expressions.
...@@ -2015,7 +2008,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -2015,7 +2008,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
variables.push_back(makeVariable(name, prefix+value)); variables.push_back(makeVariable(name, prefix+value));
} }
stringstream compute; stringstream compute;
compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, prefix+"temp", prefix+"functionParams"); compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, prefix+"temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0"); replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
...@@ -2610,8 +2603,6 @@ CudaCalcCustomGBForceKernel::~CudaCalcCustomGBForceKernel() { ...@@ -2610,8 +2603,6 @@ CudaCalcCustomGBForceKernel::~CudaCalcCustomGBForceKernel() {
delete globals; delete globals;
if (valueBuffers != NULL) if (valueBuffers != NULL)
delete valueBuffers; delete valueBuffers;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
delete tabulatedFunctions[i]; delete tabulatedFunctions[i];
} }
...@@ -2669,31 +2660,25 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2669,31 +2660,25 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
// Record the tabulated functions. // Record the tabulated functions.
CudaExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<float4> tabulatedFunctionParamsVec(force.getNumFunctions()); vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name; functionList.push_back(&force.getTabulatedFunction(i));
vector<double> values; string name = force.getTabulatedFunctionName(i);
double min, max;
force.getFunctionParameters(i, name, values, min, max);
string arrayName = prefix+"table"+cu.intToString(i); string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
tabulatedFunctionParamsVec[i] = make_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); int width;
vector<float4> f = cu.getExpressionUtilities().computeFunctionCoefficients(values, min, max); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float4>(cu, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer())); cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
tableArgs << ", const float4* __restrict__ " << arrayName; tableArgs << ", const float";
} if (width > 1)
if (force.getNumFunctions() > 0) { tableArgs << width;
tabulatedFunctionParams = CudaArray::create<float4>(cu, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters"); tableArgs << "* __restrict__ " << arrayName;
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(prefix+"functionParams", "float", 4, sizeof(float4), tabulatedFunctionParams->getDevicePointer()));
tableArgs << ", const float4* " << prefix << "functionParams";
} }
// Record the global parameters. // Record the global parameters.
...@@ -2779,7 +2764,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2779,7 +2764,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize(); Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueExpressions["tempValue1 = "] = ex; n2ValueExpressions["tempValue1 = "] = ex;
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename); n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
n2ValueSource << cu.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); n2ValueSource << cu.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp");
map<string, string> replacements; map<string, string> replacements;
string n2ValueStr = n2ValueSource.str(); string n2ValueStr = n2ValueSource.str();
replacements["COMPUTE_VALUE"] = n2ValueStr; replacements["COMPUTE_VALUE"] = n2ValueStr;
...@@ -2857,7 +2842,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2857,7 +2842,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1); variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
map<string, Lepton::ParsedExpression> valueExpressions; map<string, Lepton::ParsedExpression> valueExpressions;
valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize(); valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
reductionSource << cu.getExpressionUtilities().createExpressions(valueExpressions, variables, functionDefinitions, "value"+cu.intToString(i)+"_temp", prefix+"functionParams"); reductionSource << cu.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cu.intToString(i)+"_temp");
} }
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) { for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
string valueName = "values"+cu.intToString(i+1); string valueName = "values"+cu.intToString(i+1);
...@@ -2911,7 +2896,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2911,7 +2896,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
} }
if (exclude) if (exclude)
n2EnergySource << "if (!isExcluded) {\n"; n2EnergySource << "if (!isExcluded) {\n";
n2EnergySource << cu.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); n2EnergySource << cu.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp");
if (exclude) if (exclude)
n2EnergySource << "}\n"; n2EnergySource << "}\n";
} }
...@@ -3060,7 +3045,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -3060,7 +3045,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
for (int i = 1; i < force.getNumComputedValues(); i++) for (int i = 1; i < force.getNumComputedValues(); i++)
for (int j = 0; j < i; j++) for (int j = 0; j < i; j++)
expressions["real dV"+cu.intToString(i)+"dV"+cu.intToString(j)+" = "] = valueDerivExpressions[i][j]; expressions["real dV"+cu.intToString(i)+"dV"+cu.intToString(j)+" = "] = valueDerivExpressions[i][j];
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "temp", prefix+"functionParams"); compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "temp");
// Record values. // Record values.
...@@ -3132,7 +3117,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -3132,7 +3117,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
map<string, Lepton::ParsedExpression> derivExpressions; map<string, Lepton::ParsedExpression> derivExpressions;
string js = cu.intToString(j); string js = cu.intToString(j);
derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j]; derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
compute << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams"); compute << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js);
compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n"; compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
} }
} }
...@@ -3143,7 +3128,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -3143,7 +3128,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1]; gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2])) if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2]; gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cu.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); compute << cu.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp");
} }
for (int i = 1; i < force.getNumComputedValues(); i++) { for (int i = 1; i < force.getNumComputedValues(); i++) {
string is = cu.intToString(i); string is = cu.intToString(i);
...@@ -3185,7 +3170,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -3185,7 +3170,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize(); Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["real dV0dR1 = "] = dVdR; derivExpressions["real dV0dR1 = "] = dVdR;
derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename); derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams"); chainSource << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, prefix+"temp0_");
if (needChainForValue[0]) { if (needChainForValue[0]) {
if (useExclusionsForValue) if (useExclusionsForValue)
chainSource << "if (!isExcluded) {\n"; chainSource << "if (!isExcluded) {\n";
...@@ -3304,11 +3289,8 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -3304,11 +3289,8 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo
if (pairValueUsesParam[i]) if (pairValueUsesParam[i])
pairValueArgs.push_back(&params->getBuffers()[i].getMemory()); pairValueArgs.push_back(&params->getBuffers()[i].getMemory());
} }
if (tabulatedFunctionParams != NULL) {
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
pairValueArgs.push_back(&tabulatedFunctions[i]->getDevicePointer()); pairValueArgs.push_back(&tabulatedFunctions[i]->getDevicePointer());
pairValueArgs.push_back(&tabulatedFunctionParams->getDevicePointer());
}
perParticleValueArgs.push_back(&cu.getPosq().getDevicePointer()); perParticleValueArgs.push_back(&cu.getPosq().getDevicePointer());
perParticleValueArgs.push_back(&valueBuffers->getDevicePointer()); perParticleValueArgs.push_back(&valueBuffers->getDevicePointer());
if (globals != NULL) if (globals != NULL)
...@@ -3317,11 +3299,8 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -3317,11 +3299,8 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo
perParticleValueArgs.push_back(&params->getBuffers()[i].getMemory()); perParticleValueArgs.push_back(&params->getBuffers()[i].getMemory());
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) for (int i = 0; i < (int) computedValues->getBuffers().size(); i++)
perParticleValueArgs.push_back(&computedValues->getBuffers()[i].getMemory()); perParticleValueArgs.push_back(&computedValues->getBuffers()[i].getMemory());
if (tabulatedFunctionParams != NULL) {
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
perParticleValueArgs.push_back(&tabulatedFunctions[i]->getDevicePointer()); perParticleValueArgs.push_back(&tabulatedFunctions[i]->getDevicePointer());
perParticleValueArgs.push_back(&tabulatedFunctionParams->getDevicePointer());
}
pairEnergyArgs.push_back(&cu.getForce().getDevicePointer()); pairEnergyArgs.push_back(&cu.getForce().getDevicePointer());
pairEnergyArgs.push_back(&cu.getEnergyBuffer().getDevicePointer()); pairEnergyArgs.push_back(&cu.getEnergyBuffer().getDevicePointer());
pairEnergyArgs.push_back(&cu.getPosq().getDevicePointer()); pairEnergyArgs.push_back(&cu.getPosq().getDevicePointer());
...@@ -3350,11 +3329,8 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -3350,11 +3329,8 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo
pairEnergyArgs.push_back(&computedValues->getBuffers()[i].getMemory()); pairEnergyArgs.push_back(&computedValues->getBuffers()[i].getMemory());
} }
pairEnergyArgs.push_back(&longEnergyDerivs->getDevicePointer()); pairEnergyArgs.push_back(&longEnergyDerivs->getDevicePointer());
if (tabulatedFunctionParams != NULL) {
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
pairEnergyArgs.push_back(&tabulatedFunctions[i]->getDevicePointer()); pairEnergyArgs.push_back(&tabulatedFunctions[i]->getDevicePointer());
pairEnergyArgs.push_back(&tabulatedFunctionParams->getDevicePointer());
}
perParticleEnergyArgs.push_back(&cu.getForce().getDevicePointer()); perParticleEnergyArgs.push_back(&cu.getForce().getDevicePointer());
perParticleEnergyArgs.push_back(&cu.getEnergyBuffer().getDevicePointer()); perParticleEnergyArgs.push_back(&cu.getEnergyBuffer().getDevicePointer());
perParticleEnergyArgs.push_back(&cu.getPosq().getDevicePointer()); perParticleEnergyArgs.push_back(&cu.getPosq().getDevicePointer());
...@@ -3369,11 +3345,8 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -3369,11 +3345,8 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo
for (int i = 0; i < (int) energyDerivChain->getBuffers().size(); i++) for (int i = 0; i < (int) energyDerivChain->getBuffers().size(); i++)
perParticleEnergyArgs.push_back(&energyDerivChain->getBuffers()[i].getMemory()); perParticleEnergyArgs.push_back(&energyDerivChain->getBuffers()[i].getMemory());
perParticleEnergyArgs.push_back(&longEnergyDerivs->getDevicePointer()); perParticleEnergyArgs.push_back(&longEnergyDerivs->getDevicePointer());
if (tabulatedFunctionParams != NULL) {
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
perParticleEnergyArgs.push_back(&tabulatedFunctions[i]->getDevicePointer()); perParticleEnergyArgs.push_back(&tabulatedFunctions[i]->getDevicePointer());
perParticleEnergyArgs.push_back(&tabulatedFunctionParams->getDevicePointer());
}
if (needParameterGradient) { if (needParameterGradient) {
gradientChainRuleArgs.push_back(&cu.getForce().getDevicePointer()); gradientChainRuleArgs.push_back(&cu.getForce().getDevicePointer());
gradientChainRuleArgs.push_back(&cu.getPosq().getDevicePointer()); gradientChainRuleArgs.push_back(&cu.getPosq().getDevicePointer());
...@@ -3543,8 +3516,9 @@ void CudaCalcCustomExternalForceKernel::initialize(const System& system, const C ...@@ -3543,8 +3516,9 @@ void CudaCalcCustomExternalForceKernel::initialize(const System& system, const C
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::customExternalForce, replacements), force.getForceGroup()); cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::customExternalForce, replacements), force.getForceGroup());
...@@ -3685,8 +3659,6 @@ CudaCalcCustomHbondForceKernel::~CudaCalcCustomHbondForceKernel() { ...@@ -3685,8 +3659,6 @@ CudaCalcCustomHbondForceKernel::~CudaCalcCustomHbondForceKernel() {
delete donorExclusions; delete donorExclusions;
if (acceptorExclusions != NULL) if (acceptorExclusions != NULL)
delete acceptorExclusions; delete acceptorExclusions;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
delete tabulatedFunctions[i]; delete tabulatedFunctions[i];
} }
...@@ -3784,29 +3756,24 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust ...@@ -3784,29 +3756,24 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
// Record the tabulated functions. // Record the tabulated functions.
CudaExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<float4> tabulatedFunctionParamsVec(force.getNumFunctions()); vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
string name; functionList.push_back(&force.getTabulatedFunction(i));
vector<double> values; string name = force.getTabulatedFunctionName(i);
double min, max;
force.getFunctionParameters(i, name, values, min, max);
string arrayName = "table"+cu.intToString(i); string arrayName = "table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
tabulatedFunctionParamsVec[i] = make_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); int width;
vector<float4> f = cu.getExpressionUtilities().computeFunctionCoefficients(values, min, max); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float4>(cu, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", const float4* __restrict__ " << arrayName; tableArgs << ", const float";
} if (width > 1)
if (force.getNumFunctions() > 0) { tableArgs << width;
tabulatedFunctionParams = CudaArray::create<float4>(cu, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters"); tableArgs << "* __restrict__ " << arrayName;
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
tableArgs << ", const float4* __restrict__ functionParams";
} }
// Record information about parameters. // Record information about parameters.
...@@ -3922,9 +3889,9 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust ...@@ -3922,9 +3889,9 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
// Now evaluate the expressions. // Now evaluate the expressions.
computeAcceptor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams"); computeAcceptor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
forceExpressions["energy += "] = energyExpression; forceExpressions["energy += "] = energyExpression;
computeDonor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams"); computeDonor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
// Finally, apply forces to atoms. // Finally, apply forces to atoms.
...@@ -4036,11 +4003,8 @@ double CudaCalcCustomHbondForceKernel::execute(ContextImpl& context, bool includ ...@@ -4036,11 +4003,8 @@ double CudaCalcCustomHbondForceKernel::execute(ContextImpl& context, bool includ
CudaNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i]; CudaNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i];
donorArgs.push_back(&buffer.getMemory()); donorArgs.push_back(&buffer.getMemory());
} }
if (tabulatedFunctionParams != NULL) {
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
donorArgs.push_back(&tabulatedFunctions[i]->getDevicePointer()); donorArgs.push_back(&tabulatedFunctions[i]->getDevicePointer());
donorArgs.push_back(&tabulatedFunctionParams->getDevicePointer());
}
index = 0; index = 0;
acceptorArgs.push_back(&cu.getForce().getDevicePointer()); acceptorArgs.push_back(&cu.getForce().getDevicePointer());
acceptorArgs.push_back(&cu.getEnergyBuffer().getDevicePointer()); acceptorArgs.push_back(&cu.getEnergyBuffer().getDevicePointer());
...@@ -4060,11 +4024,8 @@ double CudaCalcCustomHbondForceKernel::execute(ContextImpl& context, bool includ ...@@ -4060,11 +4024,8 @@ double CudaCalcCustomHbondForceKernel::execute(ContextImpl& context, bool includ
CudaNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i]; CudaNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i];
acceptorArgs.push_back(&buffer.getMemory()); acceptorArgs.push_back(&buffer.getMemory());
} }
if (tabulatedFunctionParams != NULL) {
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
acceptorArgs.push_back(&tabulatedFunctions[i]->getDevicePointer()); acceptorArgs.push_back(&tabulatedFunctions[i]->getDevicePointer());
acceptorArgs.push_back(&tabulatedFunctionParams->getDevicePointer());
}
} }
int sharedMemorySize = 3*CudaContext::ThreadBlockSize*sizeof(float4); int sharedMemorySize = 3*CudaContext::ThreadBlockSize*sizeof(float4);
cu.executeKernel(donorKernel, &donorArgs[0], max(numDonors, numAcceptors), CudaContext::ThreadBlockSize, sharedMemorySize); cu.executeKernel(donorKernel, &donorArgs[0], max(numDonors, numAcceptors), CudaContext::ThreadBlockSize, sharedMemorySize);
...@@ -4148,8 +4109,6 @@ CudaCalcCustomCompoundBondForceKernel::~CudaCalcCustomCompoundBondForceKernel() ...@@ -4148,8 +4109,6 @@ CudaCalcCustomCompoundBondForceKernel::~CudaCalcCustomCompoundBondForceKernel()
delete params; delete params;
if (globals != NULL) if (globals != NULL)
delete globals; delete globals;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
delete tabulatedFunctions[i]; delete tabulatedFunctions[i];
} }
...@@ -4178,31 +4137,22 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con ...@@ -4178,31 +4137,22 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
// Record the tabulated functions. // Record the tabulated functions.
CudaExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<float4> tabulatedFunctionParamsVec(force.getNumFunctions()); vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
string name; functionList.push_back(&force.getTabulatedFunction(i));
vector<double> values; string name = force.getTabulatedFunctionName(i);
double min, max; functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
force.getFunctionParameters(i, name, values, min, max); int width;
functions[name] = &fp; vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionParamsVec[i] = make_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); CudaArray* array = CudaArray::create<float>(cu, f.size(), "TabulatedFunction");
vector<float4> f = cu.getExpressionUtilities().computeFunctionCoefficients(values, min, max);
CudaArray* array = CudaArray::create<float4>(cu, values.size()-1, "TabulatedFunction");
tabulatedFunctions.push_back(array); tabulatedFunctions.push_back(array);
array->upload(f); array->upload(f);
string arrayName = cu.getBondedUtilities().addArgument(array->getDevicePointer(), "float4"); string arrayName = cu.getBondedUtilities().addArgument(array->getDevicePointer(), width == 1 ? "float" : "float"+cu.intToString(width));
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
} }
string functionParamsName;
if (force.getNumFunctions() > 0) {
tabulatedFunctionParams = CudaArray::create<float4>(cu, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters");
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
functionParamsName = cu.getBondedUtilities().addArgument(tabulatedFunctionParams->getDevicePointer(), "float4");
}
// Record information about parameters. // Record information about parameters.
...@@ -4317,7 +4267,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con ...@@ -4317,7 +4267,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
forceExpressions["energy += "] = energyExpression; forceExpressions["energy += "] = energyExpression;
compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", functionParamsName); compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
// Finally, apply forces to atoms. // Finally, apply forces to atoms.
...@@ -4339,7 +4289,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con ...@@ -4339,7 +4289,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
if (!isZeroExpression(forceExpressionZ)) if (!isZeroExpression(forceExpressionZ))
expressions[forceName+".z -= "] = forceExpressionZ; expressions[forceName+".z -= "] = forceExpressionZ;
if (expressions.size() > 0) if (expressions.size() > 0)
compute<<cu.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "coordtemp", functionParamsName); compute<<cu.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp");
compute<<"}\n"; compute<<"}\n";
} }
index = 0; index = 0;
...@@ -4949,8 +4899,9 @@ string CudaIntegrateCustomStepKernel::createGlobalComputation(const string& vari ...@@ -4949,8 +4899,9 @@ string CudaIntegrateCustomStepKernel::createGlobalComputation(const string& vari
variables[integrator.getGlobalVariableName(i)] = "globals["+cu.intToString(i)+"]"; variables[integrator.getGlobalVariableName(i)] = "globals["+cu.intToString(i)+"]";
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cu.intToString(i)+"]"; variables[parameterNames[i]] = "params["+cu.intToString(i)+"]";
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
} }
string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) { string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) {
...@@ -4986,8 +4937,9 @@ string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& vari ...@@ -4986,8 +4937,9 @@ string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& vari
variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i); variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cu.intToString(i)+"]"; variables[parameterNames[i]] = "params["+cu.intToString(i)+"]";
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp"+cu.intToString(component)+"_", "", "double"); vector<pair<string, string> > functionNames;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp"+cu.intToString(component)+"_", "double");
} }
void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
......
...@@ -199,6 +199,103 @@ void testParallelComputation() { ...@@ -199,6 +199,103 @@ void testParallelComputation() {
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5); ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5);
} }
void testContinuous2DFunction() {
const int xsize = 10;
const int ysize = 11;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
System system;
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomCompoundBondForce* forceField = new CustomCompoundBondForce(1, "fn(x1,y1)+1");
vector<int> particles(1, 0);
forceField->addBond(particles, vector<double>());
vector<double> table(xsize*ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[0] = Vec3(x, y, 1.5);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = sin(0.25*x)*cos(0.33*y)+1;
force[0] = -0.25*cos(0.25*x)*cos(0.33*y);
force[1] = 0.3*sin(0.25*x)*sin(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
void testContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
const double zmin = 0.2;
const double zmax = 1.3;
System system;
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomCompoundBondForce* forceField = new CustomCompoundBondForce(1, "fn(x1,y1,z1)+1");
vector<int> particles(1, 0);
forceField->addBond(particles, vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
double z = zmin + k*(zmax-zmin)/zsize;
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(0.33*y)*(1+z);
}
}
}
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.1) {
positions[0] = Vec3(x, y, z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax && z >= zmin && z <= zmax) {
energy = sin(0.25*x)*cos(0.33*y)*(1.0+z)+1;
force[0] = -0.25*cos(0.25*x)*cos(0.33*y)*(1.0+z);
force[1] = 0.3*sin(0.25*x)*sin(0.33*y)*(1.0+z);
force[2] = -sin(0.25*x)*cos(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -206,6 +303,8 @@ int main(int argc, char* argv[]) { ...@@ -206,6 +303,8 @@ int main(int argc, char* argv[]) {
testBond(); testBond();
testPositionDependence(); testPositionDependence();
testParallelComputation(); testParallelComputation();
testContinuous2DFunction();
testContinuous3DFunction();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -277,7 +277,7 @@ void testTabulatedFunction() { ...@@ -277,7 +277,7 @@ void testTabulatedFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0); force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force); system.addForce(force);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment