Unverified Commit b247e7d8 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Report status during minimization (#4207)

* Implemented MinimizationReporter

* PythonAPI for MinimizationReporter

* Improved test case

* SWIG fix

* Reporter returns a bool instead of throwing an exception
parent 9a0db725
......@@ -62,6 +62,7 @@ that they aren't important!
:maxdepth: 2
generated/LocalEnergyMinimizer
generated/MinimizationReporter
generated/NoseHooverChain
generated/OpenMMException
generated/Vec3
......@@ -33,9 +33,79 @@
* -------------------------------------------------------------------------- */
#include "Context.h"
#include <map>
#include <string>
#include <vector>
namespace OpenMM {
/**
* A MinimizationReporter can be passed to LocalEnergyMinimizer::minimize() to provide
* periodic information on the progress of minimization, and to give you the chance to
* stop minimization early. Define a subclass that overrides report() and implement it
* to take whatever action you want.
*
* To correctly interpret the information passed to the reporter, you need to know a bit
* about how the minimizer works. The L-BFGS algorithm used by the minimizer does not
* support constraints. The minimizer therefore replaces all constraints with harmonic
* restraints, then performs unconstrained minimization of a combined objective function
* that is the sum of the system's potential energy and the restraint energy. Once
* minimization completes, it checks whether all constraints are satisfied to an acceptable
* tolerance. It not, it increases the strength of the harmonic restraints and performs
* additional minimization. If the error in constrained distances is especially large,
* it may choose to throw out all work that has been done so far and start over with
* stronger restraints. This has several important consequences.
*
* <ul>
* <li>The objective function being minimized not actually the same as the potential energy.</li>
* <li>The objective function and the potential energy can both increase between iterations.</li>
* <li>The total number of iterations performed could be larger than the number specified
* by the maxIterations argument, if that many iterations leaves unacceptable constraint errors.</li>
* <li>All work is provisional. It is possible for the minimizer to throw it out and start over.</li>
* </ul>
*/
class OPENMM_EXPORT MinimizationReporter {
public:
MinimizationReporter() {
}
virtual ~MinimizationReporter() {
}
/**
* This is called after each iteration to provide information about the current status
* of minimization. It receives the current particle coordinates, the gradient of the
* objective function with respect to them, and a set of useful statistics. In particular,
* args contains these values:
*
* "system energy": the current potential energy of the system
*
* "restraint energy": the energy of the harmonic restraints
*
* "restraint strength": the force constant of the restraints (in kJ/mol/nm^2)
*
* "max constraint error": the maximum relative error in the length of any constraint
*
* If this function returns true, it will cause the L-BFGS optimizer to immediately
* exit. If all constrained distances are sufficiently close to their target values,
* minimize() will return. If any constraint error is unacceptably large, it will instead
* cause the minimizer to immediately increase the strength of the harmonic restraints and
* perform additional optimization.
*
* @param iteration the index of the current iteration. This refers to the current call
* to the L-BFGS optimizer. Each time the minimizer increases the restraint
* strength, the iteration index is reset to 0.
* @param x the current particle positions in flattened order: the three coordinates
* of the first particle, then the three coordinates of the second particle, etc.
* @param grad the current gradient of the objective function (potential energy plus
* restraint energy) with respect to the particle coordinates, in flattened
* order
* @param args additional statistics described above about the current state of minimization
* @return whether to immediately stop minimization
*/
virtual bool report(int iteration, const std::vector<double>& x, const std::vector<double>& grad, std::map<std::string, double>& args) {
return false;
}
};
/**
* Given a Context, this class searches for a new set of particle positions that represent
* a local minimum of the potential energy. The search is performed with the L-BFGS algorithm.
......@@ -62,8 +132,10 @@ public:
* @param maxIterations the maximum number of iterations to perform. If this is 0, minimation is continued
* until the results converge without regard to how many iterations it takes. The
* default value is 0.
* @param reporter an optional MinimizationReporter to invoke after each iteration. This can be used
* to monitor the progress of minimization or to stop minimization early.
*/
static void minimize(Context& context, double tolerance = 10, int maxIterations = 0);
static void minimize(Context& context, double tolerance = 10, int maxIterations = 0, MinimizationReporter* reporter = NULL);
};
} // namespace OpenMM
......
......@@ -46,10 +46,12 @@ using namespace std;
struct MinimizerData {
Context& context;
double k;
MinimizationReporter* reporter;
bool checkLargeForces;
VerletIntegrator cpuIntegrator;
Context* cpuContext;
MinimizerData(Context& context, double k) : context(context), k(k), cpuIntegrator(1.0), cpuContext(NULL) {
MinimizerData(Context& context, double k, MinimizationReporter* reporter) :
context(context), k(k), reporter(reporter), cpuIntegrator(1.0), cpuContext(NULL) {
string platformName = context.getPlatform().getName();
checkLargeForces = (platformName == "CUDA" || platformName == "OpenCL" || platformName == "HIP" || platformName == "Metal");
}
......@@ -151,7 +153,49 @@ static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsf
return energy;
}
void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxIterations) {
static int report(void *instance, const lbfgsfloatval_t *x, const lbfgsfloatval_t *g, const lbfgsfloatval_t fx,
const lbfgsfloatval_t xnorm, const lbfgsfloatval_t gnorm, const lbfgsfloatval_t step, int n, int iteration, int ls) {
// Copy over the positions and gradients.
vector<double> xout(n), gradout(n);
for (int i = 0; i < n; i++) {
xout[i] = x[i];
gradout[i] = g[i];
}
// Compute the other arguments passed to the reporter.
MinimizerData* data = reinterpret_cast<MinimizerData*>(instance);
Context& context = data->context;
const System& system = context.getSystem();
double restraintEnergy = 0.0, maxError = 0.0;
double k = data->k;
for (int i = 0; i < system.getNumConstraints(); i++) {
int p1, p2;
double distance;
system.getConstraintParameters(i, p1, p2, distance);
Vec3 delta(x[3*p1]-x[3*p2], x[3*p1+1]-x[3*p2+1], x[3*p1+2]-x[3*p2+2]);
double r2 = delta.dot(delta);
double r = sqrt(r2);
double dr = r-distance;
restraintEnergy += 0.5*k*dr*dr;
maxError = max(maxError, fabs(dr)/distance);
}
map<string, double> args;
args["restraint energy"] = restraintEnergy;
args["system energy"] = fx-restraintEnergy;
args["restraint strength"] = k;
args["max constraint error"] = maxError;
// Invoke the reporter.
MinimizationReporter* reporter = reinterpret_cast<MinimizationReporter*>(data->reporter);
if (reporter->report(iteration-1, xout, gradout, args))
return 1;
return 0;
}
void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxIterations, MinimizationReporter* reporter) {
const System& system = context.getSystem();
int numParticles = system.getNumParticles();
double constraintTol = context.getIntegrator().getConstraintTolerance();
......@@ -192,12 +236,13 @@ void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxI
// Repeatedly minimize, steadily increasing the strength of the springs until all constraints are satisfied.
double prevMaxError = 1e10;
MinimizerData data(context, k);
MinimizerData data(context, k, reporter);
while (true) {
// Perform the minimization.
lbfgsfloatval_t fx;
lbfgs(numParticles*3, x, &fx, evaluate, NULL, &data, &param);
lbfgs_progress_t reportFn = (reporter == NULL ? NULL : report);
lbfgs(numParticles*3, x, &fx, evaluate, reportFn, &data, &param);
// Check whether all constraints are satisfied.
......@@ -210,7 +255,7 @@ void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxI
system.getConstraintParameters(i, particle1, particle2, distance);
Vec3 delta = positions[particle2]-positions[particle1];
double r = sqrt(delta.dot(delta));
double error = fabs(r-distance);
double error = fabs(r-distance)/distance;
if (error > maxError)
maxError = error;
}
......
......@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2010-2020 Stanford University and the Authors. *
* Portions copyright (c) 2010-2023 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -32,6 +32,7 @@
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/Context.h"
#include "openmm/CustomExternalForce.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/LocalEnergyMinimizer.h"
#include "openmm/NonbondedForce.h"
......@@ -295,6 +296,54 @@ void testMasslessParticles() {
ASSERT_EQUAL_TOL(1.05, sqrt(delta.dot(delta)), 1e-4);
}
void testReporter() {
const int numParticles = 30;
System system;
CustomExternalForce* force = new CustomExternalForce("sin(5*x)+cos(2*y)*(sin(3*z)+1.5)");
system.addForce(force);
vector<Vec3> positions;
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
force->addParticle(i);
positions.push_back(Vec3(0.5*i, 0.3*sin(i), 0.2*cos(i)));
if (i > 0)
system.addConstraint(i-1, i, 1.0);
}
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
context.applyConstraints(1e-5);
class Reporter : public MinimizationReporter {
public:
int lastIter = 0;
double lastK = 0, lastEnergy = 0;
bool canceled = false, success = true;
bool report(int iteration, const vector<double>& x, const vector<double>& grad, map<string, double>& args) {
double k = args["restraint strength"];
if (iteration > 0)
success &= (iteration == lastIter+1 && k == lastK) | (iteration == 0 && k > lastK);
if (canceled)
success &= (iteration == 0 && k > lastK);
lastEnergy = args["system energy"];
if (iteration > 300 && args["max constraint error"] > 1e-4) {
canceled = true;
return true;
}
canceled = false;
lastIter = iteration;
lastK = k;
return false;
}
};
Reporter reporter;
LocalEnergyMinimizer::minimize(context, 1.0, 0, &reporter);
ASSERT(reporter.success);
State state = context.getState(State::Energy);
ASSERT_EQUAL_TOL(state.getPotentialEnergy(), reporter.lastEnergy, 1e-5);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -306,6 +355,7 @@ int main(int argc, char* argv[]) {
testLargeForces();
testForceGroups();
testMasslessParticles();
testReporter();
runPlatformTests();
}
catch(const exception& e) {
......
......@@ -123,7 +123,7 @@ class Simulation(object):
def currentStep(self, step):
self.context.setStepCount(step)
def minimizeEnergy(self, tolerance=10*unit.kilojoules_per_mole/unit.nanometer, maxIterations=0):
def minimizeEnergy(self, tolerance=10*unit.kilojoules_per_mole/unit.nanometer, maxIterations=0, reporter=None):
"""Perform a local energy minimization on the system.
Parameters
......@@ -136,8 +136,11 @@ class Simulation(object):
The maximum number of iterations to perform. If this is 0,
minimization is continued until the results converge without regard
to how many iterations it takes.
reporter : MinimizationReporter = None
an optional reporter to invoke after each iteration. This can be used to monitor the progress
of minimization or to stop minimization early.
"""
mm.LocalEnergyMinimizer.minimize(self.context, tolerance, maxIterations)
mm.LocalEnergyMinimizer.minimize(self.context, tolerance, maxIterations, reporter)
def step(self, steps):
"""Advance the simulation by integrating a specified number of time steps."""
......
%module openmm
%module(directors="1") openmm
%include "factory.i"
%include "std_string.i"
......
......@@ -39,17 +39,26 @@
}
%exception OpenMM::LocalEnergyMinimizer::minimize {
PyThreadState* _savePythonThreadState = PyEval_SaveThread();
bool releaseGIL = (nobjs < 4 || swig_obj[3] == Py_None);
PyThreadState* _savePythonThreadState = (releaseGIL ? PyEval_SaveThread() : nullptr);
try {
$action
} catch (std::exception &e) {
PyEval_RestoreThread(_savePythonThreadState);
PyObject* mm = PyImport_AddModule("openmm");
PyObject* openmm_exception = PyObject_GetAttrString(mm, "OpenMMException");
PyErr_SetString(openmm_exception, const_cast<char*>(e.what()));
return NULL;
}
PyEval_RestoreThread(_savePythonThreadState);
catch (std::exception &e) {
if (releaseGIL)
PyEval_RestoreThread(_savePythonThreadState);
if (dynamic_cast<Swig::DirectorException*>(&e) != NULL) {
SWIG_fail;
}
else {
PyObject* mm = PyImport_AddModule("openmm");
PyObject* openmm_exception = PyObject_GetAttrString(mm, "OpenMMException");
PyErr_SetString(openmm_exception, const_cast<char*>(e.what()));
return NULL;
}
}
if (releaseGIL)
PyEval_RestoreThread(_savePythonThreadState);
}
%exception OpenMM::Context::setVelocitiesToTemperature {
......
......@@ -9,3 +9,5 @@
%include pythonprepend.i
%include pythonappend.i
%include typemaps.i
%feature("director") OpenMM::MinimizationReporter;
......@@ -197,6 +197,33 @@ class TestSimulation(unittest.TestCase):
simulation.step(500)
def testMinimizationReporter(self):
"""Test invoking a reporter during minimization."""
pdb = PDBFile('systems/alanine-dipeptide-implicit.pdb')
ff = ForceField('amber99sb.xml', 'tip3p.xml')
system = ff.createSystem(pdb.topology)
integrator = LangevinIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)
simulation = Simulation(pdb.topology, system, integrator)
simulation.context.setPositions(pdb.positions)
class Reporter(MinimizationReporter):
lastIteration = -1
error = False
def report(self, iteration, x, grad, args):
if iteration != self.lastIteration+1:
self.error = True
self.lastIteration = iteration
if iteration == 10:
return True
if iteration > 10:
self.error = True
return False
reporter = Reporter()
simulation.minimizeEnergy(reporter=reporter)
assert not reporter.error
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment