Commit 2f81944d authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented step function in custom expressions

parent 75bbd639
...@@ -62,7 +62,8 @@ public: ...@@ -62,7 +62,8 @@ public:
* can be used when processing or analyzing parsed expressions. * can be used when processing or analyzing parsed expressions.
*/ */
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG, enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, SQUARE, CUBE, RECIPROCAL, ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT}; SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, STEP, SQUARE, CUBE, RECIPROCAL,
ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT};
/** /**
* Get the name of this Operation. * Get the name of this Operation.
*/ */
...@@ -125,6 +126,7 @@ public: ...@@ -125,6 +126,7 @@ public:
class Sinh; class Sinh;
class Cosh; class Cosh;
class Tanh; class Tanh;
class Step;
class Square; class Square;
class Cube; class Cube;
class Reciprocal; class Reciprocal;
...@@ -704,6 +706,28 @@ public: ...@@ -704,6 +706,28 @@ public:
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Step : public Operation {
public:
Step() {
}
std::string getName() const {
return "step";
}
Id getId() const {
return STEP;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Step();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return (args[0] >= 0.0 ? 1.0 : 0.0);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Square : public Operation { class Operation::Square : public Operation {
public: public:
Square() { Square() {
......
...@@ -216,6 +216,10 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr ...@@ -216,6 +216,10 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::Step::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0));
}
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(2.0), ExpressionTreeNode(new Operation::MultiplyConstant(2.0),
......
...@@ -296,6 +296,7 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -296,6 +296,7 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
opMap["sinh"] = Operation::SINH; opMap["sinh"] = Operation::SINH;
opMap["cosh"] = Operation::COSH; opMap["cosh"] = Operation::COSH;
opMap["tanh"] = Operation::TANH; opMap["tanh"] = Operation::TANH;
opMap["step"] = Operation::STEP;
opMap["square"] = Operation::SQUARE; opMap["square"] = Operation::SQUARE;
opMap["cube"] = Operation::CUBE; opMap["cube"] = Operation::CUBE;
opMap["recip"] = Operation::RECIPROCAL; opMap["recip"] = Operation::RECIPROCAL;
...@@ -344,6 +345,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -344,6 +345,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
return new Operation::Cosh(); return new Operation::Cosh();
case Operation::TANH: case Operation::TANH:
return new Operation::Tanh(); return new Operation::Tanh();
case Operation::STEP:
return new Operation::Step();
case Operation::SQUARE: case Operation::SQUARE:
return new Operation::Square(); return new Operation::Square();
case Operation::CUBE: case Operation::CUBE:
......
...@@ -250,7 +250,7 @@ enum CudaNonbondedMethod ...@@ -250,7 +250,7 @@ enum CudaNonbondedMethod
enum ExpressionOp { enum ExpressionOp {
VARIABLE0 = 0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, VARIABLE8, MULTIPLY, DIVIDE, ADD, SUBTRACT, POWER, MULTIPLY_CONSTANT, POWER_CONSTANT, ADD_CONSTANT, VARIABLE0 = 0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, VARIABLE8, MULTIPLY, DIVIDE, ADD, SUBTRACT, POWER, MULTIPLY_CONSTANT, POWER_CONSTANT, ADD_CONSTANT,
GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, STEP, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH
}; };
template<int SIZE> template<int SIZE>
......
...@@ -256,6 +256,9 @@ static Expression<SIZE> createExpression(gpuContext gpu, const string& expressio ...@@ -256,6 +256,9 @@ static Expression<SIZE> createExpression(gpuContext gpu, const string& expressio
case Operation::TANH: case Operation::TANH:
exp.op[i] = TANH; exp.op[i] = TANH;
break; break;
case Operation::STEP:
exp.op[i] = STEP;
break;
case Operation::SQUARE: case Operation::SQUARE:
exp.op[i] = SQUARE; exp.op[i] = SQUARE;
break; break;
......
...@@ -137,10 +137,13 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float* ...@@ -137,10 +137,13 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
float temp = STACK(stackPointer); float temp = STACK(stackPointer);
STACK(stackPointer) *= temp; STACK(stackPointer) *= temp;
} }
else /*if (op == CUBE)*/ { else if (op == CUBE) {
float temp = STACK(stackPointer); float temp = STACK(stackPointer);
STACK(stackPointer) *= temp*temp; STACK(stackPointer) *= temp*temp;
} }
else /*if (op == STEP)*/ {
STACK(stackPointer) = (STACK(stackPointer) >= 0.0f ? 1.0f : 0.0f);
}
} }
else { else {
if (op == SIN) { if (op == SIN) {
......
...@@ -194,6 +194,9 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -194,6 +194,9 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::TANH: case Operation::TANH:
out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")"; out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
break; break;
case Operation::STEP:
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
break;
case Operation::SQUARE: case Operation::SQUARE:
{ {
string arg = getTempName(node.getChildren()[0], temps); string arg = getTempName(node.getChildren()[0], temps);
......
...@@ -203,6 +203,7 @@ int main() { ...@@ -203,6 +203,7 @@ int main() {
verifyDerivative("sinh(x)", "cosh(x)"); verifyDerivative("sinh(x)", "cosh(x)");
verifyDerivative("cosh(x)", "sinh(x)"); verifyDerivative("cosh(x)", "sinh(x)");
verifyDerivative("tanh(x)", "1/(cosh(x)^2)"); verifyDerivative("tanh(x)", "1/(cosh(x)^2)");
verifyDerivative("step(x)*x+step(1-x)*2*x", "step(x)+step(1-x)*2");
verifyDerivative("recip(x)", "-1/x^2"); verifyDerivative("recip(x)", "-1/x^2");
verifyDerivative("square(x)", "2*x"); verifyDerivative("square(x)", "2*x");
verifyDerivative("cube(x)", "3*x^2"); verifyDerivative("cube(x)", "3*x^2");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment