Unverified Commit 71f4b3fc authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Virtual sites can depend on other virtual sites (#4348)

* Reference platform supports nested virtual sites

* Common platform supports nested virtual sites

* Fixed force distribution from nested virtual sites

* Fixed test failures
parent 162b7c37
...@@ -114,6 +114,6 @@ void ReferenceVerletDynamics::update(const OpenMM::System& system, vector<Vec3>& ...@@ -114,6 +114,6 @@ void ReferenceVerletDynamics::update(const OpenMM::System& system, vector<Vec3>&
} }
} }
ReferenceVirtualSites::computePositions(system, atomCoordinates); getVirtualSites().computePositions(system, atomCoordinates);
incrementTimeStep(); incrementTimeStep();
} }
...@@ -31,14 +31,40 @@ ...@@ -31,14 +31,40 @@
#include "ReferenceVirtualSites.h" #include "ReferenceVirtualSites.h"
#include "openmm/VirtualSite.h" #include "openmm/VirtualSite.h"
#include "openmm/OpenMMException.h"
#include <cmath> #include <cmath>
#include <set>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
void ReferenceVirtualSites::computePositions(const OpenMM::System& system, vector<OpenMM::Vec3>& atomCoordinates) { ReferenceVirtualSites::ReferenceVirtualSites(const System& system) {
set<int> sites;
for (int i = 0; i < system.getNumParticles(); i++) for (int i = 0; i < system.getNumParticles(); i++)
if (system.isVirtualSite(i)) { if (system.isVirtualSite(i))
sites.insert(i);
int remainingSites = 0;
while (sites.size() > 0) {
if (sites.size() == remainingSites)
throw OpenMMException("Virtual site definitions are circular");
for (auto index = sites.begin(); index != sites.end();) {
const VirtualSite& site = system.getVirtualSite(*index);
bool canCompute = true;
for (int i = 0; i < site.getNumParticles(); i++)
if (sites.find(site.getParticle(i)) != sites.end())
canCompute = false;
if (canCompute) {
order.push_back(*index);
index = sites.erase(index);
}
else
++index;
}
}
}
void ReferenceVirtualSites::computePositions(const OpenMM::System& system, vector<OpenMM::Vec3>& atomCoordinates) const {
for (int i : order) {
if (dynamic_cast<const TwoParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) { if (dynamic_cast<const TwoParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) {
// A two particle average. // A two particle average.
...@@ -96,9 +122,9 @@ void ReferenceVirtualSites::computePositions(const OpenMM::System& system, vecto ...@@ -96,9 +122,9 @@ void ReferenceVirtualSites::computePositions(const OpenMM::System& system, vecto
} }
} }
void ReferenceVirtualSites::distributeForces(const OpenMM::System& system, const vector<OpenMM::Vec3>& atomCoordinates, vector<OpenMM::Vec3>& forces) { void ReferenceVirtualSites::distributeForces(const OpenMM::System& system, const vector<OpenMM::Vec3>& atomCoordinates, vector<OpenMM::Vec3>& forces) const {
for (int i = 0; i < system.getNumParticles(); i++) for (auto iter = order.rbegin(); iter != order.rend(); ++iter) {
if (system.isVirtualSite(i)) { int i = *iter;
Vec3 f = forces[i]; Vec3 f = forces[i];
if (dynamic_cast<const TwoParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) { if (dynamic_cast<const TwoParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) {
// A two particle average. // A two particle average.
......
...@@ -61,6 +61,11 @@ static ReferenceConstraints& extractConstraints(ContextImpl& context) { ...@@ -61,6 +61,11 @@ static ReferenceConstraints& extractConstraints(ContextImpl& context) {
return *data->constraints; return *data->constraints;
} }
static const ReferenceVirtualSites& extractVirtualSites(ContextImpl& context) {
ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return *data->virtualSites;
}
static double computeShiftedKineticEnergy(ContextImpl& context, vector<double>& inverseMasses, double timeShift) { static double computeShiftedKineticEnergy(ContextImpl& context, vector<double>& inverseMasses, double timeShift) {
const System& system = context.getSystem(); const System& system = context.getSystem();
int numParticles = system.getNumParticles(); int numParticles = system.getNumParticles();
...@@ -374,7 +379,7 @@ void ReferenceIntegrateDrudeLangevinStepKernel::execute(ContextImpl& context, co ...@@ -374,7 +379,7 @@ void ReferenceIntegrateDrudeLangevinStepKernel::execute(ContextImpl& context, co
} }
} }
} }
ReferenceVirtualSites::computePositions(context.getSystem(), pos); extractVirtualSites(context).computePositions(context.getSystem(), pos);
data.time += integrator.getStepSize(); data.time += integrator.getStepSize();
data.stepCount++; data.stepCount++;
} }
...@@ -452,7 +457,7 @@ void ReferenceIntegrateDrudeSCFStepKernel::execute(ContextImpl& context, const D ...@@ -452,7 +457,7 @@ void ReferenceIntegrateDrudeSCFStepKernel::execute(ContextImpl& context, const D
// Update the positions of virtual sites and Drude particles. // Update the positions of virtual sites and Drude particles.
ReferenceVirtualSites::computePositions(context.getSystem(), pos); extractVirtualSites(context).computePositions(context.getSystem(), pos);
minimize(context, integrator.getMinimizationErrorTolerance()); minimize(context, integrator.getMinimizationErrorTolerance());
data.time += integrator.getStepSize(); data.time += integrator.getStepSize();
data.stepCount++; data.stepCount++;
......
...@@ -482,6 +482,36 @@ void testOverlappingSites() { ...@@ -482,6 +482,36 @@ void testOverlappingSites() {
ASSERT_EQUAL_VEC(s1.getForces()[i], s2.getForces()[i], 1e-5); ASSERT_EQUAL_VEC(s1.getForces()[i], s2.getForces()[i], 1e-5);
} }
/**
* Test virtual sites that depend on other virtual sites.
*/
void testNestedSites() {
System system;
system.addParticle(1.0);
for (int i = 0; i < 3; i++)
system.addParticle(0.0);
system.addParticle(1.0);
system.setVirtualSite(2, new TwoParticleAverageSite(0, 4, 0.5, 0.5));
system.setVirtualSite(1, new TwoParticleAverageSite(0, 2, 0.5, 0.5));
system.setVirtualSite(3, new TwoParticleAverageSite(2, 4, 0.5, 0.5));
CustomExternalForce* force = new CustomExternalForce("-c*x");
force->addPerParticleParameter("c");
force->addParticle(1, {1.0});
force->addParticle(3, {2.0});
system.addForce(force);
vector<Vec3> positions(5);
positions[4] = Vec3(0, 0, 4.0);
VerletIntegrator integrator(0.002);
Context context(system, integrator, platform);
context.setPositions(positions);
context.computeVirtualSites();
State state = context.getState(State::Positions | State::Forces);
for (int i = 0; i < 5; i++)
ASSERT_EQUAL_VEC(Vec3(0, 0, i), state.getPositions()[i], 1e-6);
ASSERT_EQUAL_VEC(Vec3(1*0.75 + 2*0.25, 0, 0), state.getForces()[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(1*0.25 + 2*0.75, 0, 0), state.getForces()[4], 1e-6);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -496,6 +526,7 @@ int main(int argc, char* argv[]) { ...@@ -496,6 +526,7 @@ int main(int argc, char* argv[]) {
testLocalCoordinates(4); testLocalCoordinates(4);
testConservationLaws(); testConservationLaws();
testOverlappingSites(); testOverlappingSites();
testNestedSites();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment