Unverified Commit 609fc413 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Simplified derivatives of min() and max() (#3240)

* Simplified derivatives of min() and max()

* Fixed compilation error in kernel
parent 654c6c9c
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. * * Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -393,19 +393,13 @@ ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<Exp ...@@ -393,19 +393,13 @@ ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<Exp
ExpressionTreeNode Operation::Min::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Min::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode step(new Operation::Step(),
ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); ExpressionTreeNode(new Operation::Subtract(), children[0], children[1]));
return ExpressionTreeNode(new Operation::Subtract(), return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[1], childDerivs[0]});
ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], step),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[0],
ExpressionTreeNode(new Operation::AddConstant(-1), step)));
} }
ExpressionTreeNode Operation::Max::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Max::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode step(new Operation::Step(),
ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); ExpressionTreeNode(new Operation::Subtract(), children[0], children[1]));
return ExpressionTreeNode(new Operation::Subtract(), return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[0], childDerivs[1]});
ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], step),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[1],
ExpressionTreeNode(new Operation::AddConstant(-1), step)));
} }
ExpressionTreeNode Operation::Abs::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Abs::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
...@@ -427,9 +421,5 @@ ExpressionTreeNode Operation::Ceil::differentiate(const std::vector<ExpressionTr ...@@ -427,9 +421,5 @@ ExpressionTreeNode Operation::Ceil::differentiate(const std::vector<ExpressionTr
} }
ExpressionTreeNode Operation::Select::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Select::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
vector<ExpressionTreeNode> derivChildren; return ExpressionTreeNode(new Operation::Select(), {children[0], childDerivs[1], childDerivs[2]});
derivChildren.push_back(children[0]);
derivChildren.push_back(childDerivs[1]);
derivChildren.push_back(childDerivs[2]);
return ExpressionTreeNode(new Operation::Select(), derivChildren);
} }
...@@ -3280,7 +3280,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3280,7 +3280,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1]; gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2])) if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2]; gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cc.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp"); compute << cc.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "gradtemp_"+is);
} }
for (int i = 1; i < numComputedValues; i++) for (int i = 1; i < numComputedValues; i++)
compute << "force -= deriv"<<energyDerivs->getParameterSuffix(i)<<"*dV"<<i<<"dR;\n"; compute << "force -= deriv"<<energyDerivs->getParameterSuffix(i)<<"*dV"<<i<<"dR;\n";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment