Unverified Commit 065e34ab authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimized Context creation with complex CustomIntegrators (#4191)

parent ac6133bb
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2015-2016 Stanford University and the Authors. * * Portions copyright (c) 2015-2023 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -81,7 +81,7 @@ private: ...@@ -81,7 +81,7 @@ private:
static bool usesVariable(const Lepton::ExpressionTreeNode& node, const std::string& variable); static bool usesVariable(const Lepton::ExpressionTreeNode& node, const std::string& variable);
static void enumeratePaths(int firstStep, std::vector<int> steps, std::vector<int> jumps, const std::vector<int>& blockEnd, static void enumeratePaths(int firstStep, std::vector<int> steps, std::vector<int> jumps, const std::vector<int>& blockEnd,
const std::vector<CustomIntegrator::ComputationType>& stepType, const std::vector<bool>& needsForces, const std::vector<bool>& needsEnergy, const std::vector<CustomIntegrator::ComputationType>& stepType, const std::vector<bool>& needsForces, const std::vector<bool>& needsEnergy,
const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth); const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth, const std::vector<bool>& isSignificant);
static void analyzeForceComputationsForPath(std::vector<int>& steps, const std::vector<bool>& needsForces, const std::vector<bool>& needsEnergy, static void analyzeForceComputationsForPath(std::vector<int>& steps, const std::vector<bool>& needsForces, const std::vector<bool>& needsEnergy,
const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth); const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth);
static void validateDerivatives(const Lepton::ExpressionTreeNode& node, const std::vector<std::string>& derivNames); static void validateDerivatives(const Lepton::ExpressionTreeNode& node, const std::vector<std::string>& derivNames);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2015-2019 Stanford University and the Authors. * * Portions copyright (c) 2015-2023 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -188,6 +188,20 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -188,6 +188,20 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
if (blockStart.size() > 0) if (blockStart.size() > 0)
throw OpenMMException("CustomIntegrator: Missing EndBlock"); throw OpenMMException("CustomIntegrator: Missing EndBlock");
// Identify whether each block contains any operation that either invalidates forces,
// or requires forces or energy. These are the ones that are significant for the
// analysis that follows.
vector<bool> isSignificant(numSteps, false);
for (int step = 0; step < numSteps; step++) {
if (stepType[step] == CustomIntegrator::IfBlockStart || stepType[step] == CustomIntegrator::WhileBlockStart)
for (int i = step; i < blockEnd[step]; i++)
if (needsForces[i] || needsEnergy[i] || invalidatesForces[i]) {
isSignificant[step] = true;
break;
}
}
// If a step requires either forces or energy, and a later step will require the other one, it's most efficient // If a step requires either forces or energy, and a later step will require the other one, it's most efficient
// to compute both at the same time. Figure out whether we should do that. In principle it's easy: step through // to compute both at the same time. Figure out whether we should do that. In principle it's easy: step through
// the sequence of computations and see if the other one is used before the next time they get invalidated. // the sequence of computations and see if the other one is used before the next time they get invalidated.
...@@ -211,7 +225,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -211,7 +225,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
int numBlocks = blockEnd.size(); int numBlocks = blockEnd.size();
for (int i = 0; i < numBlocks; i++) for (int i = 0; i < numBlocks; i++)
blockEnd.push_back(blockEnd[i]+numSteps); blockEnd.push_back(blockEnd[i]+numSteps);
enumeratePaths(0, stepsInPath, jumps, blockEnd, stepType, needsForces, needsEnergy, alwaysInvalidatesForces, forceGroup, computeBoth); enumeratePaths(0, stepsInPath, jumps, blockEnd, stepType, needsForces, needsEnergy, alwaysInvalidatesForces, forceGroup, computeBoth, isSignificant);
// Make sure calls to deriv() all valid. // Make sure calls to deriv() all valid.
...@@ -224,7 +238,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -224,7 +238,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, vector<int> jumps, const vector<int>& blockEnd, void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, vector<int> jumps, const vector<int>& blockEnd,
const vector<CustomIntegrator::ComputationType>& stepType, const vector<bool>& needsForces, const vector<bool>& needsEnergy, const vector<CustomIntegrator::ComputationType>& stepType, const vector<bool>& needsForces, const vector<bool>& needsEnergy,
const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth) { const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth, const vector<bool>& isSignificant) {
int step = firstStep; int step = firstStep;
int numSteps = stepType.size(); int numSteps = stepType.size();
while (step < 2*numSteps) { while (step < 2*numSteps) {
...@@ -237,23 +251,23 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, ...@@ -237,23 +251,23 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps,
jumps[step] = -1; jumps[step] = -1;
step = nextStep; step = nextStep;
} }
else if (stepType[index] == CustomIntegrator::IfBlockStart) { else if (stepType[index] == CustomIntegrator::IfBlockStart && isSignificant[index]) {
// Consider skipping the block. // Consider skipping the block.
enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth); enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth, isSignificant);
// Continue on to execute the block. // Continue on to execute the block.
step++; step++;
} }
else if (stepType[index] == CustomIntegrator::WhileBlockStart && jumps[step] != -2) { else if (stepType[index] == CustomIntegrator::WhileBlockStart && jumps[step] != -2 && isSignificant[index]) {
// Consider skipping the block. // Consider skipping the block.
enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth); enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth, isSignificant);
// Consider executing the block once. // Consider executing the block once.
enumeratePaths(step+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth); enumeratePaths(step+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth, isSignificant);
// Continue on to execute the block twice. // Continue on to execute the block twice.
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2020 Stanford University and the Authors. * * Portions copyright (c) 2008-2023 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#define _USE_MATH_DEFINES // Needed to get M_PI #define _USE_MATH_DEFINES // Needed to get M_PI
#endif #endif
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "openmm/internal/CustomIntegratorUtilities.h"
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/AndersenThermostat.h" #include "openmm/AndersenThermostat.h"
#include "openmm/CustomAngleForce.h" #include "openmm/CustomAngleForce.h"
...@@ -1211,6 +1212,73 @@ void testSaveParameters() { ...@@ -1211,6 +1212,73 @@ void testSaveParameters() {
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6); ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
} }
void testAnalyzeComputations() {
System system;
system.addParticle(1.0);
CustomBondForce* bond = new CustomBondForce("scale*r");
bond->addGlobalParameter("scale", 2.0);
bond->setForceGroup(1);
system.addForce(bond);
// Create a complex integrator with lots of nested blocks and steps that use or invalidate
// forces or energies.
CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("color", 1.5);
integrator.addPerDofVariable("z", 0);
integrator.addComputeGlobal("color", "energy"); // 0
integrator.beginIfBlock("color > 1.0"); // 1
integrator.addComputeGlobal("scale", "energy0"); // 2
integrator.endBlock(); // 3
integrator.beginIfBlock("scale < color"); // 4
integrator.addComputePerDof("v", "x"); // 5
integrator.endBlock(); // 6
integrator.addComputePerDof("z", "f1"); // 7
integrator.beginWhileBlock("energy2 > 0"); // 8
integrator.beginIfBlock("color = 1"); // 9
integrator.addComputePerDof("v", "2*z"); // 10
integrator.endBlock(); // 11
integrator.beginIfBlock("color = 2"); // 12
integrator.addComputeGlobal("color", "color+1"); // 13
integrator.addUpdateContextState(); // 14
integrator.endBlock(); // 15
integrator.endBlock(); // 16
integrator.addComputePerDof("x", "x+f"); // 17
// Call analyzeComputations() and see if the results are what we expect.
Context context(system, integrator, platform);
ContextImpl* contextImpl = *reinterpret_cast<ContextImpl**>(&context);
vector<vector<Lepton::ParsedExpression> > expressions;
vector<CustomIntegratorUtilities::Comparison> comparisons;
vector<int> blockEnd, forceGroup;
vector<bool> invalidatesForces, needsForces, needsEnergy, computeBoth;
map<string, Lepton::CustomFunction*> functions;
CustomIntegratorUtilities::analyzeComputations(*contextImpl, integrator, expressions, comparisons, blockEnd, invalidatesForces,
needsForces, needsEnergy, computeBoth, forceGroup, functions);
ASSERT_EQUAL(3, blockEnd[1]);
ASSERT_EQUAL(6, blockEnd[4]);
ASSERT_EQUAL(16, blockEnd[8]);
ASSERT_EQUAL(11, blockEnd[9]);
ASSERT_EQUAL(15, blockEnd[12]);
for (int i = 0; i < integrator.getNumComputations(); i++) {
ASSERT_EQUAL(i == 2 || i == 14 || i == 17, invalidatesForces[i]);
ASSERT_EQUAL(i == 7 || i == 17, needsForces[i]);
ASSERT_EQUAL(i == 0 || i == 2 || i == 8, needsEnergy[i]);
ASSERT_EQUAL(i == 17, computeBoth[i]);
if (needsForces[i] || needsEnergy[i]) {
int group = -1;
if (i == 2)
group = 0;
else if (i == 7)
group = 1;
else if (i == 8)
group = 2;
ASSERT_EQUAL(group, forceGroup[i]);
}
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -1241,6 +1309,7 @@ int main(int argc, char* argv[]) { ...@@ -1241,6 +1309,7 @@ int main(int argc, char* argv[]) {
testInitialTemperature(); testInitialTemperature();
testCheckpoint(); testCheckpoint();
testSaveParameters(); testSaveParameters();
testAnalyzeComputations();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment