Unverified Commit a85c2428 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Point based geometric functions for custom forces (#3037)

* Began implementing geometric functions on points

* Started common implementation of point functions

* Completed implementation of point functions for CustomCompoundBondForce

* Implemented point functions for CustomCentroidBondForce

* Implemented point functions for CustomManyParticleForce

* Use point functions to simplify implementation of custom forces

* Removed unnecessary code

* Fixed typo
parent 2da337e9
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -107,6 +107,15 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, atan2, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta, select. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* select(x,y,z) = z if x = 0, y otherwise.
*
* This class also supports the functions pointdistance(x1, y1, z1, x2, y2, z2),
* pointangle(x1, y1, z1, x2, y2, z2, x3, y3, z3), and pointdihedral(x1, y1, z1, x2, y2, z2, x3, y3, z3, x4, y4, z4).
* These functions are similar to distance(), angle(), and dihedral(), but the arguments are the
* coordinates of points to perform the calculation based on rather than the names of groups.
* This enables more flexible geometric calculations. For example, the following computes the distance
* from group g1 to the midpoint between groups g2 and g3.
*
* <tt>CustomCentroidBondForce* force = new CustomCentroidBondForce(3, "pointdistance(x1, y1, z1, (x2+x3)/2, (y2+y3)/2, (z2+z3)/2)");</tt>
*
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in the expression.
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -96,6 +96,15 @@ namespace OpenMM {
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, atan2, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta, select. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* select(x,y,z) = z if x = 0, y otherwise.
*
* This class also supports the functions pointdistance(x1, y1, z1, x2, y2, z2),
* pointangle(x1, y1, z1, x2, y2, z2, x3, y3, z3), and pointdihedral(x1, y1, z1, x2, y2, z2, x3, y3, z3, x4, y4, z4).
* These functions are similar to distance(), angle(), and dihedral(), but the arguments are the
* coordinates of points to perform the calculation based on rather than the names of particles.
* This enables more flexible geometric calculations. For example, the following computes the distance
* from particle p1 to the midpoint between particles p2 and p3.
*
* <tt>CustomCompoundBondForce* force = new CustomCompoundBondForce(3, "pointdistance(x1, y1, z1, (x2+x3)/2, (y2+y3)/2, (z2+z3)/2)");</tt>
*
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in the expression.
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -154,6 +154,15 @@ namespace OpenMM {
* select(x,y,z) = z if x = 0, y otherwise. The names of per-particle parameters have the suffix "1", "2", etc. appended to them to indicate the values for
* the multiple interacting particles. For example, if you define a per-particle parameter called "charge", then the variable "charge2" is the charge of particle p2.
* As seen above, the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
*
* This class also supports the functions pointdistance(x1, y1, z1, x2, y2, z2),
* pointangle(x1, y1, z1, x2, y2, z2, x3, y3, z3), and pointdihedral(x1, y1, z1, x2, y2, z2, x3, y3, z3, x4, y4, z4).
* These functions are similar to distance(), angle(), and dihedral(), but the arguments are the
* coordinates of points to perform the calculation based on rather than the names of particles.
* This enables more flexible geometric calculations. For example, the following computes the distance
* from particle p1 to the midpoint between particles p2 and p3.
*
* <tt>CustomManyParticleForce* force = new CustomManyParticleForce(3, "pointdistance(x1, y1, z1, (x2+x3)/2, (y2+y3)/2, (z2+z3)/2)");</tt>
*
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
* creating a TabulatedFunction object. That function can then appear in the expression.
......
......@@ -67,21 +67,14 @@ public:
std::vector<std::pair<int, int> > getBondedParticles() const;
void updateParametersInContext(ContextImpl& context);
/**
* This is a utility routine that parses the energy expression, identifies the angles and dihedrals
* in it, and replaces them with variables.
* This is a utility routine that parses the energy expression, identifies group based functions,
* and replaces them with equivalent point based ones.
*
* @param force the CustomCentroidBondForce to process
* @param functions definitions of custom function that may appear in the expression
* @param distances on exit, this will contain an entry for each distance used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @param angles on exit, this will contain an entry for each angle used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @param dihedrals on exit, this will contain an entry for each dihedral used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @return a Parsed expression for the energy
*/
static Lepton::ParsedExpression prepareExpression(const CustomCentroidBondForce& force, const std::map<std::string, Lepton::CustomFunction*>& functions, std::map<std::string, std::vector<int> >& distances,
std::map<std::string, std::vector<int> >& angles, std::map<std::string, std::vector<int> >& dihedrals);
static Lepton::ParsedExpression prepareExpression(const CustomCentroidBondForce& force, const std::map<std::string, Lepton::CustomFunction*>& functions);
/**
* Compute the normalized weights to use for each particle in each group.
*
......@@ -92,9 +85,8 @@ public:
static void computeNormalizedWeights(const CustomCentroidBondForce& force, const System& system, std::vector<std::vector<double> >& weights);
private:
class FunctionPlaceholder;
static Lepton::ExpressionTreeNode replaceFunctions(const Lepton::ExpressionTreeNode& node, std::map<std::string, int> atoms,
std::map<std::string, std::vector<int> >& distances, std::map<std::string, std::vector<int> >& angles,
std::map<std::string, std::vector<int> >& dihedrals, std::set<std::string>& variables);
static Lepton::ExpressionTreeNode replaceFunctions(const Lepton::ExpressionTreeNode& node, std::map<std::string, int> groups,
const std::map<std::string, Lepton::CustomFunction*>& functions, std::set<std::string>& variables);
void addBondsBetweenGroups(int group1, int group2, std::vector<std::pair<int, int> >& bonds) const;
const CustomCentroidBondForce& owner;
Kernel kernel;
......
......@@ -65,26 +65,18 @@ public:
std::vector<std::string> getKernelNames();
void updateParametersInContext(ContextImpl& context);
/**
* This is a utility routine that parses the energy expression, identifies the angles and dihedrals
* in it, and replaces them with variables.
* This is a utility routine that parses the energy expression, identifies particle based functions,
* and replaces them with equivalent point based ones.
*
* @param force the CustomCompoundBondForce to process
* @param functions definitions of custom function that may appear in the expression
* @param distances on exit, this will contain an entry for each distance used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @param angles on exit, this will contain an entry for each angle used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @param dihedrals on exit, this will contain an entry for each dihedral used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @return a Parsed expression for the energy
*/
static Lepton::ParsedExpression prepareExpression(const CustomCompoundBondForce& force, const std::map<std::string, Lepton::CustomFunction*>& functions, std::map<std::string, std::vector<int> >& distances,
std::map<std::string, std::vector<int> >& angles, std::map<std::string, std::vector<int> >& dihedrals);
static Lepton::ParsedExpression prepareExpression(const CustomCompoundBondForce& force, const std::map<std::string, Lepton::CustomFunction*>& functions);
private:
class FunctionPlaceholder;
static Lepton::ExpressionTreeNode replaceFunctions(const Lepton::ExpressionTreeNode& node, std::map<std::string, int> atoms,
std::map<std::string, std::vector<int> >& distances, std::map<std::string, std::vector<int> >& angles,
std::map<std::string, std::vector<int> >& dihedrals, std::set<std::string>& variables);
const std::map<std::string, Lepton::CustomFunction*>& functions, std::set<std::string>& variables);
const CustomCompoundBondForce& owner;
Kernel kernel;
};
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -65,21 +65,14 @@ public:
std::vector<std::string> getKernelNames();
void updateParametersInContext(ContextImpl& context);
/**
* This is a utility routine that parses the energy expression, identifies the angles and dihedrals
* in it, and replaces them with variables.
* This is a utility routine that parses the energy expression, identifies particle based functions,
* and replaces them with equivalent point based ones.
*
* @param force the CustomManyParticleForce to process
* @param functions definitions of custom function that may appear in the expression
* @param distances on exit, this will contain an entry for each distance used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @param angles on exit, this will contain an entry for each angle used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @param dihedrals on exit, this will contain an entry for each dihedral used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @return a Parsed expression for the energy
*/
static Lepton::ParsedExpression prepareExpression(const CustomManyParticleForce& force, const std::map<std::string, Lepton::CustomFunction*>& functions, std::map<std::string, std::vector<int> >& distances,
std::map<std::string, std::vector<int> >& angles, std::map<std::string, std::vector<int> >& dihedrals);
static Lepton::ParsedExpression prepareExpression(const CustomManyParticleForce& force, const std::map<std::string, Lepton::CustomFunction*>& functions);
/**
* Analyze the type filters for a force and build a set of arrays that can be used for reordering the
* particles in an interaction.
......@@ -98,8 +91,7 @@ public:
private:
class FunctionPlaceholder;
static Lepton::ExpressionTreeNode replaceFunctions(const Lepton::ExpressionTreeNode& node, std::map<std::string, int> atoms,
std::map<std::string, std::vector<int> >& distances, std::map<std::string, std::vector<int> >& angles,
std::map<std::string, std::vector<int> >& dihedrals, std::set<std::string>& variables);
const std::map<std::string, Lepton::CustomFunction*>& functions, std::set<std::string>& variables);
static void generatePermutations(std::vector<int>& values, int numFixed, std::vector<std::vector<int> >& result);
const CustomManyParticleForce& owner;
Kernel kernel;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -138,16 +138,23 @@ map<string, double> CustomCentroidBondForceImpl::getDefaultParameters() {
return parameters;
}
ParsedExpression CustomCentroidBondForceImpl::prepareExpression(const CustomCentroidBondForce& force, const map<string, CustomFunction*>& customFunctions, map<string, vector<int> >& distances,
map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals) {
CustomCentroidBondForceImpl::FunctionPlaceholder custom(1);
ParsedExpression CustomCentroidBondForceImpl::prepareExpression(const CustomCentroidBondForce& force, const map<string, CustomFunction*>& customFunctions) {
CustomCentroidBondForceImpl::FunctionPlaceholder distance(2);
CustomCentroidBondForceImpl::FunctionPlaceholder angle(3);
CustomCentroidBondForceImpl::FunctionPlaceholder dihedral(4);
CustomCentroidBondForceImpl::FunctionPlaceholder pointdistance(6);
CustomCentroidBondForceImpl::FunctionPlaceholder pointangle(9);
CustomCentroidBondForceImpl::FunctionPlaceholder pointdihedral(12);
map<string, CustomFunction*> functions = customFunctions;
functions["distance"] = &distance;
functions["angle"] = &angle;
functions["dihedral"] = &dihedral;
if (functions.find("pointdistance") == functions.end())
functions["pointdistance"] = &pointdistance;
if (functions.find("pointangle") == functions.end())
functions["pointangle"] = &pointangle;
if (functions.find("pointdihedral") == functions.end())
functions["pointdihedral"] = &pointdihedral;
ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions);
map<string, int> groups;
set<string> variables;
......@@ -166,21 +173,20 @@ ParsedExpression CustomCentroidBondForceImpl::prepareExpression(const CustomCent
variables.insert(force.getGlobalParameterName(i));
for (int i = 0; i < force.getNumPerBondParameters(); i++)
variables.insert(force.getPerBondParameterName(i));
return ParsedExpression(replaceFunctions(expression.getRootNode(), groups, distances, angles, dihedrals, variables)).optimize();
return ParsedExpression(replaceFunctions(expression.getRootNode(), groups, functions, variables)).optimize();
}
ExpressionTreeNode CustomCentroidBondForceImpl::replaceFunctions(const ExpressionTreeNode& node, map<string, int> groups,
map<string, vector<int> >& distances, map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals, set<string>& variables) {
const map<string, CustomFunction*>& functions, set<string>& variables) {
const Operation& op = node.getOperation();
if (op.getId() == Operation::VARIABLE && variables.find(op.getName()) == variables.end())
throw OpenMMException("CustomCentroidBondForce: Unknown variable '"+op.getName()+"'");
if (op.getId() != Operation::CUSTOM || (op.getName() != "distance" && op.getName() != "angle" && op.getName() != "dihedral"))
{
// This is not an angle or dihedral, so process its children.
vector<ExpressionTreeNode> children;
if (op.getId() != Operation::CUSTOM || (op.getName() != "distance" && op.getName() != "angle" && op.getName() != "dihedral")) {
// The arguments are not group identifiers, so process its children.
vector<ExpressionTreeNode> children;
for (auto& child : node.getChildren())
children.push_back(replaceFunctions(child, groups, distances, angles, dihedrals, variables));
children.push_back(replaceFunctions(child, groups, functions, variables));
return ExpressionTreeNode(op.clone(), children);
}
const Operation::Custom& custom = static_cast<const Operation::Custom&>(op);
......@@ -195,29 +201,25 @@ ExpressionTreeNode CustomCentroidBondForceImpl::replaceFunctions(const Expressio
throw OpenMMException("CustomCentroidBondForce: Unknown group '"+node.getChildren()[i].getOperation().getName()+"'");
indices[i] = iter->second;
}
// Select a name for the variable and add it to the appropriate map.
stringstream variable;
if (numArgs == 2)
variable << "distance";
else if (numArgs == 3)
variable << "angle";
else
variable << "dihedral";
for (int i = 0; i < numArgs; i++)
variable << indices[i];
string name = variable.str();
if (numArgs == 2)
distances[name] = indices;
else if (numArgs == 3)
angles[name] = indices;
else
dihedrals[name] = indices;
// Return a new node that represents it as a simple variable.
return ExpressionTreeNode(new Operation::Variable(name));
// Replace it by the corresponding point based function.
for (int i = 0; i < numArgs; i++) {
stringstream x, y, z;
x << 'x' << (indices[i]+1);
y << 'y' << (indices[i]+1);
z << 'z' << (indices[i]+1);
children.push_back(ExpressionTreeNode(new Operation::Variable(x.str())));
children.push_back(ExpressionTreeNode(new Operation::Variable(y.str())));
children.push_back(ExpressionTreeNode(new Operation::Variable(z.str())));
}
if (op.getName() == "distance")
return ExpressionTreeNode(new Operation::Custom("pointdistance", functions.at("pointdistance")->clone()), children);
if (op.getName() == "angle")
return ExpressionTreeNode(new Operation::Custom("pointangle", functions.at("pointangle")->clone()), children);
if (op.getName() == "dihedral")
return ExpressionTreeNode(new Operation::Custom("pointdihedral", functions.at("pointdihedral")->clone()), children);
throw OpenMMException("Internal error. Unexpected function '"+op.getName()+"'");
}
vector<pair<int, int> > CustomCentroidBondForceImpl::getBondedParticles() const {
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -124,16 +124,23 @@ map<string, double> CustomCompoundBondForceImpl::getDefaultParameters() {
return parameters;
}
ParsedExpression CustomCompoundBondForceImpl::prepareExpression(const CustomCompoundBondForce& force, const map<string, CustomFunction*>& customFunctions, map<string, vector<int> >& distances,
map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals) {
CustomCompoundBondForceImpl::FunctionPlaceholder custom(1);
ParsedExpression CustomCompoundBondForceImpl::prepareExpression(const CustomCompoundBondForce& force, const map<string, CustomFunction*>& customFunctions) {
CustomCompoundBondForceImpl::FunctionPlaceholder distance(2);
CustomCompoundBondForceImpl::FunctionPlaceholder angle(3);
CustomCompoundBondForceImpl::FunctionPlaceholder dihedral(4);
CustomCompoundBondForceImpl::FunctionPlaceholder pointdistance(6);
CustomCompoundBondForceImpl::FunctionPlaceholder pointangle(9);
CustomCompoundBondForceImpl::FunctionPlaceholder pointdihedral(12);
map<string, CustomFunction*> functions = customFunctions;
functions["distance"] = &distance;
functions["angle"] = &angle;
functions["dihedral"] = &dihedral;
if (functions.find("pointdistance") == functions.end())
functions["pointdistance"] = &pointdistance;
if (functions.find("pointangle") == functions.end())
functions["pointangle"] = &pointangle;
if (functions.find("pointdihedral") == functions.end())
functions["pointdihedral"] = &pointdihedral;
ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions);
map<string, int> atoms;
set<string> variables;
......@@ -152,21 +159,20 @@ ParsedExpression CustomCompoundBondForceImpl::prepareExpression(const CustomComp
variables.insert(force.getGlobalParameterName(i));
for (int i = 0; i < force.getNumPerBondParameters(); i++)
variables.insert(force.getPerBondParameterName(i));
return ParsedExpression(replaceFunctions(expression.getRootNode(), atoms, distances, angles, dihedrals, variables)).optimize();
return ParsedExpression(replaceFunctions(expression.getRootNode(), atoms, functions, variables)).optimize();
}
ExpressionTreeNode CustomCompoundBondForceImpl::replaceFunctions(const ExpressionTreeNode& node, map<string, int> atoms,
map<string, vector<int> >& distances, map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals, set<string>& variables) {
const map<string, CustomFunction*>& functions, set<string>& variables) {
const Operation& op = node.getOperation();
if (op.getId() == Operation::VARIABLE && variables.find(op.getName()) == variables.end())
throw OpenMMException("CustomCompoundBondForce: Unknown variable '"+op.getName()+"'");
if (op.getId() != Operation::CUSTOM || (op.getName() != "distance" && op.getName() != "angle" && op.getName() != "dihedral"))
{
// This is not an angle or dihedral, so process its children.
vector<ExpressionTreeNode> children;
if (op.getId() != Operation::CUSTOM || (op.getName() != "distance" && op.getName() != "angle" && op.getName() != "dihedral")) {
// The arguments are not particle identifiers, so process its children.
vector<ExpressionTreeNode> children;
for (auto& child : node.getChildren())
children.push_back(replaceFunctions(child, atoms, distances, angles, dihedrals, variables));
children.push_back(replaceFunctions(child, atoms, functions, variables));
return ExpressionTreeNode(op.clone(), children);
}
const Operation::Custom& custom = static_cast<const Operation::Custom&>(op);
......@@ -181,29 +187,25 @@ ExpressionTreeNode CustomCompoundBondForceImpl::replaceFunctions(const Expressio
throw OpenMMException("CustomCompoundBondForce: Unknown particle '"+node.getChildren()[i].getOperation().getName()+"'");
indices[i] = iter->second;
}
// Select a name for the variable and add it to the appropriate map.
stringstream variable;
if (numArgs == 2)
variable << "distance";
else if (numArgs == 3)
variable << "angle";
else
variable << "dihedral";
for (int i = 0; i < numArgs; i++)
variable << indices[i];
string name = variable.str();
if (numArgs == 2)
distances[name] = indices;
else if (numArgs == 3)
angles[name] = indices;
else
dihedrals[name] = indices;
// Return a new node that represents it as a simple variable.
return ExpressionTreeNode(new Operation::Variable(name));
// Replace it by the corresponding point based function.
for (int i = 0; i < numArgs; i++) {
stringstream x, y, z;
x << 'x' << (indices[i]+1);
y << 'y' << (indices[i]+1);
z << 'z' << (indices[i]+1);
children.push_back(ExpressionTreeNode(new Operation::Variable(x.str())));
children.push_back(ExpressionTreeNode(new Operation::Variable(y.str())));
children.push_back(ExpressionTreeNode(new Operation::Variable(z.str())));
}
if (op.getName() == "distance")
return ExpressionTreeNode(new Operation::Custom("pointdistance", functions.at("pointdistance")->clone()), children);
if (op.getName() == "angle")
return ExpressionTreeNode(new Operation::Custom("pointangle", functions.at("pointangle")->clone()), children);
if (op.getName() == "dihedral")
return ExpressionTreeNode(new Operation::Custom("pointdihedral", functions.at("pointdihedral")->clone()), children);
throw OpenMMException("Internal error. Unexpected function '"+op.getName()+"'");
}
void CustomCompoundBondForceImpl::updateParametersInContext(ContextImpl& context) {
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -153,16 +153,23 @@ map<string, double> CustomManyParticleForceImpl::getDefaultParameters() {
return parameters;
}
ParsedExpression CustomManyParticleForceImpl::prepareExpression(const CustomManyParticleForce& force, const map<string, CustomFunction*>& customFunctions, map<string, vector<int> >& distances,
map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals) {
CustomManyParticleForceImpl::FunctionPlaceholder custom(1);
ParsedExpression CustomManyParticleForceImpl::prepareExpression(const CustomManyParticleForce& force, const map<string, CustomFunction*>& customFunctions) {
CustomManyParticleForceImpl::FunctionPlaceholder distance(2);
CustomManyParticleForceImpl::FunctionPlaceholder angle(3);
CustomManyParticleForceImpl::FunctionPlaceholder dihedral(4);
CustomManyParticleForceImpl::FunctionPlaceholder pointdistance(6);
CustomManyParticleForceImpl::FunctionPlaceholder pointangle(9);
CustomManyParticleForceImpl::FunctionPlaceholder pointdihedral(12);
map<string, CustomFunction*> functions = customFunctions;
functions["distance"] = &distance;
functions["angle"] = &angle;
functions["dihedral"] = &dihedral;
if (functions.find("pointdistance") == functions.end())
functions["pointdistance"] = &pointdistance;
if (functions.find("pointangle") == functions.end())
functions["pointangle"] = &pointangle;
if (functions.find("pointdihedral") == functions.end())
functions["pointdihedral"] = &pointdihedral;
ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions);
map<string, int> atoms;
set<string> variables;
......@@ -184,21 +191,20 @@ ParsedExpression CustomManyParticleForceImpl::prepareExpression(const CustomMany
}
for (int i = 0; i < force.getNumGlobalParameters(); i++)
variables.insert(force.getGlobalParameterName(i));
return ParsedExpression(replaceFunctions(expression.getRootNode(), atoms, distances, angles, dihedrals, variables)).optimize();
return ParsedExpression(replaceFunctions(expression.getRootNode(), atoms, functions, variables)).optimize();
}
ExpressionTreeNode CustomManyParticleForceImpl::replaceFunctions(const ExpressionTreeNode& node, map<string, int> atoms,
map<string, vector<int> >& distances, map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals, set<string>& variables) {
const map<string, CustomFunction*>& functions, set<string>& variables) {
const Operation& op = node.getOperation();
if (op.getId() == Operation::VARIABLE && variables.find(op.getName()) == variables.end())
throw OpenMMException("CustomManyParticleForce: Unknown variable '"+op.getName()+"'");
if (op.getId() != Operation::CUSTOM || (op.getName() != "distance" && op.getName() != "angle" && op.getName() != "dihedral"))
{
// This is not an angle or dihedral, so process its children.
vector<ExpressionTreeNode> children;
if (op.getId() != Operation::CUSTOM || (op.getName() != "distance" && op.getName() != "angle" && op.getName() != "dihedral")) {
// The arguments are not particle identifiers, so process its children.
vector<ExpressionTreeNode> children;
for (auto& child : node.getChildren())
children.push_back(replaceFunctions(child, atoms, distances, angles, dihedrals, variables));
children.push_back(replaceFunctions(child, atoms, functions, variables));
return ExpressionTreeNode(op.clone(), children);
}
const Operation::Custom& custom = static_cast<const Operation::Custom&>(op);
......@@ -213,29 +219,25 @@ ExpressionTreeNode CustomManyParticleForceImpl::replaceFunctions(const Expressio
throw OpenMMException("CustomManyParticleForce: Unknown particle '"+node.getChildren()[i].getOperation().getName()+"'");
indices[i] = iter->second;
}
// Select a name for the variable and add it to the appropriate map.
stringstream variable;
if (numArgs == 2)
variable << "distance";
else if (numArgs == 3)
variable << "angle";
else
variable << "dihedral";
for (int i = 0; i < numArgs; i++)
variable << indices[i];
string name = variable.str();
if (numArgs == 2)
distances[name] = indices;
else if (numArgs == 3)
angles[name] = indices;
else
dihedrals[name] = indices;
// Return a new node that represents it as a simple variable.
return ExpressionTreeNode(new Operation::Variable(name));
// Replace it by the corresponding point based function.
for (int i = 0; i < numArgs; i++) {
stringstream x, y, z;
x << 'x' << (indices[i]+1);
y << 'y' << (indices[i]+1);
z << 'z' << (indices[i]+1);
children.push_back(ExpressionTreeNode(new Operation::Variable(x.str())));
children.push_back(ExpressionTreeNode(new Operation::Variable(y.str())));
children.push_back(ExpressionTreeNode(new Operation::Variable(z.str())));
}
if (op.getName() == "distance")
return ExpressionTreeNode(new Operation::Custom("pointdistance", functions.at("pointdistance")->clone()), children);
if (op.getName() == "angle")
return ExpressionTreeNode(new Operation::Custom("pointangle", functions.at("pointangle")->clone()), children);
if (op.getName() == "dihedral")
return ExpressionTreeNode(new Operation::Custom("pointdihedral", functions.at("pointdihedral")->clone()), children);
throw OpenMMException("Internal error. Unexpected function '"+op.getName()+"'");
}
void CustomManyParticleForceImpl::updateParametersInContext(ContextImpl& context) {
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -57,10 +57,12 @@ public:
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param tempType the type of value to use for temporary variables (defaults to "real")
* @param distancesArePeriodic whether the distances in pointdistance(), pointangle(), and pointdihedral() functions
* should have periodic boundary conditions applied
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& tempType="real");
const std::string& prefix, const std::string& tempType="real", bool distancesArePeriodic=false);
/**
* Generate the source code for calculating a set of expressions.
*
......@@ -71,10 +73,12 @@ public:
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param tempType the type of value to use for temporary variables (defaults to "real")
* @param distancesArePeriodic whether the distances in pointdistance(), pointangle(), and pointdihedral() functions
* should have periodic boundary conditions applied
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& tempType="real");
const std::string& prefix, const std::string& tempType="real", bool distancesArePeriodic=false);
/**
* Calculate the spline coefficients for a tabulated function that appears in expressions.
*
......@@ -116,7 +120,8 @@ private:
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions,
const std::string& tempType, bool distancesArePeriodic);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedCustomFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes);
......@@ -124,6 +129,8 @@ private:
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
void callFunction(std::stringstream& out, std::string singleFn, std::string doubleFn, const std::string& arg, const std::string& tempType);
void callFunction2(std::stringstream& out, std::string singleFn, std::string doubleFn, const std::string& arg1, const std::string& arg2, const std::string& tempType);
void computeDelta(std::stringstream& out, const std::string& varName, const Lepton::ExpressionTreeNode& node, int index1, int index2, const std::string& tempType,
bool periodic, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
ComputeContext& context;
FunctionPlaceholder fp1, fp2, fp3, periodicDistance;
......
This diff is collapsed.
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -37,15 +37,17 @@ ExpressionUtilities::ExpressionUtilities(ComputeContext& context) : context(cont
}
string ExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& tempType, bool distancesArePeriodic) {
vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType, distancesArePeriodic);
}
string ExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType,
bool distancesArePeriodic) {
stringstream out;
vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
......@@ -53,7 +55,7 @@ string ExpressionUtilities::createExpressions(const map<string, ParsedExpression
vector<pair<ExpressionTreeNode, string> > temps = variables;
vector<vector<double> > functionParams = computeFunctionParameters(functions);
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType, distancesArePeriodic);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
......@@ -61,12 +63,12 @@ string ExpressionUtilities::createExpressions(const map<string, ParsedExpression
void ExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) {
const vector<ParsedExpression>& allExpressions, const string& tempType, bool distancesArePeriodic) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType, distancesArePeriodic);
string name = prefix+context.intToString(temps.size());
bool hasRecordedNode = false;
bool isVecType = (tempType[tempType.size()-1] == '3');
......@@ -109,43 +111,136 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
temps.push_back(make_pair(*nodes[j], name2));
}
out << "{\n";
if (node.getOperation().getName() == "periodicdistance") {
// This is the periodicdistance() function.
out << tempType << "3 periodicDistance_delta = make_real3(";
for (int i = 0; i < 3; i++) {
if (i > 0)
out << ", ";
out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps);
}
out << ");\n";
out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
if (node.getOperation().getName() == "pointdistance" || node.getOperation().getName() == "periodicdistance") {
// This is a pointdistance() or periodicdistance() function.
bool periodic = (node.getOperation().getName() == "periodicdistance" || distancesArePeriodic);
computeDelta(out, "distance_delta", node, 0, 3, tempType, periodic, temps);
out << tempType << " distance_rinv = RSQRT(distance_delta.w);\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
int argIndex = -1;
for (int k = 0; k < 6; k++) {
if (derivOrder[k] > 0) {
if (derivOrder[k] > 1 || argIndex != -1)
throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen.
throw OpenMMException("Unsupported derivative of "+node.getOperation().getName()); // Should be impossible for this to happen.
argIndex = k;
}
}
if (argIndex == -1)
out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
out << nodeNames[j] << " = RECIP(distance_rinv);\n";
else if (argIndex == 0)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
out << nodeNames[j] << " = (distance_delta.w > 0 ? distance_delta.x*distance_rinv : 0);\n";
else if (argIndex == 1)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
out << nodeNames[j] << " = (distance_delta.w > 0 ? distance_delta.y*distance_rinv : 0);\n";
else if (argIndex == 2)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
out << nodeNames[j] << " = (distance_delta.w > 0 ? distance_delta.z*distance_rinv : 0);\n";
else if (argIndex == 3)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
out << nodeNames[j] << " = (distance_delta.w > 0 ? -distance_delta.x*distance_rinv : 0);\n";
else if (argIndex == 4)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
out << nodeNames[j] << " = (distance_delta.w > 0 ? -distance_delta.y*distance_rinv : 0);\n";
else if (argIndex == 5)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
out << nodeNames[j] << " = (distance_delta.w > 0 ? -distance_delta.z*distance_rinv : 0);\n";
}
}
else if (node.getOperation().getName() == "pointangle") {
// This is a pointangle() function.
computeDelta(out, "angle_delta21", node, 3, 0, tempType, distancesArePeriodic, temps);
computeDelta(out, "angle_delta23", node, 3, 6, tempType, distancesArePeriodic, temps);
out << tempType << " angle_theta = computeAngle(angle_delta21, angle_delta23);\n";
out << tempType << "3 angle_crossProd = trimTo3(cross(angle_delta23, angle_delta21));\n";
out << "real angle_lengthCross = max(SQRT(dot(angle_crossProd, angle_crossProd)), (real) 1e-6f);\n";
out << "real3 angle_deltaCross0 = cross(trimTo3(angle_delta21), angle_crossProd)/(angle_delta21.w*angle_lengthCross);\n";
out << "real3 angle_deltaCross2 = -cross(trimTo3(angle_delta23), angle_crossProd)/(angle_delta23.w*angle_lengthCross);\n";
out << "real3 angle_deltaCross1 = -(angle_deltaCross0+angle_deltaCross2);\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
int argIndex = -1;
for (int k = 0; k < 9; k++) {
if (derivOrder[k] > 0) {
if (derivOrder[k] > 1 || argIndex != -1)
throw OpenMMException("Unsupported derivative of "+node.getOperation().getName()); // Should be impossible for this to happen.
argIndex = k;
}
}
if (argIndex == -1)
out << nodeNames[j] << " = angle_theta;\n";
else if (argIndex == 0)
out << nodeNames[j] << " = angle_deltaCross0.x;\n";
else if (argIndex == 1)
out << nodeNames[j] << " = angle_deltaCross0.y;\n";
else if (argIndex == 2)
out << nodeNames[j] << " = angle_deltaCross0.z;\n";
else if (argIndex == 3)
out << nodeNames[j] << " = angle_deltaCross1.x;\n";
else if (argIndex == 4)
out << nodeNames[j] << " = angle_deltaCross1.y;\n";
else if (argIndex == 5)
out << nodeNames[j] << " = angle_deltaCross1.z;\n";
else if (argIndex == 6)
out << nodeNames[j] << " = angle_deltaCross2.x;\n";
else if (argIndex == 7)
out << nodeNames[j] << " = angle_deltaCross2.y;\n";
else if (argIndex == 8)
out << nodeNames[j] << " = angle_deltaCross2.z;\n";
}
}
else if (node.getOperation().getName() == "pointdihedral") {
// This is a pointdihedral() function.
computeDelta(out, "dihedral_delta12", node, 0, 3, tempType, distancesArePeriodic, temps);
computeDelta(out, "dihedral_delta32", node, 6, 3, tempType, distancesArePeriodic, temps);
computeDelta(out, "dihedral_delta34", node, 6, 9, tempType, distancesArePeriodic, temps);
out << tempType << "4 dihedral_cross1 = computeCross(dihedral_delta12, dihedral_delta32);\n";
out << tempType << "4 dihedral_cross2 = computeCross(dihedral_delta32, dihedral_delta34);\n";
out << tempType << " dihedral_theta = computeAngle(dihedral_cross1, dihedral_cross2);\n";
out << "dihedral_theta *= (dihedral_delta12.x*dihedral_cross2.x + dihedral_delta12.y*dihedral_cross2.y + dihedral_delta12.z*dihedral_cross2.z < 0 ? -1 : 1);\n";
out << tempType << " dihedral_r = SQRT(dihedral_delta32.w);\n";
out << tempType << "4 dihedral_ff;\n";
out << "dihedral_ff.x = -dihedral_r/dihedral_cross1.w;\n";
out << "dihedral_ff.y = (dihedral_delta12.x*dihedral_delta32.x + dihedral_delta12.y*dihedral_delta32.y + dihedral_delta12.z*dihedral_delta32.z)/dihedral_delta32.w;\n";
out << "dihedral_ff.z = (dihedral_delta34.x*dihedral_delta32.x + dihedral_delta34.y*dihedral_delta32.y + dihedral_delta34.z*dihedral_delta32.z)/dihedral_delta32.w;\n";
out << "dihedral_ff.w = dihedral_r/dihedral_cross2.w;\n";
out << tempType << "3 dihedral_internalF0 = dihedral_ff.x*trimTo3(dihedral_cross1);\n";
out << tempType << "3 dihedral_internalF3 = dihedral_ff.w*trimTo3(dihedral_cross2);\n";
out << tempType << "3 dihedral_s = dihedral_ff.y*dihedral_internalF0 - dihedral_ff.z*dihedral_internalF3;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
int argIndex = -1;
for (int k = 0; k < 12; k++) {
if (derivOrder[k] > 0) {
if (derivOrder[k] > 1 || argIndex != -1)
throw OpenMMException("Unsupported derivative of "+node.getOperation().getName()); // Should be impossible for this to happen.
argIndex = k;
}
}
if (argIndex == -1)
out << nodeNames[j] << " = dihedral_theta;\n";
else if (argIndex == 0)
out << nodeNames[j] << " = -dihedral_internalF0.x;\n";
else if (argIndex == 1)
out << nodeNames[j] << " = -dihedral_internalF0.y;\n";
else if (argIndex == 2)
out << nodeNames[j] << " = -dihedral_internalF0.z;\n";
else if (argIndex == 3)
out << nodeNames[j] << " = -dihedral_s.x+dihedral_internalF0.x;\n";
else if (argIndex == 4)
out << nodeNames[j] << " = -dihedral_s.y+dihedral_internalF0.y;\n";
else if (argIndex == 5)
out << nodeNames[j] << " = -dihedral_s.z+dihedral_internalF0.z;\n";
else if (argIndex == 6)
out << nodeNames[j] << " = dihedral_s.x+dihedral_internalF3.x;\n";
else if (argIndex == 7)
out << nodeNames[j] << " = dihedral_s.y+dihedral_internalF3.y;\n";
else if (argIndex == 8)
out << nodeNames[j] << " = dihedral_s.z+dihedral_internalF3.z;\n";
else if (argIndex == 9)
out << nodeNames[j] << " = -dihedral_internalF3.x;\n";
else if (argIndex == 10)
out << nodeNames[j] << " = -dihedral_internalF3.y;\n";
else if (argIndex == 11)
out << nodeNames[j] << " = -dihedral_internalF3.z;\n";
}
}
else if (node.getOperation().getName() == "dot") {
......@@ -998,3 +1093,19 @@ void ExpressionUtilities::callFunction2(stringstream& out, string singleFn, stri
else
out<<fn<<"(("<<tempType<<") "<<arg1<<", ("<<tempType<<") "<<arg2<<")";
}
void ExpressionUtilities::computeDelta(stringstream& out, const string& varName, const ExpressionTreeNode& node, int index1, int index2, const string& tempType, bool periodic, const vector<pair<ExpressionTreeNode, string> >& temps) {
// Compute the (optionally periodic) displacement between two points, storing the distance
// into the w component.
out << tempType << "4 " << varName << " = make_" << tempType << "4(";
for (int i = 0; i < 3; i++) {
if (i > 0)
out << ", ";
out << getTempName(node.getChildren()[index1+i], temps) << "-" << getTempName(node.getChildren()[index2+i], temps);
}
out << ", 0);\n";
if (periodic)
out << "APPLY_PERIODIC_TO_DELTA(" << varName << ")\n";
out << varName << ".w = " << varName << ".x*" << varName << ".x + " << varName << ".y*" << varName << ".y + " << varName << ".z*" << varName << ".z;\n";
}
......@@ -61,47 +61,6 @@ KERNEL void computeGroupCenters(int numParticleGroups, GLOBAL const real4* RESTR
}
}
/**
* Compute the difference between two vectors, setting the fourth component to the squared magnitude.
*/
DEVICE real4 delta(real4 vec1, real4 vec2, bool periodic, real4 periodicBoxSize, real4 invPeriodicBoxSize,
real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ) {
real4 result = make_real4(vec1.x-vec2.x, vec1.y-vec2.y, vec1.z-vec2.z, 0);
if (periodic)
APPLY_PERIODIC_TO_DELTA(result);
result.w = result.x*result.x + result.y*result.y + result.z*result.z;
return result;
}
/**
* Compute the angle between two vectors. The w component of each vector should contain the squared magnitude.
*/
DEVICE real computeAngle(real4 vec1, real4 vec2) {
real dotProduct = vec1.x*vec2.x + vec1.y*vec2.y + vec1.z*vec2.z;
real cosine = dotProduct*RSQRT(vec1.w*vec2.w);
real angle;
if (cosine > 0.99f || cosine < -0.99f) {
// We're close to the singularity in acos(), so take the cross product and use asin() instead.
real3 crossProduct = cross(trimTo3(vec1), trimTo3(vec2));
real scale = vec1.w*vec2.w;
angle = ASIN(SQRT(dot(crossProduct, crossProduct)/scale));
if (cosine < 0)
angle = M_PI-angle;
}
else
angle = ACOS(cosine);
return angle;
}
/**
* Compute the cross product of two vectors, setting the fourth component to the squared magnitude.
*/
DEVICE real4 computeCross(real4 vec1, real4 vec2) {
real3 cp = cross(trimTo3(vec1), trimTo3(vec2));
return make_real4(cp.x, cp.y, cp.z, cp.x*cp.x+cp.y*cp.y+cp.z*cp.z);
}
/**
* Compute the forces on groups based on the bonds.
*/
......
......@@ -20,35 +20,6 @@ inline DEVICE real4 delta(real3 vec1, real3 vec2, real4 periodicBoxSize, real4 i
return result;
}
/**
* Compute the angle between two vectors. The w component of each vector should contain the squared magnitude.
*/
DEVICE real computeAngle(real4 vec1, real4 vec2) {
real dotProduct = vec1.x*vec2.x + vec1.y*vec2.y + vec1.z*vec2.z;
real cosine = dotProduct*RSQRT(vec1.w*vec2.w);
real angle;
if (cosine > 0.99f || cosine < -0.99f) {
// We're close to the singularity in acos(), so take the cross product and use asin() instead.
real3 crossProduct = trimTo3(cross(vec1, vec2));
real scale = vec1.w*vec2.w;
angle = ASIN(SQRT(dot(crossProduct, crossProduct)/scale));
if (cosine < 0.0f)
angle = M_PI-angle;
}
else
angle = ACOS(cosine);
return angle;
}
/**
* Compute the cross product of two vectors, setting the fourth component to the squared magnitude.
*/
inline DEVICE real4 computeCross(real4 vec1, real4 vec2) {
real3 cp = trimTo3(cross(vec1, vec2));
return make_real4(cp.x, cp.y, cp.z, cp.x*cp.x+cp.y*cp.y+cp.z*cp.z);
}
/**
* Determine whether a particular interaction is in the list of exclusions.
*/
......
/**
* Compute the difference between two vectors, setting the fourth component to the squared magnitude.
*/
DEVICE real4 ccb_delta(real4 vec1, real4 vec2, bool periodic, real4 periodicBoxSize, real4 invPeriodicBoxSize,
real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ) {
real4 result = make_real4(vec1.x-vec2.x, vec1.y-vec2.y, vec1.z-vec2.z, 0);
if (periodic)
APPLY_PERIODIC_TO_DELTA(result);
result.w = result.x*result.x + result.y*result.y + result.z*result.z;
return result;
}
/**
* Compute the angle between two vectors. The w component of each vector should contain the squared magnitude.
*/
DEVICE real ccb_computeAngle(real4 vec1, real4 vec2) {
DEVICE real computeAngle(real4 vec1, real4 vec2) {
real dotProduct = vec1.x*vec2.x + vec1.y*vec2.y + vec1.z*vec2.z;
real cosine = dotProduct*RSQRT(vec1.w*vec2.w);
real angle;
......@@ -34,7 +22,7 @@ DEVICE real ccb_computeAngle(real4 vec1, real4 vec2) {
/**
* Compute the cross product of two vectors, setting the fourth component to the squared magnitude.
*/
DEVICE real4 ccb_computeCross(real4 vec1, real4 vec2) {
DEVICE real4 computeCross(real4 vec1, real4 vec2) {
real3 cp = cross(trimTo3(vec1), trimTo3(vec2));
return make_real4(cp.x, cp.y, cp.z, cp.x*cp.x+cp.y*cp.y+cp.z*cp.z);
}
......
/* Portions copyright (c) 2009-2018 Stanford University and Simbios.
/* Portions copyright (c) 2009-2021 Stanford University and Simbios.
* Contributors: Peter Eastman
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -46,15 +45,13 @@ class CpuCustomManyParticleForce {
private:
class ParticleTermInfo;
class DistanceTermInfo;
class AngleTermInfo;
class DihedralTermInfo;
class ThreadData;
int numParticles, numParticlesPerSet, numPerParticleParameters, numTypes;
bool useCutoff, usePeriodic, triclinic, centralParticleMode;
double cutoffDistance;
float recipBoxSize[3];
Vec3 periodicBoxVectors[3];
Vec3* boxVectorsRef;
AlignedArray<fvec4> periodicBoxVec4;
CpuNeighborList* neighborList;
ThreadPool& threads;
......@@ -112,10 +109,6 @@ private:
* periodic boundary conditions.
*/
void computeDelta(const fvec4& posI, const fvec4& posJ, fvec4& deltaR, float& r2, const fvec4& boxSize, const fvec4& invBoxSize) const;
static float computeAngle(const fvec4& vi, const fvec4& vj, float v2i, float v2j, float sign);
static float getDihedralAngleBetweenThreeVectors(const fvec4& v1, const fvec4& v2, const fvec4& v3, fvec4& cross1, fvec4& cross2, const fvec4& signVector);
public:
/**
......@@ -167,57 +160,16 @@ public:
ParticleTermInfo(const std::string& name, int atom, int component, const Lepton::CompiledExpression& forceExpression, ThreadData& data);
};
class CpuCustomManyParticleForce::DistanceTermInfo {
public:
std::string name;
int p1, p2, variableIndex;
Lepton::CompiledExpression forceExpression;
int delta;
float deltaSign;
DistanceTermInfo(const std::string& name, const std::vector<int>& atoms, const Lepton::CompiledExpression& forceExpression, ThreadData& data);
};
class CpuCustomManyParticleForce::AngleTermInfo {
public:
std::string name;
int p1, p2, p3, variableIndex;
Lepton::CompiledExpression forceExpression;
int delta1, delta2;
float delta1Sign, delta2Sign;
AngleTermInfo(const std::string& name, const std::vector<int>& atoms, const Lepton::CompiledExpression& forceExpression, ThreadData& data);
};
class CpuCustomManyParticleForce::DihedralTermInfo {
public:
std::string name;
int p1, p2, p3, p4, variableIndex;
Lepton::CompiledExpression forceExpression;
int delta1, delta2, delta3;
DihedralTermInfo(const std::string& name, const std::vector<int>& atoms, const Lepton::CompiledExpression& forceExpression, ThreadData& data);
};
class CpuCustomManyParticleForce::ThreadData {
public:
CompiledExpressionSet expressionSet;
Lepton::CompiledExpression energyExpression;
std::vector<std::vector<int> > particleParamIndices;
std::vector<int> permutedParticles;
std::vector<std::pair<int, int> > deltaPairs;
std::vector<ParticleTermInfo> particleTerms;
std::vector<DistanceTermInfo> distanceTerms;
std::vector<AngleTermInfo> angleTerms;
std::vector<DihedralTermInfo> dihedralTerms;
AlignedArray<fvec4> delta, cross1, cross2;
std::vector<float> normDelta;
std::vector<float> norm2Delta;
AlignedArray<fvec4> f;
double energy;
ThreadData(const CustomManyParticleForce& force, Lepton::ParsedExpression& energyExpr,
std::map<std::string, std::vector<int> >& distances, std::map<std::string, std::vector<int> >& angles, std::map<std::string, std::vector<int> >& dihedrals);
/**
* Request a pair of particles whose distance or displacement vector is needed in the computation.
*/
void requestDeltaPair(int p1, int p2, int& pairIndex, float& pairSign, bool allowReversed);
ThreadData(const CustomManyParticleForce& force, Lepton::ParsedExpression& energyExpr);
};
} // namespace OpenMM
......
/* Portions copyright (c) 2009-2018 Stanford University and Simbios.
/* Portions copyright (c) 2009-2021 Stanford University and Simbios.
* Contributors: Peter Eastman
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -29,6 +28,7 @@
#include "SimTKOpenMMUtilities.h"
#include "ReferenceForce.h"
#include "CpuCustomManyParticleForce.h"
#include "ReferencePointFunctions.h"
#include "ReferenceTabulatedFunction.h"
#include "openmm/internal/CustomManyParticleForceImpl.h"
#include "lepton/CustomFunction.h"
......@@ -49,14 +49,17 @@ CpuCustomManyParticleForce::CpuCustomManyParticleForce(const CustomManyParticleF
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create implementations of point functions.
functions["pointdistance"] = new ReferencePointDistanceFunction(force.usesPeriodicBoundaryConditions(), &boxVectorsRef);
functions["pointangle"] = new ReferencePointAngleFunction(force.usesPeriodicBoundaryConditions(), &boxVectorsRef);
functions["pointdihedral"] = new ReferencePointDihedralFunction(force.usesPeriodicBoundaryConditions(), &boxVectorsRef);
// Parse the expression and create the objects used to calculate the interaction.
map<string, vector<int> > distances;
map<string, vector<int> > angles;
map<string, vector<int> > dihedrals;
Lepton::ParsedExpression energyExpr = CustomManyParticleForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
Lepton::ParsedExpression energyExpr = CustomManyParticleForceImpl::prepareExpression(force, functions);
for (int i = 0; i < threads.getNumThreads(); i++)
threadData.push_back(new ThreadData(force, energyExpr, distances, angles, dihedrals));
threadData.push_back(new ThreadData(force, energyExpr));
if (force.getNonbondedMethod() != CustomManyParticleForce::NoCutoff)
setUseCutoff(force.getCutoffDistance());
......@@ -190,6 +193,7 @@ void CpuCustomManyParticleForce::setPeriodic(Vec3* periodicBoxVectors) {
assert(periodicBoxVectors[1][1] >= 2.0*cutoffDistance);
assert(periodicBoxVectors[2][2] >= 2.0*cutoffDistance);
usePeriodic = true;
this->boxVectorsRef = periodicBoxVectors;
this->periodicBoxVectors[0] = periodicBoxVectors[0];
this->periodicBoxVectors[1] = periodicBoxVectors[1];
this->periodicBoxVectors[2] = periodicBoxVectors[2];
......@@ -266,37 +270,14 @@ void CpuCustomManyParticleForce::calculateOneIxn(vector<int>& particleSet, vecto
for (int i = 0; i < numParticlesPerSet; i++)
for (int j = 0; j < numPerParticleParameters; j++)
expressionSet.setVariable(data.particleParamIndices[i][j], particleParameters[permutedParticles[i]][j]);
// Compute inter-particle deltas.
int numDeltas = data.deltaPairs.size();
AlignedArray<fvec4>& delta = data.delta;
AlignedArray<fvec4>& cross1 = data.cross1;
AlignedArray<fvec4>& cross2 = data.cross2;
vector<float>& normDelta = data.normDelta;
vector<float>& norm2Delta = data.norm2Delta;
for (int i = 0; i < numDeltas; i++) {
int p1 = permutedParticles[data.deltaPairs[i].first];
int p2 = permutedParticles[data.deltaPairs[i].second];
computeDelta(fvec4(posq+4*p1), fvec4(posq+4*p2), delta[i], norm2Delta[i], boxSize, invBoxSize);
normDelta[i] = sqrtf(norm2Delta[i]);
}
// Compute all of the variables the energy can depend on.
// Record particle coordinates.
for (auto& term : data.particleTerms)
expressionSet.setVariable(term.variableIndex, posq[4*permutedParticles[term.atom]+term.component]);
for (auto& term : data.distanceTerms)
expressionSet.setVariable(term.variableIndex, normDelta[term.delta]);
for (auto& term : data.angleTerms)
expressionSet.setVariable(term.variableIndex, computeAngle(delta[term.delta1], delta[term.delta2], norm2Delta[term.delta1], norm2Delta[term.delta2], term.delta1Sign*term.delta2Sign));
for (int i = 0; i < (int) data.dihedralTerms.size(); i++) {
const DihedralTermInfo& term = data.dihedralTerms[i];
expressionSet.setVariable(term.variableIndex, getDihedralAngleBetweenThreeVectors(delta[term.delta1], delta[term.delta2], delta[term.delta3], cross1[i], cross2[i], delta[term.delta1]));
}
if (includeForces) {
// Apply forces based on individual particle coordinates.
// Apply forces based on particle coordinates.
AlignedArray<fvec4>& f = data.f;
for (int i = 0; i < numParticlesPerSet; i++)
......@@ -308,59 +289,6 @@ void CpuCustomManyParticleForce::calculateOneIxn(vector<int>& particleSet, vecto
f[term.atom] = fvec4(temp);
}
// Apply forces based on distances.
for (auto& term : data.distanceTerms) {
float dEdR = (float) (term.forceExpression.evaluate()*term.deltaSign/(normDelta[term.delta]));
fvec4 force = -dEdR*delta[term.delta];
f[term.p1] -= force;
f[term.p2] += force;
}
// Apply forces based on angles.
for (auto& term : data.angleTerms) {
float dEdTheta = (float) term.forceExpression.evaluate();
fvec4 thetaCross = cross(delta[term.delta1], delta[term.delta2]);
float lengthThetaCross = sqrtf(dot3(thetaCross, thetaCross));
if (lengthThetaCross < 1.0e-6f)
lengthThetaCross = 1.0e-6f;
float termA = dEdTheta*term.delta2Sign/(norm2Delta[term.delta1]*lengthThetaCross);
float termC = -dEdTheta*term.delta1Sign/(norm2Delta[term.delta2]*lengthThetaCross);
fvec4 deltaCross1 = cross(delta[term.delta1], thetaCross);
fvec4 deltaCross2 = cross(delta[term.delta2], thetaCross);
fvec4 force1 = termA*deltaCross1;
fvec4 force3 = termC*deltaCross2;
fvec4 force2 = -(force1+force3);
f[term.p1] += force1;
f[term.p2] += force2;
f[term.p3] += force3;
}
// Apply forces based on dihedrals.
for (int i = 0; i < (int) data.dihedralTerms.size(); i++) {
const DihedralTermInfo& term = data.dihedralTerms[i];
float dEdTheta = (float) term.forceExpression.evaluate();
float normCross1 = dot3(cross1[i], cross1[i]);
float normBC = normDelta[term.delta2];
float forceFactors[4];
forceFactors[0] = (-dEdTheta*normBC)/normCross1;
float normCross2 = dot3(cross2[i], cross2[i]);
forceFactors[3] = (dEdTheta*normBC)/normCross2;
forceFactors[1] = dot3(delta[term.delta1], delta[term.delta2]);
forceFactors[1] /= norm2Delta[term.delta2];
forceFactors[2] = dot3(delta[term.delta3], delta[term.delta2]);
forceFactors[2] /= norm2Delta[term.delta2];
fvec4 force1 = forceFactors[0]*cross1[i];
fvec4 force4 = forceFactors[3]*cross2[i];
fvec4 s = forceFactors[1]*force1 - forceFactors[2]*force4;
f[term.p1] += force1;
f[term.p2] -= force1-s;
f[term.p3] -= force4+s;
f[term.p4] += force4;
}
// Store the forces.
for (int i = 0; i < numParticlesPerSet; i++) {
......@@ -391,61 +319,12 @@ void CpuCustomManyParticleForce::computeDelta(const fvec4& posI, const fvec4& po
r2 = dot3(deltaR, deltaR);
}
float CpuCustomManyParticleForce::computeAngle(const fvec4& vi, const fvec4& vj, float v2i, float v2j, float sign) {
float dot = dot3(vi, vj)*sign;
float cosine = dot/sqrtf(v2i*v2j);
if (cosine > 0.99f || cosine < -0.99f) {
// We're close to the singularity in acos(), so take the cross product and use asin() instead.
fvec4 cross12 = cross(vi, vj);
float scale = v2i*v2j;
float angle = asinf(sqrtf(dot3(cross12, cross12)/scale));
if (cosine < 0.0f)
angle = (float) (M_PI-angle);
return angle;
}
return acosf(cosine);
}
float CpuCustomManyParticleForce::getDihedralAngleBetweenThreeVectors(const fvec4& v1, const fvec4& v2, const fvec4& v3, fvec4& cross1, fvec4& cross2, const fvec4& signVector) {
cross1 = cross(v1, v2);
cross2 = cross(v2, v3);
float angle = computeAngle(cross1, cross2, dot3(cross1, cross1), dot3(cross2, cross2), 1.0f);
float dotProduct = dot3(signVector, cross2);
if (dotProduct < 0)
angle = -angle;
return angle;
}
CpuCustomManyParticleForce::ParticleTermInfo::ParticleTermInfo(const string& name, int atom, int component, const Lepton::CompiledExpression& forceExpression, ThreadData& data) :
name(name), atom(atom), component(component), forceExpression(forceExpression) {
variableIndex = data.expressionSet.getVariableIndex(name);
}
CpuCustomManyParticleForce::DistanceTermInfo::DistanceTermInfo(const string& name, const vector<int>& atoms, const Lepton::CompiledExpression& forceExpression, ThreadData& data) :
name(name), p1(atoms[0]), p2(atoms[1]), forceExpression(forceExpression) {
variableIndex = data.expressionSet.getVariableIndex(name);
data.requestDeltaPair(p1, p2, delta, deltaSign, true);
}
CpuCustomManyParticleForce::AngleTermInfo::AngleTermInfo(const string& name, const vector<int>& atoms, const Lepton::CompiledExpression& forceExpression, ThreadData& data) :
name(name), p1(atoms[0]), p2(atoms[1]), p3(atoms[2]), forceExpression(forceExpression) {
variableIndex = data.expressionSet.getVariableIndex(name);
data.requestDeltaPair(p1, p2,delta1, delta1Sign, true);
data.requestDeltaPair(p3, p2, delta2, delta2Sign, true);
}
CpuCustomManyParticleForce::DihedralTermInfo::DihedralTermInfo(const string& name, const vector<int>& atoms, const Lepton::CompiledExpression& forceExpression, ThreadData& data) :
name(name), p1(atoms[0]), p2(atoms[1]), p3(atoms[2]), p4(atoms[3]), forceExpression(forceExpression) {
variableIndex = data.expressionSet.getVariableIndex(name);
float sign;
data.requestDeltaPair(p2, p1, delta1, sign, false);
data.requestDeltaPair(p2, p3, delta2, sign, false);
data.requestDeltaPair(p4, p3, delta3, sign, false);
}
CpuCustomManyParticleForce::ThreadData::ThreadData(const CustomManyParticleForce& force, Lepton::ParsedExpression& energyExpr,
map<string, vector<int> >& distances, map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals) {
CpuCustomManyParticleForce::ThreadData::ThreadData(const CustomManyParticleForce& force, Lepton::ParsedExpression& energyExpr) {
int numParticlesPerSet = force.getNumParticlesPerSet();
int numPerParticleParameters = force.getNumPerParticleParameters();
particleParamIndices.resize(numParticlesPerSet);
......@@ -470,43 +349,6 @@ CpuCustomManyParticleForce::ThreadData::ThreadData(const CustomManyParticleForce
particleParamIndices[i].push_back(expressionSet.getVariableIndex(paramname.str()));
}
}
for (auto& term : dihedrals)
dihedralTerms.push_back(CpuCustomManyParticleForce::DihedralTermInfo(term.first, term.second, energyExpr.differentiate(term.first).optimize().createCompiledExpression(), *this));
for (auto& term : distances)
distanceTerms.push_back(CpuCustomManyParticleForce::DistanceTermInfo(term.first, term.second, energyExpr.differentiate(term.first).optimize().createCompiledExpression(), *this));
for (auto& term : angles)
angleTerms.push_back(CpuCustomManyParticleForce::AngleTermInfo(term.first, term.second, energyExpr.differentiate(term.first).optimize().createCompiledExpression(), *this));
for (auto& term : particleTerms)
expressionSet.registerExpression(term.forceExpression);
for (auto& term : distanceTerms)
expressionSet.registerExpression(term.forceExpression);
for (auto& term : angleTerms)
expressionSet.registerExpression(term.forceExpression);
for (auto& term : dihedralTerms)
expressionSet.registerExpression(term.forceExpression);
int numDeltas = deltaPairs.size();
delta.resize(numDeltas);
normDelta.resize(numDeltas);
norm2Delta.resize(numDeltas);
cross1.resize(numDeltas);
cross2.resize(numDeltas);
}
void CpuCustomManyParticleForce::ThreadData::requestDeltaPair(int p1, int p2, int& pairIndex, float& pairSign, bool allowReversed) {
for (int i = 0; i < (int) deltaPairs.size(); i++) {
if (deltaPairs[i].first == p1 && deltaPairs[i].second == p2) {
pairIndex = i;
pairSign = 1;
return;
}
if (deltaPairs[i].first == p2 && deltaPairs[i].second == p1 && allowReversed) {
pairIndex = i;
pairSign = -1;
return;
}
}
pairIndex = deltaPairs.size();
pairSign = 1;
deltaPairs.push_back(make_pair(p1, p2));
}
......@@ -419,7 +419,7 @@ CudaContext::~CudaContext() {
string errorMessage = "Error deleting Context";
if (contextIsValid && !isLinkedContext) {
cuProfilerStop();
CHECK_RESULT(cuCtxDestroy(context));
cuCtxDestroy(context);
}
contextIsValid = false;
}
......
......@@ -82,7 +82,7 @@ class OPENMM_EXPORT ReferenceBondIxn {
--------------------------------------------------------------------------------------- */
static double getNormedDotProduct(double* vector1, double* vector2, int hasREntry);
static double getNormedDotProduct(double* vector1, double* vector2, int hasREntry=0);
/**---------------------------------------------------------------------------------------
......@@ -99,7 +99,7 @@ class OPENMM_EXPORT ReferenceBondIxn {
--------------------------------------------------------------------------------------- */
static double getAngleBetweenTwoVectors(double* vector1, double* vector2,
double* outputDotProduct, int hasREntry);
double* outputDotProduct=NULL, int hasREntry=0);
/**---------------------------------------------------------------------------------------
......@@ -120,9 +120,9 @@ class OPENMM_EXPORT ReferenceBondIxn {
--------------------------------------------------------------------------------------- */
static double getDihedralAngleBetweenThreeVectors(double* vector1, double* vector2,
double* vector3, double** outputCrossProduct,
double* cosineOfAngle, double* signVector,
double* signOfAngle, int hasREntry);
double* vector3, double** outputCrossProduct=NULL,
double* cosineOfAngle=NULL, double* signVector=NULL,
double* signOfAngle=NULL, int hasREntry=0);
};
......
......@@ -38,9 +38,6 @@ class ReferenceCustomCentroidBondIxn : public ReferenceBondIxn {
private:
class PositionTermInfo;
class DistanceTermInfo;
class AngleTermInfo;
class DihedralTermInfo;
std::vector<std::vector<int> > groupAtoms;
std::vector<std::vector<double> > normalizedWeights;
std::vector<std::vector<int> > bondGroups;
......@@ -49,9 +46,6 @@ class ReferenceCustomCentroidBondIxn : public ReferenceBondIxn {
std::vector<Lepton::CompiledExpression> energyParamDerivExpressions;
std::vector<int> bondParamIndex;
std::vector<PositionTermInfo> positionTerms;
std::vector<DistanceTermInfo> distanceTerms;
std::vector<AngleTermInfo> angleTerms;
std::vector<DihedralTermInfo> dihedralTerms;
int numParameters;
bool usePeriodic;
Vec3 boxVectors[3];
......@@ -86,9 +80,7 @@ class ReferenceCustomCentroidBondIxn : public ReferenceBondIxn {
ReferenceCustomCentroidBondIxn(int numGroupsPerBond, const std::vector<std::vector<int> >& groupAtoms,
const std::vector<std::vector<double> >& normalizedWeights, const std::vector<std::vector<int> >& bondGroups, const Lepton::ParsedExpression& energyExpression,
const std::vector<std::string>& bondParameterNames, const std::map<std::string, std::vector<int> >& distances,
const std::map<std::string, std::vector<int> >& angles, const std::map<std::string, std::vector<int> >& dihedrals,
const std::vector<Lepton::CompiledExpression> energyParamDerivExpressions);
const std::vector<std::string>& bondParameterNames, const std::vector<Lepton::CompiledExpression> energyParamDerivExpressions);
/**---------------------------------------------------------------------------------------
......@@ -148,44 +140,6 @@ public:
}
};
class ReferenceCustomCentroidBondIxn::DistanceTermInfo {
public:
std::string name;
int g1, g2, index;
Lepton::CompiledExpression forceExpression;
mutable double delta[ReferenceForce::LastDeltaRIndex];
DistanceTermInfo(const std::string& name, const std::vector<int>& groups, const Lepton::CompiledExpression& forceExpression) :
name(name), g1(groups[0]), g2(groups[1]), forceExpression(forceExpression) {
}
};
class ReferenceCustomCentroidBondIxn::AngleTermInfo {
public:
std::string name;
int g1, g2, g3, index;
Lepton::CompiledExpression forceExpression;
mutable double delta1[ReferenceForce::LastDeltaRIndex];
mutable double delta2[ReferenceForce::LastDeltaRIndex];
AngleTermInfo(const std::string& name, const std::vector<int>& groups, const Lepton::CompiledExpression& forceExpression) :
name(name), g1(groups[0]), g2(groups[1]), g3(groups[2]), forceExpression(forceExpression) {
}
};
class ReferenceCustomCentroidBondIxn::DihedralTermInfo {
public:
std::string name;
int g1, g2, g3, g4, index;
Lepton::CompiledExpression forceExpression;
mutable double delta1[ReferenceForce::LastDeltaRIndex];
mutable double delta2[ReferenceForce::LastDeltaRIndex];
mutable double delta3[ReferenceForce::LastDeltaRIndex];
mutable double cross1[3];
mutable double cross2[3];
DihedralTermInfo(const std::string& name, const std::vector<int>& groups, const Lepton::CompiledExpression& forceExpression) :
name(name), g1(groups[0]), g2(groups[1]), g3(groups[2]), g4(groups[3]), forceExpression(forceExpression) {
}
};
} // namespace OpenMM
#endif // __ReferenceCustomCentroidBondIxn_H__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment