Commit 9695c4cf authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bug in CustomIntegrator

parent a351c396
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2015-2016 Stanford University and the Authors. * * Portions copyright (c) 2015-2019 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -83,6 +83,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -83,6 +83,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
forceGroup.resize(numSteps, -2); forceGroup.resize(numSteps, -2);
vector<CustomIntegrator::ComputationType> stepType(numSteps); vector<CustomIntegrator::ComputationType> stepType(numSteps);
vector<string> stepVariable(numSteps); vector<string> stepVariable(numSteps);
vector<bool> alwaysInvalidatesForces(numSteps, false);
map<string, Lepton::CustomFunction*> customFunctions = functions; map<string, Lepton::CustomFunction*> customFunctions = functions;
Lepton::PlaceholderFunction fn1(1), fn2(2), fn3(3); Lepton::PlaceholderFunction fn1(1), fn2(2), fn3(3);
customFunctions["deriv"] = &fn2; customFunctions["deriv"] = &fn2;
...@@ -120,9 +121,10 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -120,9 +121,10 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
for (auto force : context.getForceImpls()) for (auto force : context.getForceImpls())
for (auto& param : force->getDefaultParameters()) for (auto& param : force->getDefaultParameters())
affectsForce.insert(param.first); affectsForce.insert(param.first);
for (int i = 0; i < numSteps; i++) for (int i = 0; i < numSteps; i++) {
invalidatesForces[i] = (stepType[i] == CustomIntegrator::ConstrainPositions || stepType[i] == CustomIntegrator::UpdateContextState || alwaysInvalidatesForces[i] = (stepType[i] == CustomIntegrator::ConstrainPositions || affectsForce.find(stepVariable[i]) != affectsForce.end());
affectsForce.find(stepVariable[i]) != affectsForce.end()); invalidatesForces[i] = (alwaysInvalidatesForces[i] || stepType[i] == CustomIntegrator::UpdateContextState);
}
// Make a list of which steps require valid forces or energy to be known. // Make a list of which steps require valid forces or energy to be known.
...@@ -200,10 +202,16 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -200,10 +202,16 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
// or don't. For each "while" block there are three possibilities: don't execute it; execute it and then // or don't. For each "while" block there are three possibilities: don't execute it; execute it and then
// continue on; or execute it and then jump back to the beginning. I'm assuming the number of blocks will // continue on; or execute it and then jump back to the beginning. I'm assuming the number of blocks will
// always remain small. Otherwise, this could become very expensive! // always remain small. Otherwise, this could become very expensive!
//
// We also need to consider two full passes through the algorithm. That way, we detect if a step at the beginning
// means a step at the end should compute both forces and energy.
vector<int> jumps(numSteps, -1); vector<int> jumps(2*numSteps, -1);
vector<int> stepsInPath; vector<int> stepsInPath;
enumeratePaths(0, stepsInPath, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth); int numBlocks = blockEnd.size();
for (int i = 0; i < numBlocks; i++)
blockEnd.push_back(blockEnd[i]+numSteps);
enumeratePaths(0, stepsInPath, jumps, blockEnd, stepType, needsForces, needsEnergy, alwaysInvalidatesForces, forceGroup, computeBoth);
// Make sure calls to deriv() all valid. // Make sure calls to deriv() all valid.
...@@ -219,8 +227,9 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, ...@@ -219,8 +227,9 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps,
const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth) { const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth) {
int step = firstStep; int step = firstStep;
int numSteps = stepType.size(); int numSteps = stepType.size();
while (step < numSteps) { while (step < 2*numSteps) {
steps.push_back(step); steps.push_back(step);
int index = step % stepType.size();
if (jumps[step] > 0) { if (jumps[step] > 0) {
// Follow the jump and remove it from the list. // Follow the jump and remove it from the list.
...@@ -228,7 +237,7 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, ...@@ -228,7 +237,7 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps,
jumps[step] = -1; jumps[step] = -1;
step = nextStep; step = nextStep;
} }
else if (stepType[step] == CustomIntegrator::IfBlockStart) { else if (stepType[index] == CustomIntegrator::IfBlockStart) {
// Consider skipping the block. // Consider skipping the block.
enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth); enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth);
...@@ -237,7 +246,7 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, ...@@ -237,7 +246,7 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps,
step++; step++;
} }
else if (stepType[step] == CustomIntegrator::WhileBlockStart && jumps[step] != -2) { else if (stepType[index] == CustomIntegrator::WhileBlockStart && jumps[step] != -2) {
// Consider skipping the block. // Consider skipping the block.
enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth); enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth);
...@@ -262,21 +271,24 @@ void CustomIntegratorUtilities::analyzeForceComputationsForPath(vector<int>& ste ...@@ -262,21 +271,24 @@ void CustomIntegratorUtilities::analyzeForceComputationsForPath(vector<int>& ste
const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth) { const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth) {
vector<pair<int, int> > candidatePoints; vector<pair<int, int> > candidatePoints;
for (int step : steps) { for (int step : steps) {
if (invalidatesForces[step]) { int index = step % computeBoth.size();
if (invalidatesForces[index]) {
// Forces and energies are invalidated at this step, so anything from this point on won't affect what we do at earlier steps. // Forces and energies are invalidated at this step, so anything from this point on won't affect what we do at earlier steps.
candidatePoints.clear(); candidatePoints.clear();
} }
if (needsForces[step] || needsEnergy[step]) { if (needsForces[index] || needsEnergy[index]) {
// See if this step affects what we do at earlier points. // See if this step affects what we do at earlier points.
for (auto candidate : candidatePoints) for (auto candidate : candidatePoints) {
if (candidate.second == forceGroup[step] && ((needsForces[candidate.first] && needsEnergy[step]) || (needsEnergy[candidate.first] && needsForces[step]))) int candidateIndex = candidate.first % computeBoth.size();
computeBoth[candidate.first] = true; if (candidate.second == forceGroup[index] && ((needsForces[candidateIndex] && needsEnergy[index]) || (needsEnergy[candidateIndex] && needsForces[index])))
computeBoth[candidateIndex] = true;
}
// Add this to the list of candidates that might be affected by later steps. // Add this to the list of candidates that might be affected by later steps.
candidatePoints.push_back(make_pair(step, forceGroup[step])); candidatePoints.push_back(make_pair(step, forceGroup[index]));
} }
} }
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2018 Stanford University and the Authors. * * Portions copyright (c) 2008-2019 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -1081,6 +1081,53 @@ void testVectorFunctions() { ...@@ -1081,6 +1081,53 @@ void testVectorFunctions() {
ASSERT_EQUAL_TOL(sumy, integrator.getGlobalVariable(0), 1e-5); ASSERT_EQUAL_TOL(sumy, integrator.getGlobalVariable(0), 1e-5);
} }
/**
* This test records energies at multiple points during the step and checks that
* they're correct.
*/
void testRecordEnergy() {
const int numParticles = 8;
System system;
CustomIntegrator integrator(0.002);
integrator.addGlobalVariable("startEnergy", 0);
integrator.addGlobalVariable("endEnergy", 0);
integrator.addUpdateContextState();
integrator.addComputePerDof("v", "v+0.5*dt*f/m");
integrator.addComputeGlobal("startEnergy", "energy");
integrator.addComputePerDof("x", "x+dt*v");
integrator.addComputeGlobal("endEnergy", "energy");
integrator.addConstrainPositions();
integrator.addComputePerDof("v", "v+0.5*dt*f/m");
NonbondedForce* forceField = new NonbondedForce();
for (int i = 0; i < numParticles; ++i) {
system.addParticle(i%2 == 0 ? 5.0 : 10.0);
forceField->addParticle((i%2 == 0 ? 0.2 : -0.2), 0.5, 5.0);
}
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(numParticles);
vector<Vec3> velocities(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; ++i) {
positions[i] = Vec3(i/2, (i+1)/2, 0);
velocities[i] = Vec3(genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5);
}
context.setPositions(positions);
context.setVelocities(velocities);
// Simulate it and see whether the energies are recorded correctly.
for (int i = 0; i < 10; ++i) {
double startEnergy = context.getState(State::Energy).getPotentialEnergy();
integrator.step(1);
double endEnergy = context.getState(State::Energy).getPotentialEnergy();
ASSERT_EQUAL_TOL(startEnergy, integrator.getGlobalVariable(0), 1e-6);
ASSERT_EQUAL_TOL(endEnergy, integrator.getGlobalVariable(1), 1e-6);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -1107,6 +1154,7 @@ int main(int argc, char* argv[]) { ...@@ -1107,6 +1154,7 @@ int main(int argc, char* argv[]) {
testAlternatingGroups(); testAlternatingGroups();
testUpdateContextState(); testUpdateContextState();
testVectorFunctions(); testVectorFunctions();
testRecordEnergy();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment