Commit 00e2c6a3 authored by peastman's avatar peastman
Browse files

Support atan2() in custom expressions

parent 6e3c2142
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2016 Stanford University and the Authors. *
* Portions copyright (c) 2013-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -103,6 +103,7 @@ private:
#ifdef LEPTON_USE_JIT
void generateJitCode();
void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg, double (*function)(double));
void generateTwoArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg1, asmjit::X86Xmm& arg2, double (*function)(double, double));
std::vector<double> constants;
asmjit::JitRuntime runtime;
#endif
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2018 Stanford University and the Authors. *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -63,7 +63,7 @@ public:
* can be used when processing or analyzing parsed expressions.
*/
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, ERF, ERFC, STEP, DELTA, SQUARE, CUBE, RECIPROCAL,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, ATAN2, SINH, COSH, TANH, ERF, ERFC, STEP, DELTA, SQUARE, CUBE, RECIPROCAL,
ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT, MIN, MAX, ABS, FLOOR, CEIL, SELECT};
/**
* Get the name of this Operation.
......@@ -137,6 +137,7 @@ public:
class Asin;
class Acos;
class Atan;
class Atan2;
class Sinh;
class Cosh;
class Tanh;
......@@ -689,6 +690,28 @@ public:
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class LEPTON_EXPORT Operation::Atan2 : public Operation {
public:
Atan2() {
}
std::string getName() const {
return "atan2";
}
Id getId() const {
return ATAN2;
}
int getNumArguments() const {
return 2;
}
Operation* clone() const {
return new Atan2();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::atan2(args[0], args[1]);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class LEPTON_EXPORT Operation::Sinh : public Operation {
public:
Sinh() {
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2016 Stanford University and the Authors. *
* Portions copyright (c) 2013-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -294,6 +294,9 @@ void CompiledExpression::generateJitCode() {
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]);
break;
case Operation::POWER:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow);
break;
case Operation::NEGATE:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]);
......@@ -325,6 +328,9 @@ void CompiledExpression::generateJitCode() {
case Operation::ATAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan);
break;
case Operation::ATAN2:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2);
break;
case Operation::SINH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh);
break;
......@@ -400,4 +406,13 @@ void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86
call->setArg(0, arg);
call->setRet(0, dest);
}
void CompiledExpression::generateTwoArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg1, X86Xmm& arg2, double (*function)(double, double)) {
X86Gp fn = c.newIntPtr();
c.mov(fn, imm_ptr((void*) function));
CCFuncCall* call = c.call(fn, FuncSignature2<double, double, double>());
call->setArg(0, arg1);
call->setArg(1, arg2);
call->setRet(0, dest);
}
#endif
......@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -202,6 +202,16 @@ ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTr
childDerivs[0]);
}
ExpressionTreeNode Operation::Atan2::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Divide(),
ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]),
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])),
ExpressionTreeNode(new Operation::Add(),
ExpressionTreeNode(new Operation::Square(), children[0]),
ExpressionTreeNode(new Operation::Square(), children[1])));
}
ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Cosh(),
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -313,6 +313,7 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
opMap["asin"] = Operation::ASIN;
opMap["acos"] = Operation::ACOS;
opMap["atan"] = Operation::ATAN;
opMap["atan2"] = Operation::ATAN2;
opMap["sinh"] = Operation::SINH;
opMap["cosh"] = Operation::COSH;
opMap["tanh"] = Operation::TANH;
......@@ -368,6 +369,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
return new Operation::Acos();
case Operation::ATAN:
return new Operation::Atan();
case Operation::ATAN2:
return new Operation::Atan2();
case Operation::SINH:
return new Operation::Sinh();
case Operation::COSH:
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -123,7 +123,7 @@ private:
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
void callFunction(std::stringstream& out, std::string singleFn, std::string doubleFn, const std::string& arg, const std::string& tempType);
void callFunction2(std::stringstream& out, std::string fn, const std::string& arg1, const std::string& arg2, const std::string& tempType);
void callFunction2(std::stringstream& out, std::string singleFn, std::string doubleFn, const std::string& arg1, const std::string& arg2, const std::string& tempType);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
CudaContext& context;
FunctionPlaceholder fp1, fp2, fp3, periodicDistance;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2018 Stanford University and the Authors. *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -484,6 +484,9 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
case Operation::ATAN:
callFunction(out, "atanf", "atan", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::ATAN2:
callFunction2(out, "atan2f", "atan2", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], tempType);
break;
case Operation::SINH:
callFunction(out, "sinh", "sinh", getTempName(node.getChildren()[0], temps), tempType);
break;
......@@ -619,10 +622,10 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
break;
}
case Operation::MIN:
callFunction2(out, "min", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
callFunction2(out, "min", "min", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
break;
case Operation::MAX:
callFunction2(out, "max", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
callFunction2(out, "max", "max", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
break;
case Operation::ABS:
callFunction(out, "fabs", "fabs", getTempName(node.getChildren()[0], temps), tempType);
......@@ -923,22 +926,17 @@ Lepton::CustomFunction* CudaExpressionUtilities::getPeriodicDistancePlaceholder(
void CudaExpressionUtilities::callFunction(stringstream& out, string singleFn, string doubleFn, const string& arg, const string& tempType) {
bool isDouble = (tempType[0] == 'd');
bool isVector = (tempType[tempType.size()-1] == '3');
if (isVector) {
if (isDouble)
out<<"make_double3("<<doubleFn<<"("<<arg<<".x), "<<doubleFn<<"("<<arg<<".y), "<<doubleFn<<"("<<arg<<".z))";
else
out<<"make_float3("<<singleFn<<"("<<arg<<".x), "<<singleFn<<"("<<arg<<".y), "<<singleFn<<"("<<arg<<".z))";
}
else {
if (isDouble)
out<<doubleFn<<"("<<arg<<")";
else
out<<singleFn<<"("<<arg<<")";
}
string fn = (isDouble ? doubleFn : singleFn);
if (isVector)
out<<"make_"<<tempType<<"("<<fn<<"("<<arg<<".x), "<<fn<<"("<<arg<<".y), "<<fn<<"("<<arg<<".z))";
else
out<<fn<<"("<<arg<<")";
}
void CudaExpressionUtilities::callFunction2(stringstream& out, string fn, const string& arg1, const string& arg2, const string& tempType) {
void CudaExpressionUtilities::callFunction2(stringstream& out, string singleFn, string doubleFn, const string& arg1, const string& arg2, const string& tempType) {
bool isDouble = (tempType[0] == 'd');
bool isVector = (tempType[tempType.size()-1] == '3');
string fn = (isDouble ? doubleFn : singleFn);
if (isVector) {
out<<"make_"<<tempType<<"(";
out<<fn<<"("<<arg1<<".x, "<<arg2<<".x), ";
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2018 Stanford University and the Authors. *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -464,6 +464,9 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::ATAN:
out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ATAN2:
out << "atan2(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::SINH:
out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")";
break;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Portions copyright (c) 2008-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -216,6 +216,21 @@ void testIllegalVariable() {
ASSERT(threwException);
}
void testAtan2() {
System system;
system.addParticle(1.0);
CustomExternalForce* force = new CustomExternalForce("atan2(x, y)");
force->addParticle(0);
system.addForce(force);
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(1.5, -2.1, 1.2);
context.setPositions(positions);
State state = context.getState(State::Energy);
ASSERT_EQUAL_TOL(atan2(positions[0][0], positions[0][1]), state.getPotentialEnergy(), 1e-5);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -226,6 +241,7 @@ int main(int argc, char* argv[]) {
testPeriodic();
testZeroPeriodicDistance();
testIllegalVariable();
testAtan2();
runPlatformTests();
}
catch(const exception& e) {
......
......@@ -265,6 +265,7 @@ int main() {
verifyEvaluation("ceil(x)", -2.1, 3.0, -2.0);
verifyEvaluation("select(x, 1.0, y)", 0.3, 2.0, 1.0);
verifyEvaluation("select(x, 1.0, y)", 0.0, 2.0, 2.0);
verifyEvaluation("atan2(x, y)", 3.0, 1.5, std::atan(2.0));
verifyInvalidExpression("1..2");
verifyInvalidExpression("1*(2+3");
verifyInvalidExpression("5++4");
......@@ -285,6 +286,7 @@ int main() {
verifyDerivative("asin(x)", "1/sqrt(1-x^2)");
verifyDerivative("acos(x)", "-1/sqrt(1-x^2)");
verifyDerivative("atan(x)", "1/(1+x^2)");
verifyDerivative("atan2(2*x,y)", "2*y/(4*x^2+y^2)");
verifyDerivative("sinh(x)", "cosh(x)");
verifyDerivative("cosh(x)", "sinh(x)");
verifyDerivative("tanh(x)", "1/(cosh(x)^2)");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment