Unverified Commit 32400ee5 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1984 from peastman/cpu

LocalEnergyMinimizer switches to CPU if forces are getting clipped
parents 460b0d4f 4cc77852
...@@ -31,10 +31,12 @@ ...@@ -31,10 +31,12 @@
#include "openmm/LocalEnergyMinimizer.h" #include "openmm/LocalEnergyMinimizer.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "lbfgs.h"
#include "openmm/Platform.h" #include "openmm/Platform.h"
#include "openmm/VerletIntegrator.h"
#include "lbfgs.h"
#include <cmath> #include <cmath>
#include <sstream> #include <sstream>
#include <string>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
...@@ -44,26 +46,43 @@ using namespace std; ...@@ -44,26 +46,43 @@ using namespace std;
struct MinimizerData { struct MinimizerData {
Context& context; Context& context;
double k; double k;
MinimizerData(Context& context, double k) bool checkLargeForces;
: context(context), k(k) {} VerletIntegrator cpuIntegrator;
Context* cpuContext;
MinimizerData(Context& context, double k) : context(context), k(k), cpuIntegrator(1.0), cpuContext(NULL) {
string platformName = context.getPlatform().getName();
checkLargeForces = (platformName == "CUDA" || platformName == "OpenCL");
}
~MinimizerData() {
if (cpuContext != NULL)
delete cpuContext;
}
Context& getCpuContext() {
// Get an alternate context that runs on the CPU and doesn't place any limits
// on the magnitude of forces.
if (cpuContext == NULL) {
Platform* cpuPlatform;
try {
cpuPlatform = &Platform::getPlatformByName("CPU");
}
catch (...) {
cpuPlatform = &Platform::getPlatformByName("Reference");
}
cpuContext = new Context(context.getSystem(), cpuIntegrator, *cpuPlatform);
cpuContext->setState(context.getState(State::Positions | State::Velocities | State::Parameters));
}
return *cpuContext;
}
}; };
static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsfloatval_t *g, const int n, const lbfgsfloatval_t step) { static double computeForcesAndEnergy(Context& context, const vector<Vec3>& positions, lbfgsfloatval_t *g) {
MinimizerData* data = reinterpret_cast<MinimizerData*>(instance);
Context& context = data->context;
const System& system = context.getSystem();
int numParticles = system.getNumParticles();
// Compute the force and energy for this configuration.
vector<Vec3> positions(numParticles);
for (int i = 0; i < numParticles; i++)
positions[i] = Vec3(x[3*i], x[3*i+1], x[3*i+2]);
context.setPositions(positions); context.setPositions(positions);
context.computeVirtualSites(); context.computeVirtualSites();
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
for (int i = 0; i < numParticles; i++) { const System& system = context.getSystem();
for (int i = 0; i < forces.size(); i++) {
if (system.getParticleMass(i) == 0) { if (system.getParticleMass(i) == 0) {
g[3*i] = 0.0; g[3*i] = 0.0;
g[3*i+1] = 0.0; g[3*i+1] = 0.0;
...@@ -75,7 +94,33 @@ static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsf ...@@ -75,7 +94,33 @@ static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsf
g[3*i+2] = -forces[i][2]; g[3*i+2] = -forces[i][2];
} }
} }
double energy = state.getPotentialEnergy(); return state.getPotentialEnergy();
}
static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsfloatval_t *g, const int n, const lbfgsfloatval_t step) {
MinimizerData* data = reinterpret_cast<MinimizerData*>(instance);
Context& context = data->context;
const System& system = context.getSystem();
int numParticles = system.getNumParticles();
// Compute the force and energy for this configuration.
vector<Vec3> positions(numParticles);
for (int i = 0; i < numParticles; i++)
positions[i] = Vec3(x[3*i], x[3*i+1], x[3*i+2]);
double energy = computeForcesAndEnergy(context, positions, g);
if (data->checkLargeForces) {
// The CUDA and OpenCL platforms accumulate forces in fixed point, so they
// can't handle very large forces. Check for problematic forces (very large,
// infinite, or NaN) and if necessary recompute them on the CPU.
for (int i = 0; i < 3*numParticles; i++) {
if (!(fabs(g[i]) < 2e9)) {
energy = computeForcesAndEnergy(data->getCpuContext(), positions, g);
break;
}
}
}
// Add harmonic forces for any constraints. // Add harmonic forces for any constraints.
...@@ -143,11 +188,11 @@ void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxI ...@@ -143,11 +188,11 @@ void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxI
// Repeatedly minimize, steadily increasing the strength of the springs until all constraints are satisfied. // Repeatedly minimize, steadily increasing the strength of the springs until all constraints are satisfied.
double prevMaxError = 1e10; double prevMaxError = 1e10;
MinimizerData data(context, k);
while (true) { while (true) {
// Perform the minimization. // Perform the minimization.
lbfgsfloatval_t fx; lbfgsfloatval_t fx;
MinimizerData data(context, k);
lbfgs(numParticles*3, x, &fx, evaluate, NULL, &data, &param); lbfgs(numParticles*3, x, &fx, evaluate, NULL, &data, &param);
// Check whether all constraints are satisfied. // Check whether all constraints are satisfied.
...@@ -171,7 +216,7 @@ void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxI ...@@ -171,7 +216,7 @@ void LocalEnergyMinimizer::minimize(Context& context, double tolerance, int maxI
if (maxError >= prevMaxError) if (maxError >= prevMaxError)
break; // Further tightening the springs doesn't seem to be helping, so just give up. break; // Further tightening the springs doesn't seem to be helping, so just give up.
prevMaxError = maxError; prevMaxError = maxError;
k *= 10; data.k *= 10;
if (maxError > 100*workingConstraintTol) { if (maxError > 100*workingConstraintTol) {
// We've gotten far enough from a valid state that we might have trouble getting // We've gotten far enough from a valid state that we might have trouble getting
// back, so reset to the original positions. // back, so reset to the original positions.
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2015 Stanford University and the Authors. * * Portions copyright (c) 2010-2018 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "openmm/VerletIntegrator.h" #include "openmm/VerletIntegrator.h"
#include "openmm/VirtualSite.h" #include "openmm/VirtualSite.h"
#include "sfmt/SFMT.h" #include "sfmt/SFMT.h"
#include <algorithm>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -196,6 +197,40 @@ void testVirtualSites() { ...@@ -196,6 +197,40 @@ void testVirtualSites() {
ASSERT(forceNorm < 2*tolerance); ASSERT(forceNorm < 2*tolerance);
} }
void testLargeForces() {
// Create a set of particles that are almost on top of each other so the initial
// forces are huge.
const int numParticles = 10;
System system;
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
nonbonded->addParticle(1.0, 0.2, 1.0);
}
vector<Vec3> positions(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++)
positions[i] = Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))*1e-10;
// Minimize it and verify that it didn't blow up.
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
LocalEnergyMinimizer::minimize(context, 1.0);
State state = context.getState(State::Positions);
double maxdist = 0.0;
for (int i = 0; i < numParticles; i++) {
Vec3 r = state.getPositions()[i];
maxdist = max(maxdist, sqrt(r.dot(r)));
}
ASSERT(maxdist > 0.1);
ASSERT(maxdist < 10.0);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -204,6 +239,7 @@ int main(int argc, char* argv[]) { ...@@ -204,6 +239,7 @@ int main(int argc, char* argv[]) {
testHarmonicBonds(); testHarmonicBonds();
testLargeSystem(); testLargeSystem();
testVirtualSites(); testVirtualSites();
testLargeForces();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment