Commit 5c24a611 authored by Peter Eastman's avatar Peter Eastman
Browse files

erf() and erfc() can appear in expressions

parent 8af0ac1c
...@@ -62,7 +62,7 @@ public: ...@@ -62,7 +62,7 @@ public:
* can be used when processing or analyzing parsed expressions. * can be used when processing or analyzing parsed expressions.
*/ */
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG, enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, STEP, SQUARE, CUBE, RECIPROCAL, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, ERF, ERFC, STEP, SQUARE, CUBE, RECIPROCAL,
ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT}; ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT};
/** /**
* Get the name of this Operation. * Get the name of this Operation.
...@@ -139,6 +139,8 @@ public: ...@@ -139,6 +139,8 @@ public:
class Sinh; class Sinh;
class Cosh; class Cosh;
class Tanh; class Tanh;
class Erf;
class Erfc;
class Step; class Step;
class Square; class Square;
class Cube; class Cube;
...@@ -740,6 +742,46 @@ public: ...@@ -740,6 +742,46 @@ public:
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Erf : public Operation {
public:
Erf() {
}
std::string getName() const {
return "erf";
}
Id getId() const {
return ERF;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Erf();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const;
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Erfc : public Operation {
public:
Erfc() {
}
std::string getName() const {
return "erfc";
}
Id getId() const {
return ERFC;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Erfc();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const;
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Step : public Operation { class Operation::Step : public Operation {
public: public:
Step() { Step() {
......
#ifndef LEPTON_MSVC_ERFC_H_
#define LEPTON_MSVC_ERFC_H_
/*
* At least up to version 8 (VC++ 2005), Microsoft does not support the
* standard C99 erf() and erfc() functions. For now we're including these
* definitions for an MSVC compilation; if these are added later then
* the #ifdef below should change to compare _MSC_VER with a particular
* version level.
*/
#ifdef _MSC_VER
/***************************
* erf.cpp
* author: Steve Strand
* written: 29-Jan-04
***************************/
#include <cmath>
#define M_PI 3.14159265358979323846264338327950288
static const double rel_error= 1E-12; //calculate 12 significant figures
//you can adjust rel_error to trade off between accuracy and speed
//but don't ask for > 15 figures (assuming usual 52 bit mantissa in a double)
static double erfc(double x);
static double erf(double x)
//erf(x) = 2/sqrt(pi)*integral(exp(-t^2),t,0,x)
// = 2/sqrt(pi)*[x - x^3/3 + x^5/5*2! - x^7/7*3! + ...]
// = 1-erfc(x)
{
static const double two_sqrtpi= 1.128379167095512574; // 2/sqrt(pi)
if (fabs(x) > 2.2) {
return 1.0 - erfc(x); //use continued fraction when fabs(x) > 2.2
}
double sum= x, term= x, xsqr= x*x;
int j= 1;
do {
term*= xsqr/j;
sum-= term/(2*j+1);
++j;
term*= xsqr/j;
sum+= term/(2*j+1);
++j;
} while (fabs(term)/sum > rel_error);
return two_sqrtpi*sum;
}
static double erfc(double x)
//erfc(x) = 2/sqrt(pi)*integral(exp(-t^2),t,x,inf)
// = exp(-x^2)/sqrt(pi) * [1/x+ (1/2)/x+ (2/2)/x+ (3/2)/x+ (4/2)/x+ ...]
// = 1-erf(x)
//expression inside [] is a continued fraction so '+' means add to denominator only
{
static const double one_sqrtpi= 0.564189583547756287; // 1/sqrt(pi)
if (fabs(x) < 2.2) {
return 1.0 - erf(x); //use series when fabs(x) < 2.2
}
// Don't look for x==0 here!
if (x < 0) { //continued fraction only valid for x>0
return 2.0 - erfc(-x);
}
double a=1, b=x; //last two convergent numerators
double c=x, d=x*x+0.5; //last two convergent denominators
double q1, q2= b/d; //last two convergents (a/c and b/d)
double n= 1.0, t;
do {
t= a*n+b*x;
a= b;
b= t;
t= c*n+d*x;
c= d;
d= t;
n+= 0.5;
q1= q2;
q2= b/d;
} while (fabs(q1-q2)/q2 > rel_error);
return one_sqrtpi*exp(-x*x)*q2;
}
#endif // _MSC_VER
#endif // LEPTON_MSVC_ERFC_H_
...@@ -32,10 +32,19 @@ ...@@ -32,10 +32,19 @@
#include "lepton/Operation.h" #include "lepton/Operation.h"
#include "lepton/ExpressionTreeNode.h" #include "lepton/ExpressionTreeNode.h"
#include "MSVC_erfc.h"
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
double Operation::Erf::evaluate(double* args, const map<string, double>& variables) const {
return erf(args[0]);
}
double Operation::Erfc::evaluate(double* args, const map<string, double>& variables) const {
return erfc(args[0]);
}
ExpressionTreeNode Operation::Constant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Constant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
} }
...@@ -216,6 +225,26 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr ...@@ -216,6 +225,26 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))),
ExpressionTreeNode(new Operation::Exp(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Square(), children[0])))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Erfc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))),
ExpressionTreeNode(new Operation::Exp(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Square(), children[0])))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Step::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Step::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
} }
......
...@@ -310,6 +310,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -310,6 +310,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
opMap["sinh"] = Operation::SINH; opMap["sinh"] = Operation::SINH;
opMap["cosh"] = Operation::COSH; opMap["cosh"] = Operation::COSH;
opMap["tanh"] = Operation::TANH; opMap["tanh"] = Operation::TANH;
opMap["erf"] = Operation::ERF;
opMap["erfc"] = Operation::ERFC;
opMap["step"] = Operation::STEP; opMap["step"] = Operation::STEP;
opMap["square"] = Operation::SQUARE; opMap["square"] = Operation::SQUARE;
opMap["cube"] = Operation::CUBE; opMap["cube"] = Operation::CUBE;
...@@ -359,6 +361,10 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -359,6 +361,10 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
return new Operation::Cosh(); return new Operation::Cosh();
case Operation::TANH: case Operation::TANH:
return new Operation::Tanh(); return new Operation::Tanh();
case Operation::ERF:
return new Operation::Erf();
case Operation::ERFC:
return new Operation::Erfc();
case Operation::STEP: case Operation::STEP:
return new Operation::Step(); return new Operation::Step();
case Operation::SQUARE: case Operation::SQUARE:
......
...@@ -250,7 +250,7 @@ enum CudaNonbondedMethod ...@@ -250,7 +250,7 @@ enum CudaNonbondedMethod
enum ExpressionOp { enum ExpressionOp {
VARIABLE0 = 0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, VARIABLE8, MULTIPLY, DIVIDE, ADD, SUBTRACT, POWER, MULTIPLY_CONSTANT, POWER_CONSTANT, ADD_CONSTANT, VARIABLE0 = 0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, VARIABLE8, MULTIPLY, DIVIDE, ADD, SUBTRACT, POWER, MULTIPLY_CONSTANT, POWER_CONSTANT, ADD_CONSTANT,
GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, STEP, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, STEP, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, ERF, ERFC
}; };
template<int SIZE> template<int SIZE>
......
...@@ -256,6 +256,12 @@ static Expression<SIZE> createExpression(gpuContext gpu, const string& expressio ...@@ -256,6 +256,12 @@ static Expression<SIZE> createExpression(gpuContext gpu, const string& expressio
case Operation::TANH: case Operation::TANH:
exp.op[i] = TANH; exp.op[i] = TANH;
break; break;
case Operation::ERF:
exp.op[i] = ERF;
break;
case Operation::ERFC:
exp.op[i] = ERFC;
break;
case Operation::STEP: case Operation::STEP:
exp.op[i] = STEP; exp.op[i] = STEP;
break; break;
......
...@@ -179,9 +179,15 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float* ...@@ -179,9 +179,15 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
else if (op == COSH) { else if (op == COSH) {
STACK(stackPointer) = cosh(STACK(stackPointer)); STACK(stackPointer) = cosh(STACK(stackPointer));
} }
else /*if (op == TANH)*/ { else if (op == TANH) {
STACK(stackPointer) = tanh(STACK(stackPointer)); STACK(stackPointer) = tanh(STACK(stackPointer));
} }
else if (op == ERF) {
STACK(stackPointer) = erf(STACK(stackPointer));
}
else /*if (op == ERFC)*/ {
STACK(stackPointer) = erfc(STACK(stackPointer));
}
} }
} }
} }
......
...@@ -194,6 +194,12 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -194,6 +194,12 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::TANH: case Operation::TANH:
out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")"; out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
break; break;
case Operation::ERF:
out << "erf(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ERFC:
out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::STEP: case Operation::STEP:
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f"; out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
break; break;
......
...@@ -199,6 +199,7 @@ int main() { ...@@ -199,6 +199,7 @@ int main() {
verifyEvaluation("x/(1/y)", 1.0, 4.0, 4.0); verifyEvaluation("x/(1/y)", 1.0, 4.0, 4.0);
verifyEvaluation("x*w; w = 5", 3.0, 1.0, 15.0); verifyEvaluation("x*w; w = 5", 3.0, 1.0, 15.0);
verifyEvaluation("a+b^2;a=x-b;b=3*y", 2.0, 3.0, 74.0); verifyEvaluation("a+b^2;a=x-b;b=3*y", 2.0, 3.0, 74.0);
verifyEvaluation("erf(x)+erfc(x)", 2.0, 3.0, 1.0);
verifyInvalidExpression("1..2"); verifyInvalidExpression("1..2");
verifyInvalidExpression("1*(2+3"); verifyInvalidExpression("1*(2+3");
verifyInvalidExpression("5++4"); verifyInvalidExpression("5++4");
...@@ -222,6 +223,8 @@ int main() { ...@@ -222,6 +223,8 @@ int main() {
verifyDerivative("sinh(x)", "cosh(x)"); verifyDerivative("sinh(x)", "cosh(x)");
verifyDerivative("cosh(x)", "sinh(x)"); verifyDerivative("cosh(x)", "sinh(x)");
verifyDerivative("tanh(x)", "1/(cosh(x)^2)"); verifyDerivative("tanh(x)", "1/(cosh(x)^2)");
verifyDerivative("erf(x)", "1.12837916709551*exp(-x^2)");
verifyDerivative("erfc(x)", "-1.12837916709551*exp(-x^2)");
verifyDerivative("step(x)*x+step(1-x)*2*x", "step(x)+step(1-x)*2"); verifyDerivative("step(x)*x+step(1-x)*2*x", "step(x)+step(1-x)*2");
verifyDerivative("recip(x)", "-1/x^2"); verifyDerivative("recip(x)", "-1/x^2");
verifyDerivative("square(x)", "2*x"); verifyDerivative("square(x)", "2*x");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment