Commit 2f81944d authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented step function in custom expressions

parent 75bbd639
......@@ -62,7 +62,8 @@ public:
* can be used when processing or analyzing parsed expressions.
*/
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, SQUARE, CUBE, RECIPROCAL, ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT};
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, STEP, SQUARE, CUBE, RECIPROCAL,
ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT};
/**
* Get the name of this Operation.
*/
......@@ -125,6 +126,7 @@ public:
class Sinh;
class Cosh;
class Tanh;
class Step;
class Square;
class Cube;
class Reciprocal;
......@@ -704,6 +706,28 @@ public:
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Step : public Operation {
public:
Step() {
}
std::string getName() const {
return "step";
}
Id getId() const {
return STEP;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Step();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return (args[0] >= 0.0 ? 1.0 : 0.0);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Square : public Operation {
public:
Square() {
......
......@@ -216,6 +216,10 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr
childDerivs[0]);
}
ExpressionTreeNode Operation::Step::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0));
}
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(2.0),
......
......@@ -296,6 +296,7 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
opMap["sinh"] = Operation::SINH;
opMap["cosh"] = Operation::COSH;
opMap["tanh"] = Operation::TANH;
opMap["step"] = Operation::STEP;
opMap["square"] = Operation::SQUARE;
opMap["cube"] = Operation::CUBE;
opMap["recip"] = Operation::RECIPROCAL;
......@@ -344,6 +345,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
return new Operation::Cosh();
case Operation::TANH:
return new Operation::Tanh();
case Operation::STEP:
return new Operation::Step();
case Operation::SQUARE:
return new Operation::Square();
case Operation::CUBE:
......
......@@ -250,7 +250,7 @@ enum CudaNonbondedMethod
enum ExpressionOp {
VARIABLE0 = 0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, VARIABLE8, MULTIPLY, DIVIDE, ADD, SUBTRACT, POWER, MULTIPLY_CONSTANT, POWER_CONSTANT, ADD_CONSTANT,
GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH
GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, STEP, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH
};
template<int SIZE>
......
......@@ -256,6 +256,9 @@ static Expression<SIZE> createExpression(gpuContext gpu, const string& expressio
case Operation::TANH:
exp.op[i] = TANH;
break;
case Operation::STEP:
exp.op[i] = STEP;
break;
case Operation::SQUARE:
exp.op[i] = SQUARE;
break;
......
......@@ -137,10 +137,13 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
float temp = STACK(stackPointer);
STACK(stackPointer) *= temp;
}
else /*if (op == CUBE)*/ {
else if (op == CUBE) {
float temp = STACK(stackPointer);
STACK(stackPointer) *= temp*temp;
}
else /*if (op == STEP)*/ {
STACK(stackPointer) = (STACK(stackPointer) >= 0.0f ? 1.0f : 0.0f);
}
}
else {
if (op == SIN) {
......
......@@ -194,6 +194,9 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::TANH:
out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::STEP:
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
break;
case Operation::SQUARE:
{
string arg = getTempName(node.getChildren()[0], temps);
......
......@@ -203,6 +203,7 @@ int main() {
verifyDerivative("sinh(x)", "cosh(x)");
verifyDerivative("cosh(x)", "sinh(x)");
verifyDerivative("tanh(x)", "1/(cosh(x)^2)");
verifyDerivative("step(x)*x+step(1-x)*2*x", "step(x)+step(1-x)*2");
verifyDerivative("recip(x)", "-1/x^2");
verifyDerivative("square(x)", "2*x");
verifyDerivative("cube(x)", "3*x^2");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment