Commit f30ddfa4 authored by Peter Eastman's avatar Peter Eastman
Browse files

LocalEnergyMinimizer works with virtual sites

parent a566a074
...@@ -208,6 +208,30 @@ public: ...@@ -208,6 +208,30 @@ public:
virtual void apply(ContextImpl& context, double tol) = 0; virtual void apply(ContextImpl& context, double tol) = 0;
}; };
/**
* This kernel recomputes the positions of virtual sites.
*/
class VirtualSitesKernel : public KernelImpl {
public:
static std::string Name() {
return "VirtualSites";
}
VirtualSitesKernel(std::string name, const Platform& platform) : KernelImpl(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
virtual void initialize(const System& system) = 0;
/**
* Compute the virtual site locations.
*
* @param context the context in which to execute this kernel
*/
virtual void computePositions(ContextImpl& context) = 0;
};
/** /**
* This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system. * This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system.
*/ */
......
...@@ -171,11 +171,19 @@ public: ...@@ -171,11 +171,19 @@ public:
*/ */
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c); void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
/** /**
* Update the positions of particles so that all distance constraints are satisfied. * Update the positions of particles so that all distance constraints are satisfied. This also recomputes
* the locations of all virtual sites.
* *
* @param tol the distance tolerance within which constraints must be satisfied. * @param tol the distance tolerance within which constraints must be satisfied.
*/ */
void applyConstraints(double tol); void applyConstraints(double tol);
/**
* Recompute the locations of all virtual sites. There is rarely a reason to call
* this, since virtual sites are also updated by applyConstraints(). This is only
* for the rare situations when you want to enforce virtual sites but <i>not</i>
* constraints.
*/
void computeVirtualSites();
/** /**
* When a Context is created, it may cache information about the System being simulated * When a Context is created, it may cache information about the System being simulated
* and the Force objects contained in it. This means that, if the System or Forces are then * and the Force objects contained in it. This means that, if the System or Forces are then
......
...@@ -162,11 +162,19 @@ public: ...@@ -162,11 +162,19 @@ public:
*/ */
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c); void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
/** /**
* Update the positions of particles so that all distance constraints are satisfied. * Update the positions of particles so that all distance constraints are satisfied. This also recomputes
* the locations of all virtual sites.
* *
* @param tol the distance tolerance within which constraints must be satisfied. * @param tol the distance tolerance within which constraints must be satisfied.
*/ */
void applyConstraints(double tol); void applyConstraints(double tol);
/**
* Recompute the locations of all virtual sites. There is rarely a reason to call
* this, since virtual sites are also updated by applyConstraints(). This is only
* for the rare situations when you want to enforce virtual sites but <i>not</i>
* constraints.
*/
void computeVirtualSites();
/** /**
* Recalculate all of the forces in the system and/or the potential energy of the system (in kJ/mol). * Recalculate all of the forces in the system and/or the potential energy of the system (in kJ/mol).
* After calling this, use getForces() to retrieve the forces that were calculated. * After calling this, use getForces() to retrieve the forces that were calculated.
...@@ -217,7 +225,7 @@ private: ...@@ -217,7 +225,7 @@ private:
mutable std::vector<std::vector<int> > molecules; mutable std::vector<std::vector<int> > molecules;
bool hasInitializedForces; bool hasInitializedForces;
Platform* platform; Platform* platform;
Kernel initializeForcesKernel, kineticEnergyKernel, updateStateDataKernel, applyConstraintsKernel; Kernel initializeForcesKernel, kineticEnergyKernel, updateStateDataKernel, applyConstraintsKernel, virtualSitesKernel;
void* platformData; void* platformData;
}; };
......
...@@ -166,6 +166,10 @@ void Context::applyConstraints(double tol) { ...@@ -166,6 +166,10 @@ void Context::applyConstraints(double tol) {
impl->applyConstraints(tol); impl->applyConstraints(tol);
} }
void Context::computeVirtualSites() {
impl->computeVirtualSites();
}
void Context::reinitialize() { void Context::reinitialize() {
System& system = impl->getSystem(); System& system = impl->getSystem();
Integrator& integrator = impl->getIntegrator(); Integrator& integrator = impl->getIntegrator();
......
...@@ -105,6 +105,8 @@ ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, ...@@ -105,6 +105,8 @@ ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator,
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).initialize(system); dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).initialize(system);
applyConstraintsKernel = platform->createKernel(ApplyConstraintsKernel::Name(), *this); applyConstraintsKernel = platform->createKernel(ApplyConstraintsKernel::Name(), *this);
dynamic_cast<ApplyConstraintsKernel&>(applyConstraintsKernel.getImpl()).initialize(system); dynamic_cast<ApplyConstraintsKernel&>(applyConstraintsKernel.getImpl()).initialize(system);
virtualSitesKernel = platform->createKernel(VirtualSitesKernel::Name(), *this);
dynamic_cast<VirtualSitesKernel&>(virtualSitesKernel.getImpl()).initialize(system);
Vec3 periodicBoxVectors[3]; Vec3 periodicBoxVectors[3];
system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]); system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setPeriodicBoxVectors(*this, periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]); dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setPeriodicBoxVectors(*this, periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
...@@ -185,6 +187,10 @@ void ContextImpl::applyConstraints(double tol) { ...@@ -185,6 +187,10 @@ void ContextImpl::applyConstraints(double tol) {
dynamic_cast<ApplyConstraintsKernel&>(applyConstraintsKernel.getImpl()).apply(*this, tol); dynamic_cast<ApplyConstraintsKernel&>(applyConstraintsKernel.getImpl()).apply(*this, tol);
} }
void ContextImpl::computeVirtualSites() {
dynamic_cast<VirtualSitesKernel&>(virtualSitesKernel.getImpl()).computePositions(*this);
}
double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy) { double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy) {
CalcForcesAndEnergyKernel& kernel = dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl()); CalcForcesAndEnergyKernel& kernel = dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl());
double energy = 0.0; double energy = 0.0;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2011 Stanford University and the Authors. * * Portions copyright (c) 2010-2012 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -59,13 +59,21 @@ static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsf ...@@ -59,13 +59,21 @@ static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsf
for (int i = 0; i < numParticles; i++) for (int i = 0; i < numParticles; i++)
positions[i] = Vec3(x[3*i], x[3*i+1], x[3*i+2]); positions[i] = Vec3(x[3*i], x[3*i+1], x[3*i+2]);
context.setPositions(positions); context.setPositions(positions);
context.computeVirtualSites();
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
if (system.isVirtualSite(i)) {
g[3*i] = 0.0;
g[3*i+1] = 0.0;
g[3*i+2] = 0.0;
}
else {
g[3*i] = -forces[i][0]; g[3*i] = -forces[i][0];
g[3*i+1] = -forces[i][1]; g[3*i+1] = -forces[i][1];
g[3*i+2] = -forces[i][2]; g[3*i+2] = -forces[i][2];
} }
}
double energy = state.getPotentialEnergy(); double energy = state.getPotentialEnergy();
// Add harmonic forces for any constraints. // Add harmonic forces for any constraints.
......
...@@ -39,6 +39,8 @@ OPENMMCUDA_EXPORT KernelImpl* CudaKernelFactory::createKernelImpl(std::string na ...@@ -39,6 +39,8 @@ OPENMMCUDA_EXPORT KernelImpl* CudaKernelFactory::createKernelImpl(std::string na
return new CudaUpdateStateDataKernel(name, platform, data); return new CudaUpdateStateDataKernel(name, platform, data);
if (name == ApplyConstraintsKernel::Name()) if (name == ApplyConstraintsKernel::Name())
return new CudaApplyConstraintsKernel(name, platform, data); return new CudaApplyConstraintsKernel(name, platform, data);
if (name == VirtualSitesKernel::Name())
return new CudaVirtualSitesKernel(name, platform);
if (name == CalcHarmonicBondForceKernel::Name()) if (name == CalcHarmonicBondForceKernel::Name())
return new CudaCalcHarmonicBondForceKernel(name, platform, data, context.getSystem()); return new CudaCalcHarmonicBondForceKernel(name, platform, data, context.getSystem());
if (name == CalcCustomBondForceKernel::Name()) if (name == CalcCustomBondForceKernel::Name())
......
...@@ -185,6 +185,12 @@ void CudaApplyConstraintsKernel::apply(ContextImpl& context, double tol) { ...@@ -185,6 +185,12 @@ void CudaApplyConstraintsKernel::apply(ContextImpl& context, double tol) {
kApplyConstraints(data.gpu); kApplyConstraints(data.gpu);
} }
void CudaVirtualSitesKernel::initialize(const System& system) {
}
void CudaVirtualSitesKernel::computePositions(ContextImpl& context) {
}
class CudaCalcHarmonicBondForceKernel::ForceInfo : public CudaForceInfo { class CudaCalcHarmonicBondForceKernel::ForceInfo : public CudaForceInfo {
public: public:
ForceInfo(const HarmonicBondForce& force) : force(force) { ForceInfo(const HarmonicBondForce& force) : force(force) {
......
...@@ -183,6 +183,27 @@ private: ...@@ -183,6 +183,27 @@ private:
CudaPlatform::PlatformData& data; CudaPlatform::PlatformData& data;
}; };
/**
* This kernel recomputes the positions of virtual sites.
*/
class CudaVirtualSitesKernel : public VirtualSitesKernel {
public:
CudaVirtualSitesKernel(std::string name, const Platform& platform) : VirtualSitesKernel(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* Compute the virtual site locations.
*
* @param context the context in which to execute this kernel
*/
void computePositions(ContextImpl& context);
};
/** /**
* This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system. * This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system.
*/ */
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "kernels/gputypes.h" #include "kernels/gputypes.h"
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/OpenMMException.h"
#include "openmm/System.h" #include "openmm/System.h"
#include <sstream> #include <sstream>
...@@ -48,6 +49,7 @@ CudaPlatform::CudaPlatform() { ...@@ -48,6 +49,7 @@ CudaPlatform::CudaPlatform() {
registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory); registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory);
registerKernelFactory(UpdateStateDataKernel::Name(), factory); registerKernelFactory(UpdateStateDataKernel::Name(), factory);
registerKernelFactory(ApplyConstraintsKernel::Name(), factory); registerKernelFactory(ApplyConstraintsKernel::Name(), factory);
registerKernelFactory(VirtualSitesKernel::Name(), factory);
registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory);
registerKernelFactory(CalcCustomBondForceKernel::Name(), factory); registerKernelFactory(CalcCustomBondForceKernel::Name(), factory);
registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory);
...@@ -93,6 +95,10 @@ void CudaPlatform::setPropertyValue(Context& context, const string& property, co ...@@ -93,6 +95,10 @@ void CudaPlatform::setPropertyValue(Context& context, const string& property, co
} }
void CudaPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const { void CudaPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
System& system = context.getSystem();
for (int i = 0; i < system.getNumParticles(); i++)
if (system.isVirtualSite(i))
throw OpenMMException("CudaPlatform does not support virtual sites");
unsigned int device = 0; unsigned int device = 0;
const string& devicePropValue = (properties.find(CudaDevice()) == properties.end() ? const string& devicePropValue = (properties.find(CudaDevice()) == properties.end() ?
getPropertyDefaultValue(CudaDevice()) : properties.find(CudaDevice())->second); getPropertyDefaultValue(CudaDevice()) : properties.find(CudaDevice())->second);
......
...@@ -70,6 +70,8 @@ KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platfo ...@@ -70,6 +70,8 @@ KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platfo
return new OpenCLUpdateStateDataKernel(name, platform, cl); return new OpenCLUpdateStateDataKernel(name, platform, cl);
if (name == ApplyConstraintsKernel::Name()) if (name == ApplyConstraintsKernel::Name())
return new OpenCLApplyConstraintsKernel(name, platform, cl); return new OpenCLApplyConstraintsKernel(name, platform, cl);
if (name == VirtualSitesKernel::Name())
return new OpenCLVirtualSitesKernel(name, platform, cl);
if (name == CalcHarmonicBondForceKernel::Name()) if (name == CalcHarmonicBondForceKernel::Name())
return new OpenCLCalcHarmonicBondForceKernel(name, platform, cl, context.getSystem()); return new OpenCLCalcHarmonicBondForceKernel(name, platform, cl, context.getSystem());
if (name == CalcCustomBondForceKernel::Name()) if (name == CalcCustomBondForceKernel::Name())
......
...@@ -225,6 +225,13 @@ void OpenCLApplyConstraintsKernel::apply(ContextImpl& context, double tol) { ...@@ -225,6 +225,13 @@ void OpenCLApplyConstraintsKernel::apply(ContextImpl& context, double tol) {
cl.getIntegrationUtilities().computeVirtualSites(); cl.getIntegrationUtilities().computeVirtualSites();
} }
void OpenCLVirtualSitesKernel::initialize(const System& system) {
}
void OpenCLVirtualSitesKernel::computePositions(ContextImpl& context) {
cl.getIntegrationUtilities().computeVirtualSites();
}
class OpenCLBondForceInfo : public OpenCLForceInfo { class OpenCLBondForceInfo : public OpenCLForceInfo {
public: public:
OpenCLBondForceInfo(int requiredBuffers, const HarmonicBondForce& force) : OpenCLForceInfo(requiredBuffers), force(force) { OpenCLBondForceInfo(int requiredBuffers, const HarmonicBondForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
......
...@@ -178,6 +178,29 @@ private: ...@@ -178,6 +178,29 @@ private:
OpenCLContext& cl; OpenCLContext& cl;
}; };
/**
* This kernel recomputes the positions of virtual sites.
*/
class OpenCLVirtualSitesKernel : public VirtualSitesKernel {
public:
OpenCLVirtualSitesKernel(std::string name, const Platform& platform, OpenCLContext& cl) : VirtualSitesKernel(name, platform), cl(cl) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* Compute the virtual site locations.
*
* @param context the context in which to execute this kernel
*/
void computePositions(ContextImpl& context);
private:
OpenCLContext& cl;
};
/** /**
* This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system. * This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system.
*/ */
......
...@@ -48,6 +48,7 @@ OpenCLPlatform::OpenCLPlatform() { ...@@ -48,6 +48,7 @@ OpenCLPlatform::OpenCLPlatform() {
registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory); registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory);
registerKernelFactory(UpdateStateDataKernel::Name(), factory); registerKernelFactory(UpdateStateDataKernel::Name(), factory);
registerKernelFactory(ApplyConstraintsKernel::Name(), factory); registerKernelFactory(ApplyConstraintsKernel::Name(), factory);
registerKernelFactory(VirtualSitesKernel::Name(), factory);
registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory);
registerKernelFactory(CalcCustomBondForceKernel::Name(), factory); registerKernelFactory(CalcCustomBondForceKernel::Name(), factory);
registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory);
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "openmm/LocalEnergyMinimizer.h" #include "openmm/LocalEnergyMinimizer.h"
#include "openmm/NonbondedForce.h" #include "openmm/NonbondedForce.h"
#include "openmm/VerletIntegrator.h" #include "openmm/VerletIntegrator.h"
#include "openmm/VirtualSite.h"
#include "sfmt/SFMT.h" #include "sfmt/SFMT.h"
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -132,10 +133,77 @@ void testLargeSystem() { ...@@ -132,10 +133,77 @@ void testLargeSystem() {
ASSERT(forceNorm < 3*tolerance); ASSERT(forceNorm < 3*tolerance);
} }
void testVirtualSites() {
const int numMolecules = 50;
const int numParticles = numMolecules*3;
const double cutoff = 2.0;
const double boxSize = 5.0;
const double tolerance = 5;
System system;
system.setDefaultPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
NonbondedForce* nonbonded = new NonbondedForce();
nonbonded->setCutoffDistance(cutoff);
nonbonded->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
system.addForce(nonbonded);
// Create a cloud of molecules.
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
vector<Vec3> positions(numParticles);
for (int i = 0; i < numMolecules; i++) {
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(0.0);
nonbonded->addParticle(-1.0, 0.2, 0.2);
nonbonded->addParticle(0.5, 0.2, 0.2);
nonbonded->addParticle(0.5, 0.2, 0.2);
positions[3*i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
positions[3*i+1] = Vec3(positions[3*i][0]+1.0, positions[3*i][1], positions[3*i][2]);
positions[3*i+2] = Vec3();
system.addConstraint(3*i, 3*i+1, 1.0);
system.setVirtualSite(3*i+2, new TwoParticleAverageSite(3*i, 3*i+1, 0.5, 0.5));
}
// Minimize it and verify that the energy has decreased.
OpenCLPlatform platform;
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
State initialState = context.getState(State::Forces | State::Energy);
LocalEnergyMinimizer::minimize(context, tolerance);
State finalState = context.getState(State::Forces | State::Energy | State::Positions);
ASSERT(finalState.getPotentialEnergy() < initialState.getPotentialEnergy());
// Compute the force magnitude, subtracting off any component parallel to a constraint, and
// check that it satisfies the requested tolerance.
double forceNorm = 0.0;
for (int i = 0; i < numParticles; i += 3) {
Vec3 dir = finalState.getPositions()[i+1]-finalState.getPositions()[i];
double distance = sqrt(dir.dot(dir));
dir *= 1.0/distance;
Vec3 f = finalState.getForces()[i];
f -= dir*dir.dot(f);
forceNorm += f.dot(f);
f = finalState.getForces()[i+1];
f -= dir*dir.dot(f);
forceNorm += f.dot(f);
// Check the virtual site location.
ASSERT_EQUAL_VEC((finalState.getPositions()[i+1]+finalState.getPositions()[i])*0.5, finalState.getPositions()[i+2], 1e-5);
}
forceNorm = sqrt(forceNorm/(4*numMolecules));
ASSERT(forceNorm < 3*tolerance);
}
int main() { int main() {
try { try {
testHarmonicBonds(); testHarmonicBonds();
testLargeSystem(); testLargeSystem();
testVirtualSites();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -44,6 +44,8 @@ KernelImpl* ReferenceKernelFactory::createKernelImpl(std::string name, const Pla ...@@ -44,6 +44,8 @@ KernelImpl* ReferenceKernelFactory::createKernelImpl(std::string name, const Pla
return new ReferenceUpdateStateDataKernel(name, platform, data); return new ReferenceUpdateStateDataKernel(name, platform, data);
if (name == ApplyConstraintsKernel::Name()) if (name == ApplyConstraintsKernel::Name())
return new ReferenceApplyConstraintsKernel(name, platform, data); return new ReferenceApplyConstraintsKernel(name, platform, data);
if (name == VirtualSitesKernel::Name())
return new ReferenceVirtualSitesKernel(name, platform);
if (name == CalcNonbondedForceKernel::Name()) if (name == CalcNonbondedForceKernel::Name())
return new ReferenceCalcNonbondedForceKernel(name, platform); return new ReferenceCalcNonbondedForceKernel(name, platform);
if (name == CalcCustomNonbondedForceKernel::Name()) if (name == CalcCustomNonbondedForceKernel::Name())
......
...@@ -278,6 +278,14 @@ void ReferenceApplyConstraintsKernel::apply(ContextImpl& context, double tol) { ...@@ -278,6 +278,14 @@ void ReferenceApplyConstraintsKernel::apply(ContextImpl& context, double tol) {
ReferenceVirtualSites::computePositions(context.getSystem(), positions); ReferenceVirtualSites::computePositions(context.getSystem(), positions);
} }
void ReferenceVirtualSitesKernel::initialize(const System& system) {
}
void ReferenceVirtualSitesKernel::computePositions(ContextImpl& context) {
vector<RealVec>& positions = extractPositions(context);
ReferenceVirtualSites::computePositions(context.getSystem(), positions);
}
ReferenceCalcHarmonicBondForceKernel::~ReferenceCalcHarmonicBondForceKernel() { ReferenceCalcHarmonicBondForceKernel::~ReferenceCalcHarmonicBondForceKernel() {
disposeIntArray(bondIndexArray, numBonds); disposeIntArray(bondIndexArray, numBonds);
disposeRealArray(bondParamArray, numBonds); disposeRealArray(bondParamArray, numBonds);
......
...@@ -201,6 +201,27 @@ private: ...@@ -201,6 +201,27 @@ private:
int numConstraints; int numConstraints;
}; };
/**
* This kernel recomputes the positions of virtual sites.
*/
class ReferenceVirtualSitesKernel : public VirtualSitesKernel {
public:
ReferenceVirtualSitesKernel(std::string name, const Platform& platform) : VirtualSitesKernel(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* Compute the virtual site locations.
*
* @param context the context in which to execute this kernel
*/
void computePositions(ContextImpl& context);
};
/** /**
* This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system. * This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system.
*/ */
......
...@@ -45,6 +45,7 @@ ReferencePlatform::ReferencePlatform() { ...@@ -45,6 +45,7 @@ ReferencePlatform::ReferencePlatform() {
registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory); registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory);
registerKernelFactory(UpdateStateDataKernel::Name(), factory); registerKernelFactory(UpdateStateDataKernel::Name(), factory);
registerKernelFactory(ApplyConstraintsKernel::Name(), factory); registerKernelFactory(ApplyConstraintsKernel::Name(), factory);
registerKernelFactory(VirtualSitesKernel::Name(), factory);
registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory);
registerKernelFactory(CalcCustomBondForceKernel::Name(), factory); registerKernelFactory(CalcCustomBondForceKernel::Name(), factory);
registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory);
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "openmm/LocalEnergyMinimizer.h" #include "openmm/LocalEnergyMinimizer.h"
#include "openmm/NonbondedForce.h" #include "openmm/NonbondedForce.h"
#include "openmm/VerletIntegrator.h" #include "openmm/VerletIntegrator.h"
#include "openmm/VirtualSite.h"
#include "sfmt/SFMT.h" #include "sfmt/SFMT.h"
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -113,7 +114,7 @@ void testLargeSystem() { ...@@ -113,7 +114,7 @@ void testLargeSystem() {
State finalState = context.getState(State::Forces | State::Energy | State::Positions); State finalState = context.getState(State::Forces | State::Energy | State::Positions);
ASSERT(finalState.getPotentialEnergy() < initialState.getPotentialEnergy()); ASSERT(finalState.getPotentialEnergy() < initialState.getPotentialEnergy());
// Compute the force magnitude, substracting off any component parallel to a constraint, and // Compute the force magnitude, subtracting off any component parallel to a constraint, and
// check that it satisfies the requested tolerance. // check that it satisfies the requested tolerance.
double forceNorm = 0.0; double forceNorm = 0.0;
...@@ -132,10 +133,77 @@ void testLargeSystem() { ...@@ -132,10 +133,77 @@ void testLargeSystem() {
ASSERT(forceNorm < 3*tolerance); ASSERT(forceNorm < 3*tolerance);
} }
void testVirtualSites() {
const int numMolecules = 50;
const int numParticles = numMolecules*3;
const double cutoff = 2.0;
const double boxSize = 5.0;
const double tolerance = 5;
System system;
system.setDefaultPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
NonbondedForce* nonbonded = new NonbondedForce();
nonbonded->setCutoffDistance(cutoff);
nonbonded->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
system.addForce(nonbonded);
// Create a cloud of molecules.
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
vector<Vec3> positions(numParticles);
for (int i = 0; i < numMolecules; i++) {
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(0.0);
nonbonded->addParticle(-1.0, 0.2, 0.2);
nonbonded->addParticle(0.5, 0.2, 0.2);
nonbonded->addParticle(0.5, 0.2, 0.2);
positions[3*i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
positions[3*i+1] = Vec3(positions[3*i][0]+1.0, positions[3*i][1], positions[3*i][2]);
positions[3*i+2] = Vec3();
system.addConstraint(3*i, 3*i+1, 1.0);
system.setVirtualSite(3*i+2, new TwoParticleAverageSite(3*i, 3*i+1, 0.5, 0.5));
}
// Minimize it and verify that the energy has decreased.
ReferencePlatform platform;
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
State initialState = context.getState(State::Forces | State::Energy);
LocalEnergyMinimizer::minimize(context, tolerance);
State finalState = context.getState(State::Forces | State::Energy | State::Positions);
ASSERT(finalState.getPotentialEnergy() < initialState.getPotentialEnergy());
// Compute the force magnitude, subtracting off any component parallel to a constraint, and
// check that it satisfies the requested tolerance.
double forceNorm = 0.0;
for (int i = 0; i < numParticles; i += 3) {
Vec3 dir = finalState.getPositions()[i+1]-finalState.getPositions()[i];
double distance = sqrt(dir.dot(dir));
dir *= 1.0/distance;
Vec3 f = finalState.getForces()[i];
f -= dir*dir.dot(f);
forceNorm += f.dot(f);
f = finalState.getForces()[i+1];
f -= dir*dir.dot(f);
forceNorm += f.dot(f);
// Check the virtual site location.
ASSERT_EQUAL_VEC((finalState.getPositions()[i+1]+finalState.getPositions()[i])*0.5, finalState.getPositions()[i+2], 1e-5);
}
forceNorm = sqrt(forceNorm/(4*numMolecules));
ASSERT(forceNorm < 3*tolerance);
}
int main() { int main() {
try { try {
testHarmonicBonds(); testHarmonicBonds();
testLargeSystem(); testLargeSystem();
testVirtualSites();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment