Commit 9d779a39 authored by peastman's avatar peastman
Browse files

Reference CustomIntegrator avoids unnecessary force/energy computations

parent 10b1c7b2
/* Portions copyright (c) 2011-2016 Stanford University and Simbios. /* Portions copyright (c) 2011-2017 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -219,19 +219,26 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -219,19 +219,26 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
for (auto& global : globals) for (auto& global : globals)
expressionSet.setVariable(expressionSet.getVariableIndex(global.first), global.second); expressionSet.setVariable(expressionSet.getVariableIndex(global.first), global.second);
oldPos = atomCoordinates; oldPos = atomCoordinates;
map<int, double> groupEnergy;
map<int, vector<Vec3> > groupForces;
if (forcesAreValid)
groupForces[context.getLastForceGroups()] = forces;
// Loop over steps and execute them. // Loop over steps and execute them.
for (int step = 0; step < numSteps; ) { for (int step = 0; step < numSteps; ) {
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || context.getLastForceGroups() != forceGroupFlags[step])) { int flags = forceGroupFlags[step];
if ((needsForces[step] && groupForces.find(flags) == groupForces.end()) || (needsEnergy[step] && groupEnergy.find(flags) == groupEnergy.end())) {
// Recompute forces and/or energy. // Recompute forces and/or energy.
bool computeForce = needsForces[step] || computeBothForceAndEnergy[step]; bool computeForce = needsForces[step] || computeBothForceAndEnergy[step];
bool computeEnergy = needsEnergy[step] || computeBothForceAndEnergy[step]; bool computeEnergy = needsEnergy[step] || computeBothForceAndEnergy[step];
recordChangedParameters(context, globals); recordChangedParameters(context, globals);
double e = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]); double e = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]);
if (computeForce)
groupForces[flags] = forces;
if (computeEnergy) { if (computeEnergy) {
energy = e; groupEnergy[flags] = e;
context.getEnergyParameterDerivatives(energyParamDerivs); context.getEnergyParameterDerivatives(energyParamDerivs);
} }
forcesAreValid = true; forcesAreValid = true;
...@@ -239,6 +246,8 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -239,6 +246,8 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
// Execute the step. // Execute the step.
energy = (needsEnergy[step] ? groupEnergy[flags] : 0);
vector<Vec3>& stepForces = (needsForces[step] ? groupForces[flags] : forces);
int nextStep = step+1; int nextStep = step+1;
switch (stepType[step]) { switch (stepType[step]) {
case CustomIntegrator::ComputeGlobal: { case CustomIntegrator::ComputeGlobal: {
...@@ -262,11 +271,11 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -262,11 +271,11 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
} }
if (results == NULL) if (results == NULL)
throw OpenMMException("Illegal per-DOF output variable: "+stepVariable[step]); throw OpenMMException("Illegal per-DOF output variable: "+stepVariable[step]);
computePerDof(numberOfAtoms, *results, atomCoordinates, velocities, forces, masses, perDof, stepExpressions[step][0]); computePerDof(numberOfAtoms, *results, atomCoordinates, velocities, stepForces, masses, perDof, stepExpressions[step][0]);
break; break;
} }
case CustomIntegrator::ComputeSum: { case CustomIntegrator::ComputeSum: {
computePerDof(numberOfAtoms, sumBuffer, atomCoordinates, velocities, forces, masses, perDof, stepExpressions[step][0]); computePerDof(numberOfAtoms, sumBuffer, atomCoordinates, velocities, stepForces, masses, perDof, stepExpressions[step][0]);
double sum = 0.0; double sum = 0.0;
for (int j = 0; j < numberOfAtoms; j++) for (int j = 0; j < numberOfAtoms; j++)
if (masses[j] != 0.0) if (masses[j] != 0.0)
...@@ -308,8 +317,11 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -308,8 +317,11 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
break; break;
} }
} }
if (invalidatesForces[step]) if (invalidatesForces[step]) {
forcesAreValid = false; forcesAreValid = false;
groupForces.clear();
groupEnergy.clear();
}
step = nextStep; step = nextStep;
} }
ReferenceVirtualSites::computePositions(context.getSystem(), atomCoordinates); ReferenceVirtualSites::computePositions(context.getSystem(), atomCoordinates);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment