"vscode:/vscode.git/clone" did not exist on "b6b719f11c05ae0d54b5554dc401188fd7b1f7c5"
Commit 10b1c7b2 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1824 from peastman/savedforces

Bug fix to custom integrator
parents b317bd38 897fb788
...@@ -197,8 +197,11 @@ public: ...@@ -197,8 +197,11 @@ public:
double calcForcesAndEnergy(bool includeForces, bool includeEnergy, int groups=0xFFFFFFFF); double calcForcesAndEnergy(bool includeForces, bool includeEnergy, int groups=0xFFFFFFFF);
/** /**
* Get the set of force group flags that were passed to the most recent call to calcForcesAndEnergy(). * Get the set of force group flags that were passed to the most recent call to calcForcesAndEnergy().
*
* Note that this returns a reference, so it's possible to modify it. Be very very cautious about
* doing that! Only do it if you're also modifying forces stored inside the context.
*/ */
int getLastForceGroups() const; int& getLastForceGroups();
/** /**
* Calculate the kinetic energy of the system (in kJ/mol). * Calculate the kinetic energy of the system (in kJ/mol).
*/ */
......
...@@ -301,7 +301,7 @@ double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy, ...@@ -301,7 +301,7 @@ double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy,
} }
} }
int ContextImpl::getLastForceGroups() const { int& ContextImpl::getLastForceGroups() {
return lastForceGroups; return lastForceGroups;
} }
......
...@@ -111,7 +111,8 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -111,7 +111,8 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
for (auto& param : force->getDefaultParameters()) for (auto& param : force->getDefaultParameters())
affectsForce.insert(param.first); affectsForce.insert(param.first);
for (int i = 0; i < numSteps; i++) for (int i = 0; i < numSteps; i++)
invalidatesForces[i] = (stepType[i] == CustomIntegrator::ConstrainPositions || affectsForce.find(stepVariable[i]) != affectsForce.end()); invalidatesForces[i] = (stepType[i] == CustomIntegrator::ConstrainPositions || stepType[i] == CustomIntegrator::UpdateContextState ||
affectsForce.find(stepVariable[i]) != affectsForce.end());
// Make a list of which steps require valid forces or energy to be known. // Make a list of which steps require valid forces or energy to be known.
......
...@@ -7604,13 +7604,15 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7604,13 +7604,15 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
int nextStep = step+1; int nextStep = step+1;
int lastForceGroups = context.getLastForceGroups(); int lastForceGroups = context.getLastForceGroups();
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) { if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) { if (forcesAreValid) {
if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old // The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again. // forces in case we need them again.
cu.getForce().copyTo(*savedForces[lastForceGroups]); cu.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups); validSavedForces.insert(lastForceGroups);
} }
}
else else
validSavedForces.clear(); validSavedForces.clear();
...@@ -7623,6 +7625,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7623,6 +7625,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
// We can just restore the forces we saved earlier. // We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[step]]->copyTo(cu.getForce()); savedForces[forceGroupFlags[step]]->copyTo(cu.getForce());
context.getLastForceGroups() = forceGroupFlags[step];
} }
else { else {
recordChangedParameters(context); recordChangedParameters(context);
......
...@@ -7944,13 +7944,15 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -7944,13 +7944,15 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
int nextStep = step+1; int nextStep = step+1;
int lastForceGroups = context.getLastForceGroups(); int lastForceGroups = context.getLastForceGroups();
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) { if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) { if (forcesAreValid) {
if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old // The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again. // forces in case we need them again.
cl.getForce().copyTo(*savedForces[lastForceGroups]); cl.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups); validSavedForces.insert(lastForceGroups);
} }
}
else else
validSavedForces.clear(); validSavedForces.clear();
...@@ -7963,6 +7965,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -7963,6 +7965,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
// We can just restore the forces we saved earlier. // We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[step]]->copyTo(cl.getForce()); savedForces[forceGroupFlags[step]]->copyTo(cl.getForce());
context.getLastForceGroups() = forceGroupFlags[step];
} }
else { else {
recordChangedParameters(context); recordChangedParameters(context);
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "openmm/AndersenThermostat.h" #include "openmm/AndersenThermostat.h"
#include "openmm/CustomAngleForce.h" #include "openmm/CustomAngleForce.h"
#include "openmm/CustomBondForce.h" #include "openmm/CustomBondForce.h"
#include "openmm/CustomExternalForce.h"
#include "openmm/CustomIntegrator.h" #include "openmm/CustomIntegrator.h"
#include "openmm/HarmonicBondForce.h" #include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h" #include "openmm/NonbondedForce.h"
...@@ -921,6 +922,52 @@ void testTabulatedFunction() { ...@@ -921,6 +922,52 @@ void testTabulatedFunction() {
ASSERT_EQUAL_VEC(Vec3(12.0, 13.0, 14.0), values[0], 1e-5); ASSERT_EQUAL_VEC(Vec3(12.0, 13.0, 14.0), values[0], 1e-5);
} }
/**
* Test an integrator that alternates repeatedly between force groups.
*/
void testAlternatingGroups() {
System system;
system.addParticle(1.0);
CustomExternalForce* force1 = new CustomExternalForce("-0.5*x");
force1->addParticle(0);
system.addForce(force1);
CustomExternalForce* force2 = new CustomExternalForce("-0.8*y");
force2->addParticle(0);
force2->setForceGroup(1);
system.addForce(force2);
CustomIntegrator integrator(0.5);
integrator.addGlobalVariable("savede1", 0.0);
integrator.addGlobalVariable("savede2", 0.0);
integrator.addGlobalVariable("savede3", 0.0);
integrator.addGlobalVariable("savede4", 0.0);
integrator.addPerDofVariable("savedf1", 0.0);
integrator.addPerDofVariable("savedf2", 0.0);
integrator.addPerDofVariable("savedf3", 0.0);
integrator.addPerDofVariable("savedf4", 0.0);
integrator.addComputeGlobal("savede1", "energy0");
integrator.addComputeGlobal("savede2", "energy1");
integrator.addComputePerDof("savedf1", "f0");
integrator.addComputePerDof("savedf2", "f1");
integrator.addComputeGlobal("savede3", "energy0");
integrator.addComputeGlobal("savede4", "energy1");
integrator.addComputePerDof("savedf3", "f0");
integrator.addComputePerDof("savedf4", "f1");
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(1, 2, 3);
context.setPositions(positions);
integrator.step(1);
vector<Vec3> f;
for (int i = 0; i < 2; i++) {
ASSERT_EQUAL_TOL(-0.5*1, integrator.getGlobalVariable(2*i), 1e-5);
ASSERT_EQUAL_TOL(-0.8*2, integrator.getGlobalVariable(2*i+1), 1e-5);
integrator.getPerDofVariable(2*i, f);
ASSERT_EQUAL_VEC(Vec3(0.5, 0, 0), f[0], 1e-5);
integrator.getPerDofVariable(2*i+1, f);
ASSERT_EQUAL_VEC(Vec3(0, 0.8, 0), f[0], 1e-5);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -944,6 +991,7 @@ int main(int argc, char* argv[]) { ...@@ -944,6 +991,7 @@ int main(int argc, char* argv[]) {
testEnergyParameterDerivatives(); testEnergyParameterDerivatives();
testChangeDT(); testChangeDT();
testTabulatedFunction(); testTabulatedFunction();
testAlternatingGroups();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment