Commit c300870a authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bug in switching functions with CustomNonbondedForce

parent 04b1a060
......@@ -2271,7 +2271,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
map<string, Lepton::ParsedExpression> forceExpressions;
forceExpressions["tempEnergy += "] = energyExpression;
forceExpressions["real customEnergy = "] = energyExpression;
forceExpressions["tempForce -= "] = forceExpression;
// Create the kernels.
......
......@@ -14,8 +14,10 @@ if (!isExcluded) {
#endif
COMPUTE_FORCE
#if USE_SWITCH
tempForce = tempForce*switchValue - tempEnergy*switchDeriv;
tempEnergy *= switchValue;
tempForce = tempForce*switchValue - customEnergy*switchDeriv;
tempEnergy += customEnergy*switchValue;
#else
tempEnergy += customEnergy;
#endif
dEdR += tempForce*invR;
}
......@@ -2331,7 +2331,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
map<string, Lepton::ParsedExpression> forceExpressions;
forceExpressions["tempEnergy += "] = energyExpression;
forceExpressions["real customEnergy = "] = energyExpression;
forceExpressions["tempForce -= "] = forceExpression;
// Create the kernels.
......
......@@ -14,8 +14,10 @@ if (!isExcluded) {
#endif
COMPUTE_FORCE
#if USE_SWITCH
tempForce = tempForce*switchValue - tempEnergy*switchDeriv;
tempEnergy *= switchValue;
tempForce = tempForce*switchValue - customEnergy*switchDeriv;
tempEnergy += customEnergy*switchValue;
#else
tempEnergy += customEnergy;
#endif
dEdR += tempForce*invR;
}
......@@ -1034,6 +1034,65 @@ void testMultipleCutoffs() {
}
}
void testMultipleSwitches() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
// Add multiple CustomNonbondedForces, one with a switching function and one without.
CustomNonbondedForce* nonbonded1 = new CustomNonbondedForce("2*r");
nonbonded1->addParticle(vector<double>());
nonbonded1->addParticle(vector<double>());
nonbonded1->setNonbondedMethod(CustomNonbondedForce::CutoffNonPeriodic);
nonbonded1->setCutoffDistance(1.0);
system.addForce(nonbonded1);
CustomNonbondedForce* nonbonded2 = new CustomNonbondedForce("3*r");
nonbonded2->addParticle(vector<double>());
nonbonded2->addParticle(vector<double>());
nonbonded2->setNonbondedMethod(CustomNonbondedForce::CutoffNonPeriodic);
nonbonded2->setCutoffDistance(1.0);
nonbonded2->setSwitchingDistance(0.5);
nonbonded2->setUseSwitchingFunction(true);
nonbonded2->setForceGroup(1);
system.addForce(nonbonded2);
Context context(system, integrator, platform);
double r = 0.8;
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
positions[1] = Vec3(0, r, 0);
context.setPositions(positions);
double t = (r-0.5)/(1.0-0.5);
double switchValue = 1+t*t*t*(-10+t*(15-t*6));
double switchDeriv = t*t*(-30+t*(60-t*30))/(1.0-0.5);
double e1 = 2.0*r;
double e2 = 3.0*r*switchValue;
double f1 = 2.0;
double f2 = 3.0*switchValue + 3.0*switchDeriv*r;
// Check the first force.
State state = context.getState(State::Forces | State::Energy, false, 1);
ASSERT_EQUAL_VEC(Vec3(0, f1, 0), state.getForces()[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, -f1, 0), state.getForces()[1], TOL);
ASSERT_EQUAL_TOL(e1, state.getPotentialEnergy(), TOL);
// Check the second force.
state = context.getState(State::Forces | State::Energy, false, 2);
ASSERT_EQUAL_VEC(Vec3(0, f2, 0), state.getForces()[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, -f2, 0), state.getForces()[1], TOL);
ASSERT_EQUAL_TOL(e2, state.getPotentialEnergy(), TOL);
// Check the sum of both forces.
state = context.getState(State::Forces | State::Energy);
ASSERT_EQUAL_VEC(Vec3(0, f1+f2, 0), state.getForces()[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, -f1-f2, 0), state.getForces()[1], TOL);
ASSERT_EQUAL_TOL(e1+e2, state.getPotentialEnergy(), TOL);
}
void testIllegalVariable() {
System system;
system.addParticle(1.0);
......@@ -1154,6 +1213,7 @@ int main(int argc, char* argv[]) {
testInteractionGroupLongRangeCorrection();
testInteractionGroupTabulatedFunction();
testMultipleCutoffs();
testMultipleSwitches();
testIllegalVariable();
testEnergyParameterDerivatives();
testEnergyParameterDerivatives2();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment