Commit 37ab709f authored by peastman's avatar peastman
Browse files

Merge pull request #809 from peastman/floor

Custom expressions support floor() and ceil()
parents a0bdc698 e0713aeb
......@@ -1040,7 +1040,7 @@ The following operators are supported: + (add), - (subtract), * (multiply), /
The following standard functions are supported: sqrt, exp, log, sin, cos, sec,
csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs,
step, delta. step(x) = 0 if x < 0, 1 otherwise. delta(x) = 1 if x is 0, 0
floor, ceil, step, delta. step(x) = 0 if x < 0, 1 otherwise. delta(x) = 1 if x is 0, 0
otherwise. Some custom forces allow additional functions to be defined from
tabulated values.
......
......@@ -64,7 +64,7 @@ public:
*/
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, ERF, ERFC, STEP, DELTA, SQUARE, CUBE, RECIPROCAL,
ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT, MIN, MAX, ABS};
ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT, MIN, MAX, ABS, FLOOR, CEIL};
/**
* Get the name of this Operation.
*/
......@@ -153,6 +153,8 @@ public:
class Min;
class Max;
class Abs;
class Floor;
class Ceil;
};
class LEPTON_EXPORT Operation::Constant : public Operation {
......@@ -1090,6 +1092,51 @@ public:
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class LEPTON_EXPORT Operation::Floor : public Operation {
public:
Floor() {
}
std::string getName() const {
return "floor";
}
Id getId() const {
return FLOOR;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Floor();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::floor(args[0]);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class LEPTON_EXPORT Operation::Ceil : public Operation {
public:
Ceil() {
}
std::string getName() const {
return "ceil";
}
Id getId() const {
return CEIL;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Ceil();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::ceil(args[0]);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
} // namespace Lepton
#endif /*LEPTON_OPERATION_H_*/
......@@ -343,6 +343,12 @@ void CompiledExpression::generateJitCode() {
case Operation::ABS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
break;
case Operation::FLOOR:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor);
break;
case Operation::CEIL:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil);
break;
default:
// Just invoke evaluateOperation().
......
......@@ -317,3 +317,11 @@ ExpressionTreeNode Operation::Abs::differentiate(const std::vector<ExpressionTre
ExpressionTreeNode(new Operation::AddConstant(-1),
ExpressionTreeNode(new Operation::MultiplyConstant(2), step)));
}
ExpressionTreeNode Operation::Floor::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0));
}
ExpressionTreeNode Operation::Ceil::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0));
}
......@@ -326,6 +326,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
opMap["min"] = Operation::MIN;
opMap["max"] = Operation::MAX;
opMap["abs"] = Operation::ABS;
opMap["floor"] = Operation::FLOOR;
opMap["ceil"] = Operation::CEIL;
}
string trimmed = name.substr(0, name.size()-1);
......@@ -391,6 +393,10 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
return new Operation::Max();
case Operation::ABS:
return new Operation::Abs();
case Operation::FLOOR:
return new Operation::Floor();
case Operation::CEIL:
return new Operation::Ceil();
default:
throw Exception("unknown function");
}
......
......@@ -65,7 +65,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*/
......
......@@ -65,7 +65,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*/
......
......@@ -89,7 +89,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
......
......@@ -68,7 +68,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*/
......
......@@ -129,7 +129,7 @@ namespace OpenMM {
* particular piece of the computation.
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. In expressions for
* particle pair calculations, the names of per-particle parameters and computed values
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
......
......@@ -89,7 +89,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*
* In addition, you can call addTabulatedFunction() to define a new function based on tabulated values. You specify the function by
......
......@@ -204,7 +204,7 @@ namespace OpenMM {
* the force from a single force group, or a random number.
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. An expression
* may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
*/
......
......@@ -149,7 +149,7 @@ namespace OpenMM {
* still only be evaluated once for each triplet, so it must still be symmetric with respect to p2 and p3.
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* The names of per-particle parameters have the suffix "1", "2", etc. appended to them to indicate the values for the multiple interacting particles.
* For example, if you define a per-particle parameter called "charge", then the variable "charge2" is the charge of particle p2.
......
......@@ -120,7 +120,7 @@ namespace OpenMM {
* frequently, the long range correction can be very slow. For this reason, it is disabled by default.
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. The names of per-particle parameters
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
......
......@@ -65,7 +65,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step, delta. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
*/
......
......@@ -457,6 +457,12 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
case Operation::ABS:
out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::FLOOR:
out << "floor(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::CEIL:
out << "ceil(" << getTempName(node.getChildren()[0], temps) << ")";
break;
default:
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
}
......
......@@ -449,6 +449,12 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::ABS:
out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::FLOOR:
out << "floor(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::CEIL:
out << "ceil(" << getTempName(node.getChildren()[0], temps) << ")";
break;
default:
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
}
......
......@@ -247,6 +247,8 @@ int main() {
verifyEvaluation("abs(x-y)", 2.0, 3.0, 1.0);
verifyEvaluation("delta(x)+3*delta(y-1.5)", 2.0, 1.5, 3.0);
verifyEvaluation("step(x-3)+y*step(x)", 2.0, 3.0, 3.0);
verifyEvaluation("floor(x)", -2.1, 3.0, -3.0);
verifyEvaluation("ceil(x)", -2.1, 3.0, -2.0);
verifyInvalidExpression("1..2");
verifyInvalidExpression("1*(2+3");
verifyInvalidExpression("5++4");
......@@ -279,6 +281,7 @@ int main() {
verifyDerivative("min(x, 2*x)", "step(x-2*x)*2+(1-step(x-2*x))*1");
verifyDerivative("max(5, x^2)", "(1-step(5-x^2))*2*x");
verifyDerivative("abs(3*x)", "step(3*x)*3+(1-step(3*x))*-3");
verifyDerivative("floor(x)+0.5*x*ceil(x)", "0.5*ceil(x)");
testCustomFunction("custom(x, y)/2", "x*y");
testCustomFunction("custom(x^2, 1)+custom(2, y-1)", "2*x^2+4*(y-1)");
cout << Parser::parse("x*x").optimize() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment