Unverified Commit 71f4b3fc authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Virtual sites can depend on other virtual sites (#4348)

* Reference platform supports nested virtual sites

* Common platform supports nested virtual sites

* Fixed force distribution from nested virtual sites

* Fixed test failures
parent 162b7c37
......@@ -114,6 +114,6 @@ void ReferenceVerletDynamics::update(const OpenMM::System& system, vector<Vec3>&
}
}
ReferenceVirtualSites::computePositions(system, atomCoordinates);
getVirtualSites().computePositions(system, atomCoordinates);
incrementTimeStep();
}
......@@ -31,14 +31,40 @@
#include "ReferenceVirtualSites.h"
#include "openmm/VirtualSite.h"
#include "openmm/OpenMMException.h"
#include <cmath>
#include <set>
using namespace OpenMM;
using namespace std;
void ReferenceVirtualSites::computePositions(const OpenMM::System& system, vector<OpenMM::Vec3>& atomCoordinates) {
ReferenceVirtualSites::ReferenceVirtualSites(const System& system) {
set<int> sites;
for (int i = 0; i < system.getNumParticles(); i++)
if (system.isVirtualSite(i)) {
if (system.isVirtualSite(i))
sites.insert(i);
int remainingSites = 0;
while (sites.size() > 0) {
if (sites.size() == remainingSites)
throw OpenMMException("Virtual site definitions are circular");
for (auto index = sites.begin(); index != sites.end();) {
const VirtualSite& site = system.getVirtualSite(*index);
bool canCompute = true;
for (int i = 0; i < site.getNumParticles(); i++)
if (sites.find(site.getParticle(i)) != sites.end())
canCompute = false;
if (canCompute) {
order.push_back(*index);
index = sites.erase(index);
}
else
++index;
}
}
}
void ReferenceVirtualSites::computePositions(const OpenMM::System& system, vector<OpenMM::Vec3>& atomCoordinates) const {
for (int i : order) {
if (dynamic_cast<const TwoParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) {
// A two particle average.
......@@ -96,9 +122,9 @@ void ReferenceVirtualSites::computePositions(const OpenMM::System& system, vecto
}
}
void ReferenceVirtualSites::distributeForces(const OpenMM::System& system, const vector<OpenMM::Vec3>& atomCoordinates, vector<OpenMM::Vec3>& forces) {
for (int i = 0; i < system.getNumParticles(); i++)
if (system.isVirtualSite(i)) {
void ReferenceVirtualSites::distributeForces(const OpenMM::System& system, const vector<OpenMM::Vec3>& atomCoordinates, vector<OpenMM::Vec3>& forces) const {
for (auto iter = order.rbegin(); iter != order.rend(); ++iter) {
int i = *iter;
Vec3 f = forces[i];
if (dynamic_cast<const TwoParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) {
// A two particle average.
......
......@@ -61,6 +61,11 @@ static ReferenceConstraints& extractConstraints(ContextImpl& context) {
return *data->constraints;
}
static const ReferenceVirtualSites& extractVirtualSites(ContextImpl& context) {
ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return *data->virtualSites;
}
static double computeShiftedKineticEnergy(ContextImpl& context, vector<double>& inverseMasses, double timeShift) {
const System& system = context.getSystem();
int numParticles = system.getNumParticles();
......@@ -374,7 +379,7 @@ void ReferenceIntegrateDrudeLangevinStepKernel::execute(ContextImpl& context, co
}
}
}
ReferenceVirtualSites::computePositions(context.getSystem(), pos);
extractVirtualSites(context).computePositions(context.getSystem(), pos);
data.time += integrator.getStepSize();
data.stepCount++;
}
......@@ -452,7 +457,7 @@ void ReferenceIntegrateDrudeSCFStepKernel::execute(ContextImpl& context, const D
// Update the positions of virtual sites and Drude particles.
ReferenceVirtualSites::computePositions(context.getSystem(), pos);
extractVirtualSites(context).computePositions(context.getSystem(), pos);
minimize(context, integrator.getMinimizationErrorTolerance());
data.time += integrator.getStepSize();
data.stepCount++;
......
......@@ -482,6 +482,36 @@ void testOverlappingSites() {
ASSERT_EQUAL_VEC(s1.getForces()[i], s2.getForces()[i], 1e-5);
}
/**
* Test virtual sites that depend on other virtual sites.
*/
void testNestedSites() {
System system;
system.addParticle(1.0);
for (int i = 0; i < 3; i++)
system.addParticle(0.0);
system.addParticle(1.0);
system.setVirtualSite(2, new TwoParticleAverageSite(0, 4, 0.5, 0.5));
system.setVirtualSite(1, new TwoParticleAverageSite(0, 2, 0.5, 0.5));
system.setVirtualSite(3, new TwoParticleAverageSite(2, 4, 0.5, 0.5));
CustomExternalForce* force = new CustomExternalForce("-c*x");
force->addPerParticleParameter("c");
force->addParticle(1, {1.0});
force->addParticle(3, {2.0});
system.addForce(force);
vector<Vec3> positions(5);
positions[4] = Vec3(0, 0, 4.0);
VerletIntegrator integrator(0.002);
Context context(system, integrator, platform);
context.setPositions(positions);
context.computeVirtualSites();
State state = context.getState(State::Positions | State::Forces);
for (int i = 0; i < 5; i++)
ASSERT_EQUAL_VEC(Vec3(0, 0, i), state.getPositions()[i], 1e-6);
ASSERT_EQUAL_VEC(Vec3(1*0.75 + 2*0.25, 0, 0), state.getForces()[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(1*0.25 + 2*0.75, 0, 0), state.getForces()[4], 1e-6);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -496,6 +526,7 @@ int main(int argc, char* argv[]) {
testLocalCoordinates(4);
testConservationLaws();
testOverlappingSites();
testNestedSites();
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment