Unverified Commit dde4228b authored by Emilio Gallicchio's avatar Emilio Gallicchio Committed by GitHub
Browse files

ATMForce: ignore invalid energy term when the energy expression does not depend on it (#4834)



* reset overflowed state energies at the alchemical endpoints

* address formatting, complete clash test

* Fixed indentation

---------
Co-authored-by: default avatarPeter Eastman <peter.eastman@gmail.com>
parent 7bdf028c
...@@ -47,6 +47,7 @@ ...@@ -47,6 +47,7 @@
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include "lepton/Parser.h" #include "lepton/Parser.h"
#include <cmath> #include <cmath>
#include <limits>
#include <map> #include <map>
#include <set> #include <set>
#include <sstream> #include <sstream>
...@@ -138,19 +139,33 @@ double ATMForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForce ...@@ -138,19 +139,33 @@ double ATMForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForce
state0Energy = innerContextImpl0.calcForcesAndEnergy(includeForces, true); state0Energy = innerContextImpl0.calcForcesAndEnergy(includeForces, true);
state1Energy = innerContextImpl1.calcForcesAndEnergy(includeForces, true); state1Energy = innerContextImpl1.calcForcesAndEnergy(includeForces, true);
// Compute the alchemical energy and forces. // set global parameters for energy expression
for (int i = 0; i < globalParameterNames.size(); i++) for (int i = 0; i < globalParameterNames.size(); i++)
globalValues[i] = context.getParameter(globalParameterNames[i]); globalValues[i] = context.getParameter(globalParameterNames[i]);
// Protect against overflow when the hybrid potential function does
// not depend on u0 or u1 and their values are unbounded; typically at the endstates
double dEdu0 = u0DerivExpression.evaluate();
double dEdu1 = u1DerivExpression.evaluate();
double epsi = std::numeric_limits<float>::min();
double maxEnergy = std::numeric_limits<float>::max();
if(fabs(dEdu0) < epsi && (isnan(state0Energy) || isinf(state0Energy)))
state0Energy = maxEnergy;
if(fabs(dEdu1) < epsi && (isnan(state1Energy) || isinf(state1Energy)))
state1Energy = maxEnergy;
// Compute the alchemical energy and forces.
combinedEnergy = energyExpression.evaluate(); combinedEnergy = energyExpression.evaluate();
if (includeForces) { if (includeForces) {
double dEdu0 = u0DerivExpression.evaluate();
double dEdu1 = u1DerivExpression.evaluate();
map<string, double> energyParamDerivs; map<string, double> energyParamDerivs;
for (int i = 0; i < paramDerivExpressions.size(); i++) for (int i = 0; i < paramDerivExpressions.size(); i++)
energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate(); energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate();
kernel.getAs<CalcATMForceKernel>().applyForces(context, innerContextImpl0, innerContextImpl1, dEdu0, dEdu1, energyParamDerivs); kernel.getAs<CalcATMForceKernel>().applyForces(context, innerContextImpl0, innerContextImpl1, dEdu0, dEdu1, energyParamDerivs);
} }
return (includeEnergy ? combinedEnergy : 0.0); return (includeEnergy ? combinedEnergy : 0.0);
} }
......
...@@ -2972,8 +2972,18 @@ void ReferenceCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl& ...@@ -2972,8 +2972,18 @@ void ReferenceCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl&
vector<Vec3>& force = extractForces(context); vector<Vec3>& force = extractForces(context);
vector<Vec3>& force0 = extractForces(innerContext0); vector<Vec3>& force0 = extractForces(innerContext0);
vector<Vec3>& force1 = extractForces(innerContext1); vector<Vec3>& force1 = extractForces(innerContext1);
for (int i = 0; i < force.size(); i++)
force[i] += dEdu0*force0[i] + dEdu1*force1[i]; //update forces and
//protects from infinite forces when the hybrid potential does
//not depend on u1 or u0, typically at the endpoints
double epsi = std::numeric_limits<float>::min();
for (int i = 0; i < force.size(); i++) {
if (fabs(dEdu0) > epsi)
force[i] += dEdu0*force0[i];
if (fabs(dEdu1) > epsi)
force[i] += dEdu1*force1[i];
}
map<string, double>& derivs = extractEnergyParameterDerivatives(context); map<string, double>& derivs = extractEnergyParameterDerivatives(context);
for (auto deriv : energyParamDerivs) for (auto deriv : energyParamDerivs)
derivs[deriv.first] += deriv.second; derivs[deriv.first] += deriv.second;
......
...@@ -110,7 +110,6 @@ void test2Particles() { ...@@ -110,7 +110,6 @@ void test2Particles() {
} }
} }
void test2Particles2Displacement0() { void test2Particles2Displacement0() {
// A pair of particles tethered by an harmonic bond. // A pair of particles tethered by an harmonic bond.
// Displace the second one to test energy and forces at different lambda values // Displace the second one to test energy and forces at different lambda values
...@@ -194,6 +193,7 @@ double softCoreFunc(double u, double umax, double ub, double a, double& df) { ...@@ -194,6 +193,7 @@ double softCoreFunc(double u, double umax, double ub, double a, double& df) {
return usc; return usc;
} }
void test2ParticlesSoftCore() { void test2ParticlesSoftCore() {
// Similar to test2Particles() but employing a soft-core function // Similar to test2Particles() but employing a soft-core function
...@@ -241,7 +241,10 @@ void test2ParticlesSoftCore() { ...@@ -241,7 +241,10 @@ void test2ParticlesSoftCore() {
ASSERT_EQUAL_VEC(Vec3(-lmbd*df*displ[0], 0.0, 0.0), state.getForces()[1], 1e-6); ASSERT_EQUAL_VEC(Vec3(-lmbd*df*displ[0], 0.0, 0.0), state.getForces()[1], 1e-6);
} }
void testNonbonded() { void testNonbonded() {
// Tests a system with a nonbonded Force
System system; System system;
double u0, u1, energy; double u0, u1, energy;
double lambda = 0.5; double lambda = 0.5;
...@@ -296,11 +299,61 @@ void testNonbonded() { ...@@ -296,11 +299,61 @@ void testNonbonded() {
double epot2 = state2.getPotentialEnergy(); double epot2 = state2.getPotentialEnergy();
atm->getPerturbationEnergy(context2, u1, u0, energy); atm->getPerturbationEnergy(context2, u1, u0, energy);
double epert2 = u1 - u0; double epert2 = u1 - u0;
ASSERT_EQUAL_TOL(epert1, epert2, 1e-3); ASSERT_EQUAL_TOL(epert1, epert2, 1e-3);
} }
void testNonbondedwithEndpointClash() {
// Similar to testNonbonded() but at the initial alchemical state
// and with an invalid potential at the final state due to a clash
// between two particles.
System system;
double u0, u1, energy;
double lambda = 0.0; //U(lambda) = u0; it does not depend on u1
double width = 4.0;
system.setDefaultPeriodicBoxVectors(Vec3(width, 0, 0), Vec3(0, width, 0), Vec3(0, 0, width));
NonbondedForce* nbforce = new NonbondedForce();
nbforce->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
nbforce->setCutoffDistance(0.7);
ATMForce* atm = new ATMForce(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0);
double spacing = width/6.0;
double offset = spacing/5.0;
vector<Vec3> positions;
for (int i = 0; i < 6; i++)
for (int j = 0; j < 6; j++)
for (int k = 0; k < 6; k++) {
positions.push_back(Vec3(spacing*i+offset, spacing*j+offset, spacing*k+offset));
system.addParticle(10.0);
nbforce->addParticle(0, 0.3, 1.0);
atm->addParticle(Vec3(0,0,0));
}
//places first particle almost on top of another particle in displaced system
atm->setParticleParameters(0, Vec3(spacing+1.e-4, 0, 0), Vec3(0.0, 0, 0));
atm->addForce(nbforce);
system.addForce(atm);
LangevinMiddleIntegrator integrator(300, 1.0, 0.004);
Context context(system, integrator, platform);
context.setPositions(positions);
context.setParameter(ATMForce::Lambda1(), lambda);
context.setParameter(ATMForce::Lambda2(), lambda);
State state = context.getState(State::Energy, false);
double epot = state.getPotentialEnergy();
ASSERT(!isnan(epot) && !isinf(epot));
atm->getPerturbationEnergy(context, u1, u0, energy);
double epert = u1 - u0;
ASSERT(!isnan(energy) && !isinf(energy));
ASSERT(!isnan(epert) && !isinf(epert));
integrator.step(10);
state = context.getState(State::Energy | State::Positions, false);
vector<Vec3> positions2 = state.getPositions();
ASSERT(fabs(positions[0][0] - positions2[0][0]) < width);
}
void testParticlesCustomExpressionLinear() { void testParticlesCustomExpressionLinear() {
// Similar to test2Particles() but employing a custom alchemical energy expression // Similar to test2Particles() but employing a custom alchemical energy expression
...@@ -546,6 +599,7 @@ int main(int argc, char* argv[]) { ...@@ -546,6 +599,7 @@ int main(int argc, char* argv[]) {
test2Particles2Displacement0(); test2Particles2Displacement0();
test2ParticlesSoftCore(); test2ParticlesSoftCore();
testNonbonded(); testNonbonded();
testNonbondedwithEndpointClash();
testParticlesCustomExpressionLinear(); testParticlesCustomExpressionLinear();
testParticlesCustomExpressionSoftplus(); testParticlesCustomExpressionSoftplus();
testLargeSystem(); testLargeSystem();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment