Commit 07d005fa authored by peastman's avatar peastman
Browse files

LocalEnergyMinimizer switches to CPU if forces are getting clipped

parent 7164109e
......@@ -31,10 +31,12 @@
#include "openmm/LocalEnergyMinimizer.h"
#include "openmm/OpenMMException.h"
#include "lbfgs.h"
#include "openmm/Platform.h"
#include "openmm/VerletIntegrator.h"
#include "lbfgs.h"
#include <cmath>
#include <sstream>
#include <string>
#include <vector>
#include <algorithm>
......@@ -44,26 +46,43 @@ using namespace std;
struct MinimizerData {
Context& context;
double k;
MinimizerData(Context& context, double k)
: context(context), k(k) {}
bool checkLargeForces;
VerletIntegrator cpuIntegrator;
Context* cpuContext;
MinimizerData(Context& context, double k) : context(context), k(k), cpuIntegrator(1.0), cpuContext(NULL) {
string platformName = context.getPlatform().getName();
checkLargeForces = (platformName == "CUDA" || platformName == "OpenCL");
}
~MinimizerData() {
if (cpuContext != NULL)
delete cpuContext;
}
Context& getCpuContext() {
// Get an alternate context that runs on the CPU and doesn't place any limits
// on the magnitude of forces.
if (cpuContext == NULL) {
Platform* cpuPlatform;
try {
cpuPlatform = &Platform::getPlatformByName("CPU");
}
catch (...) {
cpuPlatform = &Platform::getPlatformByName("Reference");
}
cpuContext = new Context(context.getSystem(), cpuIntegrator, *cpuPlatform);
cpuContext->setState(context.getState(State::Positions | State::Velocities | State::Parameters));
}
return *cpuContext;
}
};
static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsfloatval_t *g, const int n, const lbfgsfloatval_t step) {
MinimizerData* data = reinterpret_cast<MinimizerData*>(instance);
Context& context = data->context;
const System& system = context.getSystem();
int numParticles = system.getNumParticles();
// Compute the force and energy for this configuration.
vector<Vec3> positions(numParticles);
for (int i = 0; i < numParticles; i++)
positions[i] = Vec3(x[3*i], x[3*i+1], x[3*i+2]);
static double computeForcesAndEnergy(Context& context, const vector<Vec3>& positions, lbfgsfloatval_t *g) {
context.setPositions(positions);
context.computeVirtualSites();
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
for (int i = 0; i < numParticles; i++) {
const System& system = context.getSystem();
for (int i = 0; i < forces.size(); i++) {
if (system.getParticleMass(i) == 0) {
g[3*i] = 0.0;
g[3*i+1] = 0.0;
......@@ -75,7 +94,33 @@ static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsf
g[3*i+2] = -forces[i][2];
}
}
double energy = state.getPotentialEnergy();
return state.getPotentialEnergy();
}
static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsfloatval_t *g, const int n, const lbfgsfloatval_t step) {
MinimizerData* data = reinterpret_cast<MinimizerData*>(instance);
Context& context = data->context;
const System& system = context.getSystem();
int numParticles = system.getNumParticles();
// Compute the force and energy for this configuration.
vector<Vec3> positions(numParticles);
for (int i = 0; i < numParticles; i++)
positions[i] = Vec3(x[3*i], x[3*i+1], x[3*i+2]);
double energy = computeForcesAndEnergy(context, positions, g);
if (data->checkLargeForces) {
// The CUDA and OpenCL platforms accumulate forces in fixed point, so they
// can't handle very large forces. Check for problematic forces (very large,
// infinite, or NaN) and if necessary recompute them on the CPU.
for (int i = 0; i < 3*numParticles; i++) {
if (!(fabs(g[i]) < 2e9)) {
energy = computeForcesAndEnergy(data->getCpuContext(), positions, g);
break;
}
}
}
// Add harmonic forces for any constraints.
......@@ -143,11 +188,11 @@ void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxI
// Repeatedly minimize, steadily increasing the strength of the springs until all constraints are satisfied.
double prevMaxError = 1e10;
MinimizerData data(context, k);
while (true) {
// Perform the minimization.
lbfgsfloatval_t fx;
MinimizerData data(context, k);
lbfgs(numParticles*3, x, &fx, evaluate, NULL, &data, &param);
// Check whether all constraints are satisfied.
......@@ -171,7 +216,7 @@ void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxI
if (maxError >= prevMaxError)
break; // Further tightening the springs doesn't seem to be helping, so just give up.
prevMaxError = maxError;
k *= 10;
data.k *= 10;
if (maxError > 100*workingConstraintTol) {
// We've gotten far enough from a valid state that we might have trouble getting
// back, so reset to the original positions.
......
......@@ -196,6 +196,40 @@ void testVirtualSites() {
ASSERT(forceNorm < 2*tolerance);
}
void testLargeForces() {
// Create a set of particles that are almost on top of each other so the initial
// forces are huge.
const int numParticles = 10;
System system;
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
nonbonded->addParticle(1.0, 0.2, 1.0);
}
vector<Vec3> positions(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++)
positions[i] = Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))*1e-10;
// Minimize it and verify that it didn't blow up.
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
LocalEnergyMinimizer::minimize(context, 1.0);
State state = context.getState(State::Positions);
double maxdist = 0.0;
for (int i = 0; i < numParticles; i++) {
Vec3 r = state.getPositions()[i];
maxdist = max(maxdist, sqrt(r.dot(r)));
}
ASSERT(maxdist > 0.1);
ASSERT(maxdist < 10.0);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -204,6 +238,7 @@ int main(int argc, char* argv[]) {
testHarmonicBonds();
testLargeSystem();
testVirtualSites();
testLargeForces();
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment