Commit 897fb788 authored by Peter Eastman's avatar Peter Eastman
Browse files

Bug fix to custom integrator

parent b317bd38
......@@ -197,8 +197,11 @@ public:
double calcForcesAndEnergy(bool includeForces, bool includeEnergy, int groups=0xFFFFFFFF);
/**
* Get the set of force group flags that were passed to the most recent call to calcForcesAndEnergy().
*
* Note that this returns a reference, so it's possible to modify it. Be very very cautious about
* doing that! Only do it if you're also modifying forces stored inside the context.
*/
int getLastForceGroups() const;
int& getLastForceGroups();
/**
* Calculate the kinetic energy of the system (in kJ/mol).
*/
......
......@@ -301,7 +301,7 @@ double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy,
}
}
int ContextImpl::getLastForceGroups() const {
int& ContextImpl::getLastForceGroups() {
return lastForceGroups;
}
......
......@@ -111,7 +111,8 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
for (auto& param : force->getDefaultParameters())
affectsForce.insert(param.first);
for (int i = 0; i < numSteps; i++)
invalidatesForces[i] = (stepType[i] == CustomIntegrator::ConstrainPositions || affectsForce.find(stepVariable[i]) != affectsForce.end());
invalidatesForces[i] = (stepType[i] == CustomIntegrator::ConstrainPositions || stepType[i] == CustomIntegrator::UpdateContextState ||
affectsForce.find(stepVariable[i]) != affectsForce.end());
// Make a list of which steps require valid forces or energy to be known.
......
......@@ -7604,12 +7604,14 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
int nextStep = step+1;
int lastForceGroups = context.getLastForceGroups();
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again.
cu.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups);
if (forcesAreValid) {
if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again.
cu.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups);
}
}
else
validSavedForces.clear();
......@@ -7623,6 +7625,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
// We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[step]]->copyTo(cu.getForce());
context.getLastForceGroups() = forceGroupFlags[step];
}
else {
recordChangedParameters(context);
......
......@@ -7944,12 +7944,14 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
int nextStep = step+1;
int lastForceGroups = context.getLastForceGroups();
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again.
cl.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups);
if (forcesAreValid) {
if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again.
cl.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups);
}
}
else
validSavedForces.clear();
......@@ -7963,6 +7965,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
// We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[step]]->copyTo(cl.getForce());
context.getLastForceGroups() = forceGroupFlags[step];
}
else {
recordChangedParameters(context);
......
......@@ -37,6 +37,7 @@
#include "openmm/AndersenThermostat.h"
#include "openmm/CustomAngleForce.h"
#include "openmm/CustomBondForce.h"
#include "openmm/CustomExternalForce.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h"
......@@ -921,6 +922,52 @@ void testTabulatedFunction() {
ASSERT_EQUAL_VEC(Vec3(12.0, 13.0, 14.0), values[0], 1e-5);
}
/**
* Test an integrator that alternates repeatedly between force groups.
*/
void testAlternatingGroups() {
System system;
system.addParticle(1.0);
CustomExternalForce* force1 = new CustomExternalForce("-0.5*x");
force1->addParticle(0);
system.addForce(force1);
CustomExternalForce* force2 = new CustomExternalForce("-0.8*y");
force2->addParticle(0);
force2->setForceGroup(1);
system.addForce(force2);
CustomIntegrator integrator(0.5);
integrator.addGlobalVariable("savede1", 0.0);
integrator.addGlobalVariable("savede2", 0.0);
integrator.addGlobalVariable("savede3", 0.0);
integrator.addGlobalVariable("savede4", 0.0);
integrator.addPerDofVariable("savedf1", 0.0);
integrator.addPerDofVariable("savedf2", 0.0);
integrator.addPerDofVariable("savedf3", 0.0);
integrator.addPerDofVariable("savedf4", 0.0);
integrator.addComputeGlobal("savede1", "energy0");
integrator.addComputeGlobal("savede2", "energy1");
integrator.addComputePerDof("savedf1", "f0");
integrator.addComputePerDof("savedf2", "f1");
integrator.addComputeGlobal("savede3", "energy0");
integrator.addComputeGlobal("savede4", "energy1");
integrator.addComputePerDof("savedf3", "f0");
integrator.addComputePerDof("savedf4", "f1");
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(1, 2, 3);
context.setPositions(positions);
integrator.step(1);
vector<Vec3> f;
for (int i = 0; i < 2; i++) {
ASSERT_EQUAL_TOL(-0.5*1, integrator.getGlobalVariable(2*i), 1e-5);
ASSERT_EQUAL_TOL(-0.8*2, integrator.getGlobalVariable(2*i+1), 1e-5);
integrator.getPerDofVariable(2*i, f);
ASSERT_EQUAL_VEC(Vec3(0.5, 0, 0), f[0], 1e-5);
integrator.getPerDofVariable(2*i+1, f);
ASSERT_EQUAL_VEC(Vec3(0, 0.8, 0), f[0], 1e-5);
}
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -944,6 +991,7 @@ int main(int argc, char* argv[]) {
testEnergyParameterDerivatives();
testChangeDT();
testTabulatedFunction();
testAlternatingGroups();
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment