Commit 75a04e7f authored by peastman's avatar peastman
Browse files

Renamed methods for dealing with TabulatedFunctions to avoid problems in...

Renamed methods for dealing with TabulatedFunctions to avoid problems in wrapper APIs.  Also fixed a few bugs.
parent f7e5492d
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -92,7 +92,7 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify the function by
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in the expression.
*/
......@@ -135,6 +135,14 @@ public:
/**
* Get the number of tabulated functions that have been defined.
*/
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const {
return functions.size();
}
......@@ -236,49 +244,46 @@ public:
* Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @return the index of the function that was added
*/
int addFunction(const std::string& name, TabulatedFunction* function);
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getFunction(int index) const;
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getFunction(int index);
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getFunctionName(int index) const;
const std::string& getTabulatedFunctionName(int index) const;
/**
* Add a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead.
* @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Get the parameters for a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws
* an exception.
* @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* If the specified function is not a Continuous1DFunction, this throws an exception.
*/
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/**
* Set the parameters for a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws
* an exception.
* @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* If the specified function is not a Continuous1DFunction, this throws an exception.
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -135,7 +135,7 @@ namespace OpenMM {
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
*
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify the function by
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in expressions.
*/
......@@ -210,6 +210,14 @@ public:
/**
* Get the number of tabulated functions that have been defined.
*/
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const {
return functions.size();
}
......@@ -462,49 +470,46 @@ public:
* Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @return the index of the function that was added
*/
int addFunction(const std::string& name, TabulatedFunction* function);
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getFunction(int index) const;
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getFunction(int index);
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in expressions.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getFunctionName(int index) const;
const std::string& getTabulatedFunctionName(int index) const;
/**
* Add a tabulated function that may appear in expressions.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead.
* @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Get the parameters for a tabulated function that may appear in expressions.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws
* an exception.
* @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* If the specified function is not a Continuous1DFunction, this throws an exception.
*/
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/**
* Set the parameters for a tabulated function that may appear in expressions.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws
* an exception.
* @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* If the specified function is not a Continuous1DFunction, this throws an exception.
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -92,7 +92,7 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify the function by
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in the expression.
*/
......@@ -165,6 +165,14 @@ public:
/**
* Get the number of tabulated functions that have been defined.
*/
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const {
return functions.size();
}
......@@ -381,49 +389,46 @@ public:
* Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @return the index of the function that was added
*/
int addFunction(const std::string& name, TabulatedFunction* function);
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getFunction(int index) const;
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getFunction(int index);
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getFunctionName(int index) const;
const std::string& getTabulatedFunctionName(int index) const;
/**
* Add a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead.
* @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Get the parameters for a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws
* an exception.
* @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* If the specified function is not a Continuous1DFunction, this throws an exception.
*/
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/**
* Set the parameters for a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws
* an exception.
* @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* If the specified function is not a Continuous1DFunction, this throws an exception.
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -125,7 +125,7 @@ namespace OpenMM {
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
*
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify the function by
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in the expression.
*/
......@@ -185,6 +185,14 @@ public:
/**
* Get the number of tabulated functions that have been defined.
*/
int getNumTabulatedFunctions() const {
return functions.size();
}
/**
* Get the number of tabulated functions that have been defined.
*
* @deprecated This method exists only for backward compatibility. Use getNumTabulatedFunctions() instead.
*/
int getNumFunctions() const {
return functions.size();
}
......@@ -366,49 +374,46 @@ public:
* Force takes over ownership of it, and deletes it when the Force itself is deleted.
* @return the index of the function that was added
*/
int addFunction(const std::string& name, TabulatedFunction* function);
int addTabulatedFunction(const std::string& name, TabulatedFunction* function);
/**
* Get a const reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
const TabulatedFunction& getFunction(int index) const;
const TabulatedFunction& getTabulatedFunction(int index) const;
/**
* Get a reference to a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the TabulatedFunction object defining the function
*/
TabulatedFunction& getFunction(int index);
TabulatedFunction& getTabulatedFunction(int index);
/**
* Get the name of a tabulated function that may appear in the energy expression.
*
* @param index the index of the function to get
* @return the name of the function as it appears in expressions
*/
const std::string& getFunctionName(int index) const;
const std::string& getTabulatedFunctionName(int index) const;
/**
* Add a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead.
* @deprecated This method exists only for backward compatibility. Use addTabulatedFunction() instead.
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Get the parameters for a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws
* an exception.
* @deprecated This method exists only for backward compatibility. Use getTabulatedFunctionParameters() instead.
* If the specified function is not a Continuous1DFunction, this throws an exception.
*/
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/**
* Set the parameters for a tabulated function that may appear in the energy expression.
*
* @deprecated This method exists only for backward compatibility. Use the version that takes
* a TabulatedFunction instead. If the specified function is not a Continuous1DFunction, this throws
* an exception.
* @deprecated This method exists only for backward compatibility. Use setTabulatedFunctionParameters() instead.
* If the specified function is not a Continuous1DFunction, this throws an exception.
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -126,22 +126,22 @@ void CustomCompoundBondForce::setBondParameters(int index, const vector<int>& pa
bonds[index].parameters = parameters;
}
int CustomCompoundBondForce::addFunction(const std::string& name, TabulatedFunction* function) {
int CustomCompoundBondForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomCompoundBondForce::getFunction(int index) const {
const TabulatedFunction& CustomCompoundBondForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomCompoundBondForce::getFunction(int index) {
TabulatedFunction& CustomCompoundBondForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomCompoundBondForce::getFunctionName(int index) const {
const string& CustomCompoundBondForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -178,22 +178,22 @@ void CustomGBForce::setExclusionParticles(int index, int particle1, int particle
exclusions[index].particle2 = particle2;
}
int CustomGBForce::addFunction(const std::string& name, TabulatedFunction* function) {
int CustomGBForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomGBForce::getFunction(int index) const {
const TabulatedFunction& CustomGBForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomGBForce::getFunction(int index) {
TabulatedFunction& CustomGBForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomGBForce::getFunctionName(int index) const {
const string& CustomGBForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -193,22 +193,22 @@ void CustomHbondForce::setExclusionParticles(int index, int donor, int acceptor)
exclusions[index].acceptor = acceptor;
}
int CustomHbondForce::addFunction(const std::string& name, TabulatedFunction* function) {
int CustomHbondForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomHbondForce::getFunction(int index) const {
const TabulatedFunction& CustomHbondForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomHbondForce::getFunction(int index) {
TabulatedFunction& CustomHbondForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomHbondForce::getFunctionName(int index) const {
const string& CustomHbondForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -174,22 +174,22 @@ void CustomNonbondedForce::setExclusionParticles(int index, int particle1, int p
exclusions[index].particle1 = particle1;
exclusions[index].particle2 = particle2;
}
int CustomNonbondedForce::addFunction(const std::string& name, TabulatedFunction* function) {
int CustomNonbondedForce::addTabulatedFunction(const std::string& name, TabulatedFunction* function) {
functions.push_back(FunctionInfo(name, function));
return functions.size()-1;
}
const TabulatedFunction& CustomNonbondedForce::getFunction(int index) const {
const TabulatedFunction& CustomNonbondedForce::getTabulatedFunction(int index) const {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
TabulatedFunction& CustomNonbondedForce::getFunction(int index) {
TabulatedFunction& CustomNonbondedForce::getTabulatedFunction(int index) {
ASSERT_VALID_INDEX(index, functions);
return *functions[index].function;
}
const string& CustomNonbondedForce::getFunctionName(int index) const {
const string& CustomNonbondedForce::getTabulatedFunctionName(int index) const {
ASSERT_VALID_INDEX(index, functions);
return functions[index].name;
}
......
......@@ -468,8 +468,23 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons
void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0])
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
// Make sure the arguments are identical.
for (int i = 0; i < (int) node.getChildren().size(); i++)
if (node.getChildren()[i] != searchNode.getChildren()[i])
return;
// See if we already have an identical node.
for (int i = 0; i < (int) nodes.size(); i++)
if (*nodes[i] == searchNode)
return;
// Add the node.
nodes.push_back(&searchNode);
}
else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
......
......@@ -1960,13 +1960,13 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
......@@ -2664,18 +2664,21 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
tableArgs << ", const float4* __restrict__ " << arrayName;
tableArgs << ", const float";
if (width > 1)
tableArgs << width;
tableArgs << "* __restrict__ " << arrayName;
}
// Record the global parameters.
......@@ -3758,13 +3761,13 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
vector<const TabulatedFunction*> functionList;
stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string arrayName = "table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", const float";
......@@ -4139,11 +4142,11 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
vector<const TabulatedFunction*> functionList;
stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
CudaArray* array = CudaArray::create<float>(cu, f.size(), "TabulatedFunction");
tabulatedFunctions.push_back(array);
array->upload(f);
......
......@@ -220,7 +220,7 @@ void testContinuous2DFunction() {
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
......@@ -270,7 +270,7 @@ void testContinuous3DFunction() {
}
}
}
forceField->addFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
......
......@@ -277,7 +277,7 @@ void testTabulatedFunction() {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0);
force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......
......@@ -214,7 +214,7 @@ void testCustomFunctions() {
vector<double> function(2);
function[0] = 0;
function[1] = 1;
custom->addFunction("foo", function, 0, 10);
custom->addTabulatedFunction("foo", new Continuous1DFunction(function, 0, 10));
system.addForce(custom);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
......
......@@ -272,7 +272,7 @@ void testContinuous1DFunction() {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
forceField->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -322,7 +322,7 @@ void testContinuous2DFunction() {
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -377,7 +377,7 @@ void testContinuous3DFunction() {
}
}
}
forceField->addFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -416,7 +416,7 @@ void testDiscrete1DFunction() {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table));
forceField->addTabulatedFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -447,7 +447,7 @@ void testDiscrete2DFunction() {
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addFunction("fn", new Discrete2DFunction(xsize, ysize, table));
forceField->addTabulatedFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -482,7 +482,7 @@ void testDiscrete3DFunction() {
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
forceField->addTabulatedFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......
......@@ -468,8 +468,23 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0])
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
// Make sure the arguments are identical.
for (int i = 0; i < (int) node.getChildren().size(); i++)
if (node.getChildren()[i] != searchNode.getChildren()[i])
return;
// See if we already have an identical node.
for (int i = 0; i < (int) nodes.size(); i++)
if (*nodes[i] == searchNode)
return;
// Add the node.
nodes.push_back(&searchNode);
}
else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
......
......@@ -1970,13 +1970,13 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
......@@ -2717,18 +2717,21 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
tableArgs << ", __global const float4* restrict " << arrayName;
tableArgs << ", __global const float";
if (width > 1)
tableArgs << width;
tableArgs << "* restrict " << arrayName;
}
// Record the global parameters.
......@@ -3921,13 +3924,13 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
vector<const TabulatedFunction*> functionList;
stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
string arrayName = "table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", __global const float";
......@@ -4304,11 +4307,11 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
vector<const TabulatedFunction*> functionList;
stringstream tableArgs;
for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction");
tabulatedFunctions.push_back(array);
array->upload(f);
......
......@@ -220,7 +220,7 @@ void testContinuous2DFunction() {
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
......@@ -270,7 +270,7 @@ void testContinuous3DFunction() {
}
}
}
forceField->addFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
......
......@@ -277,7 +277,7 @@ void testTabulatedFunction() {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0);
force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......
......@@ -214,7 +214,7 @@ void testCustomFunctions() {
vector<double> function(2);
function[0] = 0;
function[1] = 1;
custom->addFunction("foo", function, 0, 10);
custom->addTabulatedFunction("foo", new Continuous1DFunction(function, 0, 10));
system.addForce(custom);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
......
......@@ -272,7 +272,7 @@ void testContinuous1DFunction() {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
forceField->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -322,7 +322,7 @@ void testContinuous2DFunction() {
table[i+xsize*j] = sin(0.25*x)*cos(0.33*y);
}
}
forceField->addFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -377,7 +377,7 @@ void testContinuous3DFunction() {
}
}
}
forceField->addFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -416,7 +416,7 @@ void testDiscrete1DFunction() {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table));
forceField->addTabulatedFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -446,7 +446,7 @@ void testDiscrete2DFunction() {
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addFunction("fn", new Discrete2DFunction(xsize, ysize, table));
forceField->addTabulatedFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -481,7 +481,7 @@ void testDiscrete3DFunction() {
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
forceField->addTabulatedFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment