Commit 94a31b8d authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented delta() for custom expressions

parent 7c836d2d
...@@ -62,7 +62,7 @@ public: ...@@ -62,7 +62,7 @@ public:
* can be used when processing or analyzing parsed expressions. * can be used when processing or analyzing parsed expressions.
*/ */
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG, enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, ERF, ERFC, STEP, SQUARE, CUBE, RECIPROCAL, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, ERF, ERFC, STEP, DELTA, SQUARE, CUBE, RECIPROCAL,
ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT, MIN, MAX, ABS}; ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT, MIN, MAX, ABS};
/** /**
* Get the name of this Operation. * Get the name of this Operation.
...@@ -142,6 +142,7 @@ public: ...@@ -142,6 +142,7 @@ public:
class Erf; class Erf;
class Erfc; class Erfc;
class Step; class Step;
class Delta;
class Square; class Square;
class Cube; class Cube;
class Reciprocal; class Reciprocal;
...@@ -807,6 +808,28 @@ public: ...@@ -807,6 +808,28 @@ public:
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class LEPTON_EXPORT Operation::Delta : public Operation {
public:
Delta() {
}
std::string getName() const {
return "delta";
}
Id getId() const {
return DELTA;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Delta();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return (args[0] == 0.0 ? 1.0 : 0.0);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class LEPTON_EXPORT Operation::Square : public Operation { class LEPTON_EXPORT Operation::Square : public Operation {
public: public:
Square() { Square() {
......
...@@ -249,6 +249,10 @@ ExpressionTreeNode Operation::Step::differentiate(const std::vector<ExpressionTr ...@@ -249,6 +249,10 @@ ExpressionTreeNode Operation::Step::differentiate(const std::vector<ExpressionTr
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
} }
ExpressionTreeNode Operation::Delta::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0));
}
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(2.0), ExpressionTreeNode(new Operation::MultiplyConstant(2.0),
......
...@@ -314,6 +314,7 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -314,6 +314,7 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
opMap["erf"] = Operation::ERF; opMap["erf"] = Operation::ERF;
opMap["erfc"] = Operation::ERFC; opMap["erfc"] = Operation::ERFC;
opMap["step"] = Operation::STEP; opMap["step"] = Operation::STEP;
opMap["delta"] = Operation::DELTA;
opMap["square"] = Operation::SQUARE; opMap["square"] = Operation::SQUARE;
opMap["cube"] = Operation::CUBE; opMap["cube"] = Operation::CUBE;
opMap["recip"] = Operation::RECIPROCAL; opMap["recip"] = Operation::RECIPROCAL;
...@@ -371,6 +372,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -371,6 +372,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
return new Operation::Erfc(); return new Operation::Erfc();
case Operation::STEP: case Operation::STEP:
return new Operation::Step(); return new Operation::Step();
case Operation::DELTA:
return new Operation::Delta();
case Operation::SQUARE: case Operation::SQUARE:
return new Operation::Square(); return new Operation::Square();
case Operation::CUBE: case Operation::CUBE:
......
...@@ -64,8 +64,8 @@ namespace OpenMM { ...@@ -64,8 +64,8 @@ namespace OpenMM {
* </pre></tt> * </pre></tt>
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*/ */
class OPENMM_EXPORT CustomAngleForce : public Force { class OPENMM_EXPORT CustomAngleForce : public Force {
......
...@@ -64,8 +64,8 @@ namespace OpenMM { ...@@ -64,8 +64,8 @@ namespace OpenMM {
* </pre></tt> * </pre></tt>
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*/ */
class OPENMM_EXPORT CustomBondForce : public Force { class OPENMM_EXPORT CustomBondForce : public Force {
......
...@@ -67,8 +67,8 @@ namespace OpenMM { ...@@ -67,8 +67,8 @@ namespace OpenMM {
* </pre></tt> * </pre></tt>
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*/ */
class OPENMM_EXPORT CustomExternalForce : public Force { class OPENMM_EXPORT CustomExternalForce : public Force {
......
...@@ -127,8 +127,8 @@ namespace OpenMM { ...@@ -127,8 +127,8 @@ namespace OpenMM {
* particular piece of the computation. * particular piece of the computation.
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. In expressions for * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. In expressions for
* particle pair calculations, the names of per-particle parameters and computed values * particle pair calculations, the names of per-particle parameters and computed values
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example, * have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
......
...@@ -88,8 +88,8 @@ namespace OpenMM { ...@@ -88,8 +88,8 @@ namespace OpenMM {
* </pre></tt> * </pre></tt>
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of * In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of
* values, and a natural spline is created from them. That function can then appear in the expression. * values, and a natural spline is created from them. That function can then appear in the expression.
......
...@@ -178,8 +178,8 @@ namespace OpenMM { ...@@ -178,8 +178,8 @@ namespace OpenMM {
* </pre></tt> * </pre></tt>
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. An expression * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. An expression
* may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
*/ */
......
...@@ -74,8 +74,8 @@ namespace OpenMM { ...@@ -74,8 +74,8 @@ namespace OpenMM {
* </pre></tt> * </pre></tt>
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. The names of per-particle parameters * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. The names of per-particle parameters
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example, * have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
* *
......
...@@ -64,8 +64,8 @@ namespace OpenMM { ...@@ -64,8 +64,8 @@ namespace OpenMM {
* </pre></tt> * </pre></tt>
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*/ */
class OPENMM_EXPORT CustomTorsionForce : public Force { class OPENMM_EXPORT CustomTorsionForce : public Force {
......
...@@ -217,6 +217,9 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -217,6 +217,9 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::STEP: case Operation::STEP:
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f"; out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
break; break;
case Operation::DELTA:
out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? 1.0f : 0.0f";
break;
case Operation::SQUARE: case Operation::SQUARE:
{ {
string arg = getTempName(node.getChildren()[0], temps); string arg = getTempName(node.getChildren()[0], temps);
......
...@@ -205,6 +205,7 @@ int main() { ...@@ -205,6 +205,7 @@ int main() {
verifyEvaluation("max(x, y)", 2.0, 3.0, 3.0); verifyEvaluation("max(x, y)", 2.0, 3.0, 3.0);
verifyEvaluation("max(x, -1)", 2.0, 3.0, 2.0); verifyEvaluation("max(x, -1)", 2.0, 3.0, 2.0);
verifyEvaluation("abs(x-y)", 2.0, 3.0, 1.0); verifyEvaluation("abs(x-y)", 2.0, 3.0, 1.0);
verifyEvaluation("delta(x)+3*delta(y-1.5)", 2.0, 1.5, 3.0);
verifyInvalidExpression("1..2"); verifyInvalidExpression("1..2");
verifyInvalidExpression("1*(2+3"); verifyInvalidExpression("1*(2+3");
verifyInvalidExpression("5++4"); verifyInvalidExpression("5++4");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment