Commit 75a04e7f authored by peastman's avatar peastman
Browse files

Renamed methods for dealing with TabulatedFunctions to avoid problems in...

Renamed methods for dealing with TabulatedFunctions to avoid problems in wrapper APIs.  Also fixed a few bugs.
parent f7e5492d
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -92,7 +92,7 @@ namespace OpenMM { ...@@ -92,7 +92,7 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify the function by * In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in the expression. * creating a TabulatedFunction object. That function can then appear in the expression.
*/ */
...@@ -135,6 +135,14 @@ public: ...@@ -135,6 +135,14 @@ public:
/** /**
* Get the number of tabulated functions that have been defined. * Get the number of tabulated functions that have been defined.
*/ */
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const { int getNumFunctions() const {
return functions.size(); return functions.size();
} }
...@@ -236,49 +244,46 @@ public: ...@@ -236,49 +244,46 @@ public:
* Force takes over ownership of it, and deletes it when the Force itself is deleted. * Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addFunction(const std::string& name, TabulatedFunction* function); int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/** /**
* Get a const reference to a tabulated function that may appear in the energy expression. * Get a const reference to a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the TabulatedFunction object defining the function * @return the TabulatedFunction object defining the function
*/ */
const TabulatedFunction& getFunction(int index) const; const TabulatedFunction& getTabulatedFunction(int index) const;
/** /**
* Get a reference to a tabulated function that may appear in the energy expression. * Get a reference to a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the TabulatedFunction object defining the function * @return the TabulatedFunction object defining the function
*/ */
TabulatedFunction& getFunction(int index); TabulatedFunction& getTabulatedFunction(int index);
/** /**
* Get the name of a tabulated function that may appear in the energy expression. * Get the name of a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the name of the function as it appears in expressions * @return the name of the function as it appears in expressions
*/ */
const std::string& getFunctionName(int index) const; const std::string& getTabulatedFunctionName(int index) const;
/** /**
* Add a tabulated function that may appear in the energy expression. * Add a tabulated function that may appear in the energy expression.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
* a TabulatedFunction instead.
*/ */
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in the energy expression.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws * If the specified function is not a Continuous1DFunction, this throws an exception.
* an exception.
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in the energy expression. * Set the parameters for a tabulated function that may appear in the energy expression.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws * If the specified function is not a Continuous1DFunction, this throws an exception.
* an exception.
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -135,7 +135,7 @@ namespace OpenMM { ...@@ -135,7 +135,7 @@ namespace OpenMM {
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example, * have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify the function by * In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in expressions. * creating a TabulatedFunction object. That function can then appear in expressions.
*/ */
...@@ -210,6 +210,14 @@ public: ...@@ -210,6 +210,14 @@ public:
/** /**
* Get the number of tabulated functions that have been defined. * Get the number of tabulated functions that have been defined.
*/ */
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const { int getNumFunctions() const {
return functions.size(); return functions.size();
} }
...@@ -462,49 +470,46 @@ public: ...@@ -462,49 +470,46 @@ public:
* Force takes over ownership of it, and deletes it when the Force itself is deleted. * Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addFunction(const std::string& name, TabulatedFunction* function); int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/** /**
* Get a const reference to a tabulated function that may appear in expressions. * Get a const reference to a tabulated function that may appear in expressions.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the TabulatedFunction object defining the function * @return the TabulatedFunction object defining the function
*/ */
const TabulatedFunction& getFunction(int index) const; const TabulatedFunction& getTabulatedFunction(int index) const;
/** /**
* Get a reference to a tabulated function that may appear in expressions. * Get a reference to a tabulated function that may appear in expressions.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the TabulatedFunction object defining the function * @return the TabulatedFunction object defining the function
*/ */
TabulatedFunction& getFunction(int index); TabulatedFunction& getTabulatedFunction(int index);
/** /**
* Get the name of a tabulated function that may appear in expressions. * Get the name of a tabulated function that may appear in expressions.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the name of the function as it appears in expressions * @return the name of the function as it appears in expressions
*/ */
const std::string& getFunctionName(int index) const; const std::string& getTabulatedFunctionName(int index) const;
/** /**
* Add a tabulated function that may appear in expressions. * Add a tabulated function that may appear in expressions.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
* a TabulatedFunction instead.
*/ */
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in expressions. * Get the parameters for a tabulated function that may appear in expressions.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws * If the specified function is not a Continuous1DFunction, this throws an exception.
* an exception.
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in expressions. * Set the parameters for a tabulated function that may appear in expressions.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws * If the specified function is not a Continuous1DFunction, this throws an exception.
* an exception.
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -92,7 +92,7 @@ namespace OpenMM { ...@@ -92,7 +92,7 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify the function by * In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in the expression. * creating a TabulatedFunction object. That function can then appear in the expression.
*/ */
...@@ -165,6 +165,14 @@ public: ...@@ -165,6 +165,14 @@ public:
/** /**
* Get the number of tabulated functions that have been defined. * Get the number of tabulated functions that have been defined.
*/ */
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const { int getNumFunctions() const {
return functions.size(); return functions.size();
} }
...@@ -381,49 +389,46 @@ public: ...@@ -381,49 +389,46 @@ public:
* Force takes over ownership of it, and deletes it when the Force itself is deleted. * Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addFunction(const std::string& name, TabulatedFunction* function); int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/** /**
* Get a const reference to a tabulated function that may appear in the energy expression. * Get a const reference to a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the TabulatedFunction object defining the function * @return the TabulatedFunction object defining the function
*/ */
const TabulatedFunction& getFunction(int index) const; const TabulatedFunction& getTabulatedFunction(int index) const;
/** /**
* Get a reference to a tabulated function that may appear in the energy expression. * Get a reference to a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the TabulatedFunction object defining the function * @return the TabulatedFunction object defining the function
*/ */
TabulatedFunction& getFunction(int index); TabulatedFunction& getTabulatedFunction(int index);
/** /**
* Get the name of a tabulated function that may appear in the energy expression. * Get the name of a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the name of the function as it appears in expressions * @return the name of the function as it appears in expressions
*/ */
const std::string& getFunctionName(int index) const; const std::string& getTabulatedFunctionName(int index) const;
/** /**
* Add a tabulated function that may appear in the energy expression. * Add a tabulated function that may appear in the energy expression.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
* a TabulatedFunction instead.
*/ */
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in the energy expression.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws * If the specified function is not a Continuous1DFunction, this throws an exception.
* an exception.
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in the energy expression. * Set the parameters for a tabulated function that may appear in the energy expression.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws * If the specified function is not a Continuous1DFunction, this throws an exception.
* an exception.
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -125,7 +125,7 @@ namespace OpenMM { ...@@ -125,7 +125,7 @@ namespace OpenMM {
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example, * have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify the function by * In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in the expression. * creating a TabulatedFunction object. That function can then appear in the expression.
*/ */
...@@ -185,6 +185,14 @@ public: ...@@ -185,6 +185,14 @@ public:
/** /**
* Get the number of tabulated functions that have been defined. * Get the number of tabulated functions that have been defined.
*/ */
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const { int getNumFunctions() const {
return functions.size(); return functions.size();
} }
...@@ -366,49 +374,46 @@ public: ...@@ -366,49 +374,46 @@ public:
* Force takes over ownership of it, and deletes it when the Force itself is deleted. * Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addFunction(const std::string& name, TabulatedFunction* function); int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/** /**
* Get a const reference to a tabulated function that may appear in the energy expression. * Get a const reference to a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the TabulatedFunction object defining the function * @return the TabulatedFunction object defining the function
*/ */
const TabulatedFunction& getFunction(int index) const; const TabulatedFunction& getTabulatedFunction(int index) const;
/** /**
* Get a reference to a tabulated function that may appear in the energy expression. * Get a reference to a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the TabulatedFunction object defining the function * @return the TabulatedFunction object defining the function
*/ */
TabulatedFunction& getFunction(int index); TabulatedFunction& getTabulatedFunction(int index);
/** /**
* Get the name of a tabulated function that may appear in the energy expression. * Get the name of a tabulated function that may appear in the energy expression.
* *
* @param index the index of the function to get * @param index the index of the function to get
* @return the name of the function as it appears in expressions * @return the name of the function as it appears in expressions
*/ */
const std::string& getFunctionName(int index) const; const std::string& getTabulatedFunctionName(int index) const;
/** /**
* Add a tabulated function that may appear in the energy expression. * Add a tabulated function that may appear in the energy expression.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
* a TabulatedFunction instead.
*/ */
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in the energy expression.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws * If the specified function is not a Continuous1DFunction, this throws an exception.
* an exception.
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in the energy expression. * Set the parameters for a tabulated function that may appear in the energy expression.
* *
* @deprecated This method exists only for backward compatibility. Use the version that takes * @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws * If the specified function is not a Continuous1DFunction, this throws an exception.
* an exception.
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -126,22 +126,22 @@ void CustomCompoundBondForce::setBondParameters(int index, const vector<int>& pa ...@@ -126,22 +126,22 @@ void CustomCompoundBondForce::setBondParameters(int index, const vector<int>& pa
bonds[index].parameters = parameters; bonds[index].parameters = parameters;
} }
int CustomCompoundBondForce::addFunction(const std::string& name, TabulatedFunction* function) { int CustomCompoundBondForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function)); functions.push_back(FunctionInfo(name, function));
return functions.size()-1; return functions.size()-1;
} }
const TabulatedFunction& CustomCompoundBondForce::getFunction(int index) const { const TabulatedFunction& CustomCompoundBondForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return *functions[index].function; return *functions[index].function;
} }
TabulatedFunction& CustomCompoundBondForce::getFunction(int index) { TabulatedFunction& CustomCompoundBondForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return *functions[index].function; return *functions[index].function;
} }
const string& CustomCompoundBondForce::getFunctionName(int index) const { const string& CustomCompoundBondForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return functions[index].name; return functions[index].name;
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -178,22 +178,22 @@ void CustomGBForce::setExclusionParticles(int index, int particle1, int particle ...@@ -178,22 +178,22 @@ void CustomGBForce::setExclusionParticles(int index, int particle1, int particle
exclusions[index].particle2 = particle2; exclusions[index].particle2 = particle2;
} }
int CustomGBForce::addFunction(const std::string& name, TabulatedFunction* function) { int CustomGBForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function)); functions.push_back(FunctionInfo(name, function));
return functions.size()-1; return functions.size()-1;
} }
const TabulatedFunction& CustomGBForce::getFunction(int index) const { const TabulatedFunction& CustomGBForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return *functions[index].function; return *functions[index].function;
} }
TabulatedFunction& CustomGBForce::getFunction(int index) { TabulatedFunction& CustomGBForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return *functions[index].function; return *functions[index].function;
} }
const string& CustomGBForce::getFunctionName(int index) const { const string& CustomGBForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return functions[index].name; return functions[index].name;
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -193,22 +193,22 @@ void CustomHbondForce::setExclusionParticles(int index, int donor, int acceptor) ...@@ -193,22 +193,22 @@ void CustomHbondForce::setExclusionParticles(int index, int donor, int acceptor)
exclusions[index].acceptor = acceptor; exclusions[index].acceptor = acceptor;
} }
int CustomHbondForce::addFunction(const std::string& name, TabulatedFunction* function) { int CustomHbondForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function)); functions.push_back(FunctionInfo(name, function));
return functions.size()-1; return functions.size()-1;
} }
const TabulatedFunction& CustomHbondForce::getFunction(int index) const { const TabulatedFunction& CustomHbondForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return *functions[index].function; return *functions[index].function;
} }
TabulatedFunction& CustomHbondForce::getFunction(int index) { TabulatedFunction& CustomHbondForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return *functions[index].function; return *functions[index].function;
} }
const string& CustomHbondForce::getFunctionName(int index) const { const string& CustomHbondForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return functions[index].name; return functions[index].name;
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -174,22 +174,22 @@ void CustomNonbondedForce::setExclusionParticles(int index, int particle1, int p ...@@ -174,22 +174,22 @@ void CustomNonbondedForce::setExclusionParticles(int index, int particle1, int p
exclusions[index].particle1 = particle1; exclusions[index].particle1 = particle1;
exclusions[index].particle2 = particle2; exclusions[index].particle2 = particle2;
} }
int CustomNonbondedForce::addFunction(const std::string& name, TabulatedFunction* function) { int CustomNonbondedForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function)); functions.push_back(FunctionInfo(name, function));
return functions.size()-1; return functions.size()-1;
} }
const TabulatedFunction& CustomNonbondedForce::getFunction(int index) const { const TabulatedFunction& CustomNonbondedForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return *functions[index].function; return *functions[index].function;
} }
TabulatedFunction& CustomNonbondedForce::getFunction(int index) { TabulatedFunction& CustomNonbondedForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return *functions[index].function; return *functions[index].function;
} }
const string& CustomNonbondedForce::getFunctionName(int index) const { const string& CustomNonbondedForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions); ASSERT_VALID_INDEX(index, functions);
return functions[index].name; return functions[index].name;
} }
......
...@@ -468,8 +468,23 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons ...@@ -468,8 +468,23 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons
void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
vector<const Lepton::ExpressionTreeNode*>& nodes) { vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0]) if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
// Make sure the arguments are identical.
for (int i = 0; i < (int) node.getChildren().size(); i++)
if (node.getChildren()[i] != searchNode.getChildren()[i])
return;
// See if we already have an identical node.
for (int i = 0; i < (int) nodes.size(); i++)
if (*nodes[i] == searchNode)
return;
// Add the node.
nodes.push_back(&searchNode); nodes.push_back(&searchNode);
}
else else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++) for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes); findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
......
...@@ -1960,13 +1960,13 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -1960,13 +1960,13 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getFunctionName(i); string name = force.getTabulatedFunctionName(i);
string arrayName = prefix+"table"+cu.intToString(i); string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i)); functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer())); cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
...@@ -2664,18 +2664,21 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2664,18 +2664,21 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getFunctionName(i); string name = force.getTabulatedFunctionName(i);
string arrayName = prefix+"table"+cu.intToString(i); string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i)); functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer())); cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
tableArgs << ", const float4* __restrict__ " << arrayName; tableArgs << ", const float";
if (width > 1)
tableArgs << width;
tableArgs << "* __restrict__ " << arrayName;
} }
// Record the global parameters. // Record the global parameters.
...@@ -3758,13 +3761,13 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust ...@@ -3758,13 +3761,13 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getFunctionName(i); string name = force.getTabulatedFunctionName(i);
string arrayName = "table"+cu.intToString(i); string arrayName = "table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i)); functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", const float"; tableArgs << ", const float";
...@@ -4139,11 +4142,11 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con ...@@ -4139,11 +4142,11 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getFunctionName(i); string name = force.getTabulatedFunctionName(i);
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i)); functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
CudaArray* array = CudaArray::create<float>(cu, f.size(), "TabulatedFunction"); CudaArray* array = CudaArray::create<float>(cu, f.size(), "TabulatedFunction");
tabulatedFunctions.push_back(array); tabulatedFunctions.push_back(array);
array->upload(f); array->upload(f);
......
...@@ -220,7 +220,7 @@ void testContinuous2DFunction() { ...@@ -220,7 +220,7 @@ void testContinuous2DFunction() {
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y); table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
} }
} }
forceField->addFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax)); forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(1); vector<Vec3> positions(1);
...@@ -270,7 +270,7 @@ void testContinuous3DFunction() { ...@@ -270,7 +270,7 @@ void testContinuous3DFunction() {
} }
} }
} }
forceField->addFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax)); forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(1); vector<Vec3> positions(1);
......
...@@ -277,7 +277,7 @@ void testTabulatedFunction() { ...@@ -277,7 +277,7 @@ void testTabulatedFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0); force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force); system.addForce(force);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
......
...@@ -214,7 +214,7 @@ void testCustomFunctions() { ...@@ -214,7 +214,7 @@ void testCustomFunctions() {
vector<double> function(2); vector<double> function(2);
function[0] = 0; function[0] = 0;
function[1] = 1; function[1] = 1;
custom->addFunction("foo", function, 0, 10); custom->addTabulatedFunction("foo", new Continuous1DFunction(function, 0, 10));
system.addForce(custom); system.addForce(custom);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(3); vector<Vec3> positions(3);
......
...@@ -272,7 +272,7 @@ void testContinuous1DFunction() { ...@@ -272,7 +272,7 @@ void testContinuous1DFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0)); forceField->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -322,7 +322,7 @@ void testContinuous2DFunction() { ...@@ -322,7 +322,7 @@ void testContinuous2DFunction() {
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y); table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
} }
} }
forceField->addFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax)); forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -377,7 +377,7 @@ void testContinuous3DFunction() { ...@@ -377,7 +377,7 @@ void testContinuous3DFunction() {
} }
} }
} }
forceField->addFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax)); forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -416,7 +416,7 @@ void testDiscrete1DFunction() { ...@@ -416,7 +416,7 @@ void testDiscrete1DFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table)); forceField->addTabulatedFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -447,7 +447,7 @@ void testDiscrete2DFunction() { ...@@ -447,7 +447,7 @@ void testDiscrete2DFunction() {
for (int i = 0; i < xsize; i++) for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++) for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j)); table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addFunction("fn", new Discrete2DFunction(xsize, ysize, table)); forceField->addTabulatedFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -482,7 +482,7 @@ void testDiscrete3DFunction() { ...@@ -482,7 +482,7 @@ void testDiscrete3DFunction() {
for (int j = 0; j < ysize; j++) for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++) for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k); table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table)); forceField->addTabulatedFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
......
...@@ -468,8 +468,23 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co ...@@ -468,8 +468,23 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
vector<const Lepton::ExpressionTreeNode*>& nodes) { vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0]) if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
// Make sure the arguments are identical.
for (int i = 0; i < (int) node.getChildren().size(); i++)
if (node.getChildren()[i] != searchNode.getChildren()[i])
return;
// See if we already have an identical node.
for (int i = 0; i < (int) nodes.size(); i++)
if (*nodes[i] == searchNode)
return;
// Add the node.
nodes.push_back(&searchNode); nodes.push_back(&searchNode);
}
else else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++) for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes); findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
......
...@@ -1970,13 +1970,13 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -1970,13 +1970,13 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getFunctionName(i); string name = force.getTabulatedFunctionName(i);
string arrayName = prefix+"table"+cl.intToString(i); string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i)); functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
...@@ -2717,18 +2717,21 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2717,18 +2717,21 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getFunctionName(i); string name = force.getTabulatedFunctionName(i);
string arrayName = prefix+"table"+cl.intToString(i); string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i)); functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
tableArgs << ", __global const float4* restrict " << arrayName; tableArgs << ", __global const float";
if (width > 1)
tableArgs << width;
tableArgs << "* restrict " << arrayName;
} }
// Record the global parameters. // Record the global parameters.
...@@ -3921,13 +3924,13 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -3921,13 +3924,13 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getFunctionName(i); string name = force.getTabulatedFunctionName(i);
string arrayName = "table"+cl.intToString(i); string arrayName = "table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i)); functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", __global const float"; tableArgs << ", __global const float";
...@@ -4304,11 +4307,11 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -4304,11 +4307,11 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getFunctionName(i); string name = force.getTabulatedFunctionName(i);
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i)); functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"); OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction");
tabulatedFunctions.push_back(array); tabulatedFunctions.push_back(array);
array->upload(f); array->upload(f);
......
...@@ -220,7 +220,7 @@ void testContinuous2DFunction() { ...@@ -220,7 +220,7 @@ void testContinuous2DFunction() {
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y); table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
} }
} }
forceField->addFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax)); forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(1); vector<Vec3> positions(1);
...@@ -270,7 +270,7 @@ void testContinuous3DFunction() { ...@@ -270,7 +270,7 @@ void testContinuous3DFunction() {
} }
} }
} }
forceField->addFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax)); forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(1); vector<Vec3> positions(1);
......
...@@ -277,7 +277,7 @@ void testTabulatedFunction() { ...@@ -277,7 +277,7 @@ void testTabulatedFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0); force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force); system.addForce(force);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
......
...@@ -214,7 +214,7 @@ void testCustomFunctions() { ...@@ -214,7 +214,7 @@ void testCustomFunctions() {
vector<double> function(2); vector<double> function(2);
function[0] = 0; function[0] = 0;
function[1] = 1; function[1] = 1;
custom->addFunction("foo", function, 0, 10); custom->addTabulatedFunction("foo", new Continuous1DFunction(function, 0, 10));
system.addForce(custom); system.addForce(custom);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(3); vector<Vec3> positions(3);
......
...@@ -272,7 +272,7 @@ void testContinuous1DFunction() { ...@@ -272,7 +272,7 @@ void testContinuous1DFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0)); forceField->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -322,7 +322,7 @@ void testContinuous2DFunction() { ...@@ -322,7 +322,7 @@ void testContinuous2DFunction() {
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y); table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
} }
} }
forceField->addFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax)); forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -377,7 +377,7 @@ void testContinuous3DFunction() { ...@@ -377,7 +377,7 @@ void testContinuous3DFunction() {
} }
} }
} }
forceField->addFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax)); forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -416,7 +416,7 @@ void testDiscrete1DFunction() { ...@@ -416,7 +416,7 @@ void testDiscrete1DFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table)); forceField->addTabulatedFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -446,7 +446,7 @@ void testDiscrete2DFunction() { ...@@ -446,7 +446,7 @@ void testDiscrete2DFunction() {
for (int i = 0; i < xsize; i++) for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++) for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j)); table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addFunction("fn", new Discrete2DFunction(xsize, ysize, table)); forceField->addTabulatedFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -481,7 +481,7 @@ void testDiscrete3DFunction() { ...@@ -481,7 +481,7 @@ void testDiscrete3DFunction() {
for (int j = 0; j < ysize; j++) for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++) for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k); table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table)); forceField->addTabulatedFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment