"openmmapi/vscode:/vscode.git/clone" did not exist on "9e1145ac70f00664caadd35e4488824e35e88713"
Commit 43ebedfb authored by Peter Eastman's avatar Peter Eastman
Browse files

Added support for sinh(), cosh(), and tanh() functions.

parent 48d93893
...@@ -62,7 +62,7 @@ public: ...@@ -62,7 +62,7 @@ public:
* can be used when processing or analyzing parsed expressions. * can be used when processing or analyzing parsed expressions.
*/ */
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG, enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SQUARE, CUBE, RECIPROCAL, ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT}; SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, SQUARE, CUBE, RECIPROCAL, ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT};
/** /**
* Get the name of this Operation. * Get the name of this Operation.
*/ */
...@@ -122,6 +122,9 @@ public: ...@@ -122,6 +122,9 @@ public:
class Asin; class Asin;
class Acos; class Acos;
class Atan; class Atan;
class Sinh;
class Cosh;
class Tanh;
class Square; class Square;
class Cube; class Cube;
class Reciprocal; class Reciprocal;
...@@ -635,6 +638,72 @@ public: ...@@ -635,6 +638,72 @@ public:
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Sinh : public Operation {
public:
Sinh() {
}
std::string getName() const {
return "sinh";
}
Id getId() const {
return SINH;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Sinh();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::sinh(args[0]);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Cosh : public Operation {
public:
Cosh() {
}
std::string getName() const {
return "cosh";
}
Id getId() const {
return COSH;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Cosh();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::cosh(args[0]);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Tanh : public Operation {
public:
Tanh() {
}
std::string getName() const {
return "tanh";
}
Id getId() const {
return TANH;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Tanh();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::tanh(args[0]);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Square : public Operation { class Operation::Square : public Operation {
public: public:
Square() { Square() {
......
...@@ -193,6 +193,29 @@ ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTr ...@@ -193,6 +193,29 @@ ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTr
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Cosh(),
children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Sinh(),
children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Constant(1.0)),
ExpressionTreeNode(new Operation::Square(),
ExpressionTreeNode(new Operation::Tanh(), children[0]))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(2.0), ExpressionTreeNode(new Operation::MultiplyConstant(2.0),
......
...@@ -241,6 +241,9 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -241,6 +241,9 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
opMap["asin"] = Operation::ASIN; opMap["asin"] = Operation::ASIN;
opMap["acos"] = Operation::ACOS; opMap["acos"] = Operation::ACOS;
opMap["atan"] = Operation::ATAN; opMap["atan"] = Operation::ATAN;
opMap["sinh"] = Operation::SINH;
opMap["cosh"] = Operation::COSH;
opMap["tanh"] = Operation::TANH;
opMap["square"] = Operation::SQUARE; opMap["square"] = Operation::SQUARE;
opMap["cube"] = Operation::CUBE; opMap["cube"] = Operation::CUBE;
opMap["recip"] = Operation::RECIPROCAL; opMap["recip"] = Operation::RECIPROCAL;
...@@ -283,6 +286,12 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -283,6 +286,12 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
return new Operation::Acos(); return new Operation::Acos();
case Operation::ATAN: case Operation::ATAN:
return new Operation::Atan(); return new Operation::Atan();
case Operation::SINH:
return new Operation::Sinh();
case Operation::COSH:
return new Operation::Cosh();
case Operation::TANH:
return new Operation::Tanh();
case Operation::SQUARE: case Operation::SQUARE:
return new Operation::Square(); return new Operation::Square();
case Operation::CUBE: case Operation::CUBE:
......
...@@ -78,8 +78,8 @@ namespace OpenMM { ...@@ -78,8 +78,8 @@ namespace OpenMM {
* </pre></tt> * </pre></tt>
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan. All trigonometric functions are defined * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh. All trigonometric functions
* in radians, and log is the natural logarithm. * are defined in radians, and log is the natural logarithm.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of * In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of
* values, and an interpolating or approximating spline is created from them. That function can then appear in expressions * values, and an interpolating or approximating spline is created from them. That function can then appear in expressions
......
...@@ -250,7 +250,7 @@ enum CudaNonbondedMethod ...@@ -250,7 +250,7 @@ enum CudaNonbondedMethod
enum ExpressionOp { enum ExpressionOp {
VARIABLE0 = 0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, MULTIPLY, DIVIDE, ADD, SUBTRACT, POWER, MULTIPLY_CONSTANT, POWER_CONSTANT, ADD_CONSTANT, VARIABLE0 = 0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, MULTIPLY, DIVIDE, ADD, SUBTRACT, POWER, MULTIPLY_CONSTANT, POWER_CONSTANT, ADD_CONSTANT,
GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH
}; };
template<int SIZE> template<int SIZE>
......
...@@ -243,6 +243,15 @@ static Expression<SIZE> createExpression(gpuContext gpu, const string& expressio ...@@ -243,6 +243,15 @@ static Expression<SIZE> createExpression(gpuContext gpu, const string& expressio
case Operation::ATAN: case Operation::ATAN:
exp.op[i] = ATAN; exp.op[i] = ATAN;
break; break;
case Operation::SINH:
exp.op[i] = SINH;
break;
case Operation::COSH:
exp.op[i] = COSH;
break;
case Operation::TANH:
exp.op[i] = TANH;
break;
case Operation::SQUARE: case Operation::SQUARE:
exp.op[i] = SQUARE; exp.op[i] = SQUARE;
break; break;
......
...@@ -234,9 +234,18 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float* ...@@ -234,9 +234,18 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
else if (op == ACOS) { else if (op == ACOS) {
STACK(stackPointer) = acos(STACK(stackPointer)); STACK(stackPointer) = acos(STACK(stackPointer));
} }
else /*if (op == ATAN)*/ { else if (op == ATAN) {
STACK(stackPointer) = atan(STACK(stackPointer)); STACK(stackPointer) = atan(STACK(stackPointer));
} }
else if (op == SINH) {
STACK(stackPointer) = sinh(STACK(stackPointer));
}
else if (op == COSH) {
STACK(stackPointer) = cosh(STACK(stackPointer));
}
else /*if (op == TANH)*/ {
STACK(stackPointer) = tanh(STACK(stackPointer));
}
} }
} }
} }
......
...@@ -155,6 +155,15 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -155,6 +155,15 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::ATAN: case Operation::ATAN:
out << "atan(" << getTempName(node.getChildren()[0], temps) << ")"; out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
break; break;
case Operation::SINH:
out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COSH:
out << "cosh(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::TANH:
out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SQUARE: case Operation::SQUARE:
{ {
string arg = getTempName(node.getChildren()[0], temps); string arg = getTempName(node.getChildren()[0], temps);
......
...@@ -158,10 +158,10 @@ void testCustomFunction(const string& expression, const string& equivalent) { ...@@ -158,10 +158,10 @@ void testCustomFunction(const string& expression, const string& equivalent) {
verifySameValue(deriv1, deriv2, 2.0, -3.0); verifySameValue(deriv1, deriv2, 2.0, -3.0);
ParsedExpression deriv3 = deriv1.differentiate("y"); ParsedExpression deriv3 = deriv1.differentiate("y");
ParsedExpression deriv4 = deriv2.differentiate("y"); ParsedExpression deriv4 = deriv2.differentiate("y");
verifySameValue(deriv1, deriv2, 1.0, 2.0); verifySameValue(deriv3, deriv4, 1.0, 2.0);
verifySameValue(deriv1, deriv2, 2.0, 3.0); verifySameValue(deriv3, deriv4, 2.0, 3.0);
verifySameValue(deriv1, deriv2, -2.0, 3.0); verifySameValue(deriv3, deriv4, -2.0, 3.0);
verifySameValue(deriv1, deriv2, 2.0, -3.0); verifySameValue(deriv3, deriv4, 2.0, -3.0);
delete functions["custom"]; delete functions["custom"];
} }
...@@ -204,14 +204,14 @@ int main() { ...@@ -204,14 +204,14 @@ int main() {
verifyDerivative("asin(x)", "1/sqrt(1-x^2)"); verifyDerivative("asin(x)", "1/sqrt(1-x^2)");
verifyDerivative("acos(x)", "-1/sqrt(1-x^2)"); verifyDerivative("acos(x)", "-1/sqrt(1-x^2)");
verifyDerivative("atan(x)", "1/(1+x^2)"); verifyDerivative("atan(x)", "1/(1+x^2)");
verifyDerivative("sinh(x)", "cosh(x)");
verifyDerivative("cosh(x)", "sinh(x)");
verifyDerivative("tanh(x)", "1/(cosh(x)^2)");
verifyDerivative("recip(x)", "-1/x^2"); verifyDerivative("recip(x)", "-1/x^2");
verifyDerivative("square(x)", "2*x"); verifyDerivative("square(x)", "2*x");
verifyDerivative("cube(x)", "3*x^2"); verifyDerivative("cube(x)", "3*x^2");
testCustomFunction("custom(x, y)/2", "x*y"); testCustomFunction("custom(x, y)/2", "x*y");
testCustomFunction("custom(x^2, 1)+custom(2, y-1)", "2*x^2+4*(y-1)"); testCustomFunction("custom(x^2, 1)+custom(2, y-1)", "2*x^2+4*(y-1)");
map<string, double> variables;
variables["x"] = 2.0;
variables["y"] = 10.0;
cout << Parser::parse("2*3*x").optimize() << endl; cout << Parser::parse("2*3*x").optimize() << endl;
cout << Parser::parse("1/(1+x)").optimize() << endl; cout << Parser::parse("1/(1+x)").optimize() << endl;
cout << Parser::parse("x^(1/2)").optimize() << endl; cout << Parser::parse("x^(1/2)").optimize() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment