Commit 66bc28f5 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1826 from peastman/optimizeintegrator

CustomIntegrator avoids unnecessary force/energy computations
parents fb607c7f 6a0e1bd5
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <utility>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
...@@ -250,26 +251,23 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, ...@@ -250,26 +251,23 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps,
void CustomIntegratorUtilities::analyzeForceComputationsForPath(vector<int>& steps, const vector<bool>& needsForces, const vector<bool>& needsEnergy, void CustomIntegratorUtilities::analyzeForceComputationsForPath(vector<int>& steps, const vector<bool>& needsForces, const vector<bool>& needsEnergy,
const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth) { const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth) {
vector<int> candidatePoints; vector<pair<int, int> > candidatePoints;
int currentGroup = -1;
for (int step : steps) { for (int step : steps) {
if (invalidatesForces[step] || ((needsForces[step] || needsEnergy[step]) && forceGroup[step] != currentGroup)) { if (invalidatesForces[step]) {
// Forces and energies are invalidated at this step, or it changes to a different force group, // Forces and energies are invalidated at this step, so anything from this point on won't affect what we do at earlier steps.
// so anything from this point on won't affect what we do at earlier steps.
candidatePoints.clear(); candidatePoints.clear();
} }
if (needsForces[step] || needsEnergy[step]) { if (needsForces[step] || needsEnergy[step]) {
// See if this step affects what we do at earlier points. // See if this step affects what we do at earlier points.
for (int candidate : candidatePoints) for (auto candidate : candidatePoints)
if ((needsForces[candidate] && needsEnergy[step]) || (needsEnergy[candidate] && needsForces[step])) if (candidate.second == forceGroup[step] && ((needsForces[candidate.first] && needsEnergy[step]) || (needsEnergy[candidate.first] && needsForces[step])))
computeBoth[candidate] = true; computeBoth[candidate.first] = true;
// Add this to the list of candidates that might be affected by later steps. // Add this to the list of candidates that might be affected by later steps.
candidatePoints.push_back(step); candidatePoints.push_back(make_pair(step, forceGroup[step]));
currentGroup = forceGroup[step];
} }
} }
} }
......
...@@ -1507,6 +1507,7 @@ private: ...@@ -1507,6 +1507,7 @@ private:
CudaArray* randomSeed; CudaArray* randomSeed;
CudaArray* perDofEnergyParamDerivs; CudaArray* perDofEnergyParamDerivs;
std::vector<CudaArray*> tabulatedFunctions; std::vector<CudaArray*> tabulatedFunctions;
std::map<int, double> savedEnergy;
std::map<int, CudaArray*> savedForces; std::map<int, CudaArray*> savedForces;
std::set<int> validSavedForces; std::set<int> validSavedForces;
CudaParameterSet* perDofValues; CudaParameterSet* perDofValues;
......
...@@ -7594,6 +7594,8 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7594,6 +7594,8 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
CudaIntegrationUtilities& integration = cu.getIntegrationUtilities(); CudaIntegrationUtilities& integration = cu.getIntegrationUtilities();
int numAtoms = cu.getNumAtoms(); int numAtoms = cu.getNumAtoms();
int numSteps = integrator.getNumComputations(); int numSteps = integrator.getNumComputations();
if (!forcesAreValid)
savedEnergy.clear();
// Loop over computation steps in the integrator and execute them. // Loop over computation steps in the integrator and execute them.
...@@ -7602,8 +7604,11 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7602,8 +7604,11 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
CUdeviceptr posCorrection = (cu.getUseMixedPrecision() ? cu.getPosqCorrection().getDevicePointer() : 0); CUdeviceptr posCorrection = (cu.getUseMixedPrecision() ? cu.getPosqCorrection().getDevicePointer() : 0);
for (int step = 0; step < numSteps; ) { for (int step = 0; step < numSteps; ) {
int nextStep = step+1; int nextStep = step+1;
int forceGroups = forceGroupFlags[step];
int lastForceGroups = context.getLastForceGroups(); int lastForceGroups = context.getLastForceGroups();
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) { bool haveForces = (!needsForces[step] || (forcesAreValid && lastForceGroups == forceGroups));
bool haveEnergy = (!needsEnergy[step] || savedEnergy.find(forceGroups) != savedEnergy.end());
if (!haveForces || !haveEnergy) {
if (forcesAreValid) { if (forcesAreValid) {
if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) { if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old // The forces are still valid. We just need a different force group right now. Save the old
...@@ -7621,16 +7626,16 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7621,16 +7626,16 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
bool computeForce = (needsForces[step] || computeBothForceAndEnergy[step]); bool computeForce = (needsForces[step] || computeBothForceAndEnergy[step]);
bool computeEnergy = (needsEnergy[step] || computeBothForceAndEnergy[step]); bool computeEnergy = (needsEnergy[step] || computeBothForceAndEnergy[step]);
if (!computeEnergy && validSavedForces.find(forceGroupFlags[step]) != validSavedForces.end()) { if (!computeEnergy && validSavedForces.find(forceGroups) != validSavedForces.end()) {
// We can just restore the forces we saved earlier. // We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[step]]->copyTo(cu.getForce()); savedForces[forceGroups]->copyTo(cu.getForce());
context.getLastForceGroups() = forceGroupFlags[step]; context.getLastForceGroups() = forceGroups;
} }
else { else {
recordChangedParameters(context); recordChangedParameters(context);
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]); energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroups);
energyFloat = (float) energy; savedEnergy[forceGroups] = energy;
if (needsEnergyParamDerivs) { if (needsEnergyParamDerivs) {
context.getEnergyParameterDerivatives(energyParamDerivs); context.getEnergyParameterDerivatives(energyParamDerivs);
if (perDofEnergyParamDerivNames.size() > 0) { if (perDofEnergyParamDerivNames.size() > 0) {
...@@ -7649,6 +7654,10 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7649,6 +7654,10 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
} }
forcesAreValid = true; forcesAreValid = true;
} }
if (needsEnergy[step]) {
energy = savedEnergy[forceGroups];
energyFloat = (float) energy;
}
if (needsGlobals[step] && !deviceGlobalsAreCurrent) { if (needsGlobals[step] && !deviceGlobalsAreCurrent) {
// Upload the global values to the device. // Upload the global values to the device.
...@@ -7725,8 +7734,10 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7725,8 +7734,10 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
if (blockEnd[step] != -1) if (blockEnd[step] != -1)
nextStep = blockEnd[step]; // Return to the start of a while block. nextStep = blockEnd[step]; // Return to the start of a while block.
} }
if (invalidatesForces[step]) if (invalidatesForces[step]) {
forcesAreValid = false; forcesAreValid = false;
savedEnergy.clear();
}
step = nextStep; step = nextStep;
} }
recordChangedParameters(context); recordChangedParameters(context);
......
...@@ -1494,6 +1494,7 @@ private: ...@@ -1494,6 +1494,7 @@ private:
OpenCLArray* randomSeed; OpenCLArray* randomSeed;
OpenCLArray* perDofEnergyParamDerivs; OpenCLArray* perDofEnergyParamDerivs;
std::vector<OpenCLArray*> tabulatedFunctions; std::vector<OpenCLArray*> tabulatedFunctions;
std::map<int, double> savedEnergy;
std::map<int, OpenCLArray*> savedForces; std::map<int, OpenCLArray*> savedForces;
std::set<int> validSavedForces; std::set<int> validSavedForces;
OpenCLParameterSet* perDofValues; OpenCLParameterSet* perDofValues;
......
...@@ -7937,13 +7937,18 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -7937,13 +7937,18 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities(); OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
int numAtoms = cl.getNumAtoms(); int numAtoms = cl.getNumAtoms();
int numSteps = integrator.getNumComputations(); int numSteps = integrator.getNumComputations();
if (!forcesAreValid)
savedEnergy.clear();
// Loop over computation steps in the integrator and execute them. // Loop over computation steps in the integrator and execute them.
for (int step = 0; step < numSteps; ) { for (int step = 0; step < numSteps; ) {
int nextStep = step+1; int nextStep = step+1;
int forceGroups = forceGroupFlags[step];
int lastForceGroups = context.getLastForceGroups(); int lastForceGroups = context.getLastForceGroups();
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) { bool haveForces = (!needsForces[step] || (forcesAreValid && lastForceGroups == forceGroups));
bool haveEnergy = (!needsEnergy[step] || savedEnergy.find(forceGroups) != savedEnergy.end());
if (!haveForces || !haveEnergy) {
if (forcesAreValid) { if (forcesAreValid) {
if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) { if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old // The forces are still valid. We just need a different force group right now. Save the old
...@@ -7961,15 +7966,16 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -7961,15 +7966,16 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
bool computeForce = (needsForces[step] || computeBothForceAndEnergy[step]); bool computeForce = (needsForces[step] || computeBothForceAndEnergy[step]);
bool computeEnergy = (needsEnergy[step] || computeBothForceAndEnergy[step]); bool computeEnergy = (needsEnergy[step] || computeBothForceAndEnergy[step]);
if (!computeEnergy && validSavedForces.find(forceGroupFlags[step]) != validSavedForces.end()) { if (!computeEnergy && validSavedForces.find(forceGroups) != validSavedForces.end()) {
// We can just restore the forces we saved earlier. // We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[step]]->copyTo(cl.getForce()); savedForces[forceGroups]->copyTo(cl.getForce());
context.getLastForceGroups() = forceGroupFlags[step]; context.getLastForceGroups() = forceGroups;
} }
else { else {
recordChangedParameters(context); recordChangedParameters(context);
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]); energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroups);
savedEnergy[forceGroups] = energy;
if (needsEnergyParamDerivs) { if (needsEnergyParamDerivs) {
context.getEnergyParameterDerivatives(energyParamDerivs); context.getEnergyParameterDerivatives(energyParamDerivs);
if (perDofEnergyParamDerivNames.size() > 0) { if (perDofEnergyParamDerivNames.size() > 0) {
...@@ -7988,6 +7994,8 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -7988,6 +7994,8 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
forcesAreValid = true; forcesAreValid = true;
} }
} }
if (needsEnergy[step])
energy = savedEnergy[forceGroups];
if (needsGlobals[step] && !deviceGlobalsAreCurrent) { if (needsGlobals[step] && !deviceGlobalsAreCurrent) {
// Upload the global values to the device. // Upload the global values to the device.
...@@ -8067,8 +8075,10 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -8067,8 +8075,10 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
if (blockEnd[step] != -1) if (blockEnd[step] != -1)
nextStep = blockEnd[step]; // Return to the start of a while block. nextStep = blockEnd[step]; // Return to the start of a while block.
} }
if (invalidatesForces[step]) if (invalidatesForces[step]) {
forcesAreValid = false; forcesAreValid = false;
savedEnergy.clear();
}
step = nextStep; step = nextStep;
} }
recordChangedParameters(context); recordChangedParameters(context);
......
/* Portions copyright (c) 2011-2016 Stanford University and Simbios. /* Portions copyright (c) 2011-2017 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -219,19 +219,26 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -219,19 +219,26 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
for (auto& global : globals) for (auto& global : globals)
expressionSet.setVariable(expressionSet.getVariableIndex(global.first), global.second); expressionSet.setVariable(expressionSet.getVariableIndex(global.first), global.second);
oldPos = atomCoordinates; oldPos = atomCoordinates;
map<int, double> groupEnergy;
map<int, vector<Vec3> > groupForces;
if (forcesAreValid)
groupForces[context.getLastForceGroups()] = forces;
// Loop over steps and execute them. // Loop over steps and execute them.
for (int step = 0; step < numSteps; ) { for (int step = 0; step < numSteps; ) {
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || context.getLastForceGroups() != forceGroupFlags[step])) { int flags = forceGroupFlags[step];
if ((needsForces[step] && groupForces.find(flags) == groupForces.end()) || (needsEnergy[step] && groupEnergy.find(flags) == groupEnergy.end())) {
// Recompute forces and/or energy. // Recompute forces and/or energy.
bool computeForce = needsForces[step] || computeBothForceAndEnergy[step]; bool computeForce = needsForces[step] || computeBothForceAndEnergy[step];
bool computeEnergy = needsEnergy[step] || computeBothForceAndEnergy[step]; bool computeEnergy = needsEnergy[step] || computeBothForceAndEnergy[step];
recordChangedParameters(context, globals); recordChangedParameters(context, globals);
double e = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]); double e = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]);
if (computeForce)
groupForces[flags] = forces;
if (computeEnergy) { if (computeEnergy) {
energy = e; groupEnergy[flags] = e;
context.getEnergyParameterDerivatives(energyParamDerivs); context.getEnergyParameterDerivatives(energyParamDerivs);
} }
forcesAreValid = true; forcesAreValid = true;
...@@ -239,6 +246,8 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -239,6 +246,8 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
// Execute the step. // Execute the step.
energy = (needsEnergy[step] ? groupEnergy[flags] : 0);
vector<Vec3>& stepForces = (needsForces[step] ? groupForces[flags] : forces);
int nextStep = step+1; int nextStep = step+1;
switch (stepType[step]) { switch (stepType[step]) {
case CustomIntegrator::ComputeGlobal: { case CustomIntegrator::ComputeGlobal: {
...@@ -262,11 +271,11 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -262,11 +271,11 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
} }
if (results == NULL) if (results == NULL)
throw OpenMMException("Illegal per-DOF output variable: "+stepVariable[step]); throw OpenMMException("Illegal per-DOF output variable: "+stepVariable[step]);
computePerDof(numberOfAtoms, *results, atomCoordinates, velocities, forces, masses, perDof, stepExpressions[step][0]); computePerDof(numberOfAtoms, *results, atomCoordinates, velocities, stepForces, masses, perDof, stepExpressions[step][0]);
break; break;
} }
case CustomIntegrator::ComputeSum: { case CustomIntegrator::ComputeSum: {
computePerDof(numberOfAtoms, sumBuffer, atomCoordinates, velocities, forces, masses, perDof, stepExpressions[step][0]); computePerDof(numberOfAtoms, sumBuffer, atomCoordinates, velocities, stepForces, masses, perDof, stepExpressions[step][0]);
double sum = 0.0; double sum = 0.0;
for (int j = 0; j < numberOfAtoms; j++) for (int j = 0; j < numberOfAtoms; j++)
if (masses[j] != 0.0) if (masses[j] != 0.0)
...@@ -308,8 +317,11 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -308,8 +317,11 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
break; break;
} }
} }
if (invalidatesForces[step]) if (invalidatesForces[step]) {
forcesAreValid = false; forcesAreValid = false;
groupForces.clear();
groupEnergy.clear();
}
step = nextStep; step = nextStep;
} }
ReferenceVirtualSites::computePositions(context.getSystem(), atomCoordinates); ReferenceVirtualSites::computePositions(context.getSystem(), atomCoordinates);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment