Commit 59b0c5fa authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented updateParametersInContext() for NonbondedForce

parent 895f8dac
...@@ -507,6 +507,13 @@ public: ...@@ -507,6 +507,13 @@ public:
* @return the potential energy due to the force * @return the potential energy due to the force
*/ */
virtual double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) = 0; virtual double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) = 0;
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the NonbondedForce to copy the parameters from
*/
virtual void copyParametersToContext(ContextImpl& context, const NonbondedForce& force) = 0;
}; };
/** /**
......
...@@ -225,7 +225,9 @@ public: ...@@ -225,7 +225,9 @@ public:
*/ */
void loadCheckpoint(std::istream& stream); void loadCheckpoint(std::istream& stream);
private: private:
friend class Force;
friend class Platform; friend class Platform;
ContextImpl& getImpl();
ContextImpl* impl; ContextImpl* impl;
std::map<std::string, std::string> properties; std::map<std::string, std::string> properties;
}; };
......
...@@ -36,6 +36,8 @@ ...@@ -36,6 +36,8 @@
namespace OpenMM { namespace OpenMM {
class Context;
class ContextImpl;
class ForceImpl; class ForceImpl;
/** /**
...@@ -84,6 +86,14 @@ protected: ...@@ -84,6 +86,14 @@ protected:
* The ForceImpl will be deleted automatically when the Context is deleted. * The ForceImpl will be deleted automatically when the Context is deleted.
*/ */
virtual ForceImpl* createImpl() = 0; virtual ForceImpl* createImpl() = 0;
/**
* Get the ForceImpl corresponding to this Force in a Context.
*/
ForceImpl& getImplInContext(Context& context);
/**
* Get the ContextImpl corresponding to a Context.
*/
ContextImpl& getContextImpl(Context& context);
private: private:
int forceGroup; int forceGroup;
}; };
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "Context.h"
#include "Force.h" #include "Force.h"
#include <map> #include <map>
#include <set> #include <set>
...@@ -280,6 +281,19 @@ public: ...@@ -280,6 +281,19 @@ public:
* that is specified for direct space. * that is specified for direct space.
*/ */
void setReciprocalSpaceForceGroup(int group); void setReciprocalSpaceForceGroup(int group);
/**
* Update the particle and exception parameters in a Context to match those stored in this Force object. This method
* provides an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() and setExceptionParameters() to modify this object's parameters, then call
* updateParametersInState() to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the parameters of particles and exceptions.
* All other aspects of the Force (the nonbonded method, the cutoff distance, etc.) are unaffected and can only be
* changed by reinitializing the Context. Furthermore, only the chargeProd, sigma, and epsilon values of an exception
* can be changed; the pair of particles involved in the exception cannot change. Finally, this method cannot be used
* to add new particles or exceptions, only to change the parameters of existing ones.
*/
void updateParametersInContext(Context& context);
protected: protected:
ForceImpl* createImpl(); ForceImpl* createImpl();
private: private:
......
...@@ -204,6 +204,10 @@ public: ...@@ -204,6 +204,10 @@ public:
* Get the list of ForceImpls belonging to this ContextImpl. * Get the list of ForceImpls belonging to this ContextImpl.
*/ */
const std::vector<ForceImpl*>& getForceImpls() const; const std::vector<ForceImpl*>& getForceImpls() const;
/**
* Get the list of ForceImpls belonging to this ContextImpl.
*/
std::vector<ForceImpl*>& getForceImpls();
/** /**
* Get the platform-specific data stored in this context. * Get the platform-specific data stored in this context.
*/ */
......
...@@ -63,6 +63,7 @@ public: ...@@ -63,6 +63,7 @@ public:
return std::map<std::string, double>(); // This force field doesn't define any parameters. return std::map<std::string, double>(); // This force field doesn't define any parameters.
} }
std::vector<std::string> getKernelNames(); std::vector<std::string> getKernelNames();
void updateParametersInContext(ContextImpl& context);
/** /**
* This is a utility routine that calculates the values to use for alpha and kmax when using * This is a utility routine that calculates the values to use for alpha and kmax when using
* Ewald summation. * Ewald summation.
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/internal/ForceImpl.h"
#include <cmath> #include <cmath>
using namespace OpenMM; using namespace OpenMM;
...@@ -185,3 +186,7 @@ void Context::createCheckpoint(ostream& stream) { ...@@ -185,3 +186,7 @@ void Context::createCheckpoint(ostream& stream) {
void Context::loadCheckpoint(istream& stream) { void Context::loadCheckpoint(istream& stream) {
impl->loadCheckpoint(stream); impl->loadCheckpoint(stream);
} }
ContextImpl& Context::getImpl() {
return *impl;
}
...@@ -218,6 +218,10 @@ const vector<ForceImpl*>& ContextImpl::getForceImpls() const { ...@@ -218,6 +218,10 @@ const vector<ForceImpl*>& ContextImpl::getForceImpls() const {
return forceImpls; return forceImpls;
} }
vector<ForceImpl*>& ContextImpl::getForceImpls() {
return forceImpls;
}
void* ContextImpl::getPlatformData() { void* ContextImpl::getPlatformData() {
return platformData; return platformData;
} }
......
...@@ -29,10 +29,15 @@ ...@@ -29,10 +29,15 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "openmm/Context.h"
#include "openmm/Force.h" #include "openmm/Force.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/ForceImpl.h"
#include <vector>
using namespace OpenMM; using namespace OpenMM;
using namespace std;
int Force::getForceGroup() const { int Force::getForceGroup() const {
return forceGroup; return forceGroup;
...@@ -43,3 +48,15 @@ void Force::setForceGroup(int group) { ...@@ -43,3 +48,15 @@ void Force::setForceGroup(int group) {
throw OpenMMException("Force group must be between 0 and 31"); throw OpenMMException("Force group must be between 0 and 31");
forceGroup = group; forceGroup = group;
} }
ForceImpl& Force::getImplInContext(Context& context) {
const vector<ForceImpl*>& impls = context.getImpl().getForceImpls();
for (int i = 0; i < (int) impls.size(); i++)
if (&impls[i]->getOwner() == this)
return *impls[i];
throw OpenMMException("getImplInContext: This Force is not present in the Context");
}
ContextImpl& Force::getContextImpl(Context& context) {
return context.getImpl();
}
...@@ -206,3 +206,7 @@ void NonbondedForce::setReciprocalSpaceForceGroup(int group) { ...@@ -206,3 +206,7 @@ void NonbondedForce::setReciprocalSpaceForceGroup(int group) {
throw OpenMMException("Force group must be between -1 and 31"); throw OpenMMException("Force group must be between -1 and 31");
recipForceGroup = group; recipForceGroup = group;
} }
void NonbondedForce::updateParametersInContext(Context& context) {
dynamic_cast<NonbondedForceImpl&>(getImplInContext(context)).updateParametersInContext(getContextImpl(context));
}
...@@ -219,3 +219,7 @@ double NonbondedForceImpl::calcDispersionCorrection(const System& system, const ...@@ -219,3 +219,7 @@ double NonbondedForceImpl::calcDispersionCorrection(const System& system, const
double cutoff = force.getCutoffDistance(); double cutoff = force.getCutoffDistance();
return 8*numParticles*numParticles*M_PI*(sum1/(9*pow(cutoff, 9))-sum2/(3*pow(cutoff, 3))); return 8*numParticles*numParticles*M_PI*(sum1/(9*pow(cutoff, 9))-sum2/(3*pow(cutoff, 3)));
} }
void NonbondedForceImpl::updateParametersInContext(ContextImpl& context) {
kernel.getAs<CalcNonbondedForceKernel>().copyParametersToContext(context, owner);
}
...@@ -894,6 +894,10 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF ...@@ -894,6 +894,10 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
return 0.0; return 0.0;
} }
void CudaCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const NonbondedForce& force) {
throw OpenMMException("CudaPlatform does not support copyParametersToContext");
}
class CudaCalcCustomNonbondedForceKernel::ForceInfo : public CudaForceInfo { class CudaCalcCustomNonbondedForceKernel::ForceInfo : public CudaForceInfo {
public: public:
ForceInfo(const CustomNonbondedForce& force) : force(force) { ForceInfo(const CustomNonbondedForce& force) : force(force) {
......
...@@ -509,6 +509,13 @@ public: ...@@ -509,6 +509,13 @@ public:
* @return the potential energy due to the force * @return the potential energy due to the force
*/ */
double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal); double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the NonbondedForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const NonbondedForce& force);
private: private:
class ForceInfo; class ForceInfo;
CudaPlatform::PlatformData& data; CudaPlatform::PlatformData& data;
......
...@@ -66,7 +66,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i ...@@ -66,7 +66,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i
} }
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData) : OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData) :
time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), atomsWereReordered(false), posq(NULL), system(system), time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), atomsWereReordered(false), posq(NULL),
velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndex(NULL), integration(NULL), velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndex(NULL), integration(NULL),
bonded(NULL), nonbonded(NULL), thread(NULL) { bonded(NULL), nonbonded(NULL), thread(NULL) {
try { try {
...@@ -303,7 +303,7 @@ OpenCLContext::~OpenCLContext() { ...@@ -303,7 +303,7 @@ OpenCLContext::~OpenCLContext() {
delete thread; delete thread;
} }
void OpenCLContext::initialize(const System& system) { void OpenCLContext::initialize() {
for (int i = 0; i < numAtoms; i++) { for (int i = 0; i < numAtoms; i++) {
double mass = system.getParticleMass(i); double mass = system.getParticleMass(i);
(*velm)[i].w = (float) (mass == 0.0 ? 0.0 : 1.0/mass); (*velm)[i].w = (float) (mass == 0.0 ? 0.0 : 1.0/mass);
...@@ -331,7 +331,8 @@ void OpenCLContext::initialize(const System& system) { ...@@ -331,7 +331,8 @@ void OpenCLContext::initialize(const System& system) {
for (int i = 0; i < paddedNumAtoms; ++i) for (int i = 0; i < paddedNumAtoms; ++i)
(*atomIndex)[i] = i; (*atomIndex)[i] = i;
atomIndex->upload(); atomIndex->upload();
findMoleculeGroups(system); findMoleculeGroups();
moleculesInvalid = false;
nonbonded->initialize(system); nonbonded->initialize(system);
} }
...@@ -531,12 +532,6 @@ void OpenCLContext::tagAtomsInMolecule(int atom, int molecule, vector<int>& atom ...@@ -531,12 +532,6 @@ void OpenCLContext::tagAtomsInMolecule(int atom, int molecule, vector<int>& atom
tagAtomsInMolecule(atomBonds[atom][i], molecule, atomMolecule, atomBonds); tagAtomsInMolecule(atomBonds[atom][i], molecule, atomMolecule, atomBonds);
} }
struct OpenCLContext::Molecule {
vector<int> atoms;
vector<int> constraints;
vector<vector<int> > groups;
};
/** /**
* This class ensures that atom reordering doesn't break virtual sites. * This class ensures that atom reordering doesn't break virtual sites.
*/ */
...@@ -603,67 +598,72 @@ private: ...@@ -603,67 +598,72 @@ private:
}; };
void OpenCLContext::findMoleculeGroups(const System& system) { void OpenCLContext::findMoleculeGroups() {
// Add a ForceInfo that makes sure reordering doesn't break virtual sites. // The first time this is called, we need to identify all the molecules in the system.
addForce(new VirtualSiteInfo(system)); if (moleculeGroups.size() == 0) {
// Add a ForceInfo that makes sure reordering doesn't break virtual sites.
// First make a list of every other atom to which each atom is connect by a constraint or force group.
addForce(new VirtualSiteInfo(system));
vector<vector<int> > atomBonds(system.getNumParticles());
for (int i = 0; i < system.getNumConstraints(); i++) { // First make a list of every other atom to which each atom is connect by a constraint or force group.
int particle1, particle2;
double distance; vector<vector<int> > atomBonds(system.getNumParticles());
system.getConstraintParameters(i, particle1, particle2, distance); for (int i = 0; i < system.getNumConstraints(); i++) {
atomBonds[particle1].push_back(particle2); int particle1, particle2;
atomBonds[particle2].push_back(particle1); double distance;
} system.getConstraintParameters(i, particle1, particle2, distance);
for (int i = 0; i < (int) forces.size(); i++) { atomBonds[particle1].push_back(particle2);
for (int j = 0; j < forces[i]->getNumParticleGroups(); j++) { atomBonds[particle2].push_back(particle1);
vector<int> particles; }
forces[i]->getParticlesInGroup(j, particles); for (int i = 0; i < (int) forces.size(); i++) {
for (int k = 0; k < (int) particles.size(); k++) for (int j = 0; j < forces[i]->getNumParticleGroups(); j++) {
for (int m = 0; m < (int) particles.size(); m++) vector<int> particles;
if (k != m) forces[i]->getParticlesInGroup(j, particles);
atomBonds[particles[k]].push_back(particles[m]); for (int k = 0; k < (int) particles.size(); k++)
for (int m = 0; m < (int) particles.size(); m++)
if (k != m)
atomBonds[particles[k]].push_back(particles[m]);
}
} }
}
// Now tag atoms by which molecule they belong to. // Now tag atoms by which molecule they belong to.
vector<int> atomMolecule(numAtoms, -1); vector<int> atomMolecule(numAtoms, -1);
int numMolecules = 0; int numMolecules = 0;
for (int i = 0; i < numAtoms; i++) for (int i = 0; i < numAtoms; i++)
if (atomMolecule[i] == -1) if (atomMolecule[i] == -1)
tagAtomsInMolecule(i, numMolecules++, atomMolecule, atomBonds); tagAtomsInMolecule(i, numMolecules++, atomMolecule, atomBonds);
vector<vector<int> > atomIndices(numMolecules); vector<vector<int> > atomIndices(numMolecules);
for (int i = 0; i < numAtoms; i++) for (int i = 0; i < numAtoms; i++)
atomIndices[atomMolecule[i]].push_back(i); atomIndices[atomMolecule[i]].push_back(i);
// Construct a description of each molecule. // Construct a description of each molecule.
vector<Molecule> molecules(numMolecules); molecules.resize(numMolecules);
for (int i = 0; i < numMolecules; i++) { for (int i = 0; i < numMolecules; i++) {
molecules[i].atoms = atomIndices[i]; molecules[i].atoms = atomIndices[i];
molecules[i].groups.resize(forces.size()); molecules[i].groups.resize(forces.size());
}
for (int i = 0; i < system.getNumConstraints(); i++) {
int particle1, particle2;
double distance;
system.getConstraintParameters(i, particle1, particle2, distance);
molecules[atomMolecule[particle1]].constraints.push_back(i);
}
for (int i = 0; i < (int) forces.size(); i++)
for (int j = 0; j < forces[i]->getNumParticleGroups(); j++) {
vector<int> particles;
forces[i]->getParticlesInGroup(j, particles);
molecules[atomMolecule[particles[0]]].groups[i].push_back(j);
} }
for (int i = 0; i < system.getNumConstraints(); i++) {
int particle1, particle2;
double distance;
system.getConstraintParameters(i, particle1, particle2, distance);
molecules[atomMolecule[particle1]].constraints.push_back(i);
}
for (int i = 0; i < (int) forces.size(); i++)
for (int j = 0; j < forces[i]->getNumParticleGroups(); j++) {
vector<int> particles;
forces[i]->getParticlesInGroup(j, particles);
molecules[atomMolecule[particles[0]]].groups[i].push_back(j);
}
}
// Sort them into groups of identical molecules. // Sort them into groups of identical molecules.
vector<Molecule> uniqueMolecules; vector<Molecule> uniqueMolecules;
vector<vector<int> > moleculeInstances; vector<vector<int> > moleculeInstances;
vector<vector<int> > moleculeOffsets;
for (int molIndex = 0; molIndex < (int) molecules.size(); molIndex++) { for (int molIndex = 0; molIndex < (int) molecules.size(); molIndex++) {
Molecule& mol = molecules[molIndex]; Molecule& mol = molecules[molIndex];
...@@ -706,20 +706,24 @@ void OpenCLContext::findMoleculeGroups(const System& system) { ...@@ -706,20 +706,24 @@ void OpenCLContext::findMoleculeGroups(const System& system) {
identical = false; identical = false;
} }
if (identical) { if (identical) {
moleculeInstances[j].push_back(mol.atoms[0]); moleculeInstances[j].push_back(molIndex);
moleculeOffsets[j].push_back(mol.atoms[0]);
isNew = false; isNew = false;
} }
} }
if (isNew) { if (isNew) {
uniqueMolecules.push_back(mol); uniqueMolecules.push_back(mol);
moleculeInstances.push_back(vector<int>()); moleculeInstances.push_back(vector<int>());
moleculeInstances[moleculeInstances.size()-1].push_back(mol.atoms[0]); moleculeInstances[moleculeInstances.size()-1].push_back(molIndex);
moleculeOffsets.push_back(vector<int>());
moleculeOffsets[moleculeOffsets.size()-1].push_back(mol.atoms[0]);
} }
} }
moleculeGroups.resize(moleculeInstances.size()); moleculeGroups.resize(moleculeInstances.size());
for (int i = 0; i < (int) moleculeInstances.size(); i++) for (int i = 0; i < (int) moleculeInstances.size(); i++)
{ {
moleculeGroups[i].instances = moleculeInstances[i]; moleculeGroups[i].instances = moleculeInstances[i];
moleculeGroups[i].offsets = moleculeOffsets[i];
vector<int>& atoms = uniqueMolecules[i].atoms; vector<int>& atoms = uniqueMolecules[i].atoms;
moleculeGroups[i].atoms.resize(atoms.size()); moleculeGroups[i].atoms.resize(atoms.size());
for (int j = 0; j < (int) atoms.size(); j++) for (int j = 0; j < (int) atoms.size(); j++)
...@@ -727,9 +731,78 @@ void OpenCLContext::findMoleculeGroups(const System& system) { ...@@ -727,9 +731,78 @@ void OpenCLContext::findMoleculeGroups(const System& system) {
} }
} }
void OpenCLContext::invalidateMolecules() {
moleculesInvalid = true;
}
void OpenCLContext::validateMolecules() {
moleculesInvalid = false;
if (numAtoms == 0 || nonbonded == NULL || !nonbonded->getUseCutoff())
return;
bool valid = true;
for (int group = 0; valid && group < (int) moleculeGroups.size(); group++) {
MoleculeGroup& mol = moleculeGroups[group];
vector<int>& instances = mol.instances;
vector<int>& offsets = mol.offsets;
vector<int>& atoms = mol.atoms;
int numMolecules = instances.size();
Molecule& m1 = molecules[instances[0]];
int offset1 = offsets[0];
for (int j = 1; valid && j < numMolecules; j++) {
// See if the atoms are identical.
Molecule& m2 = molecules[instances[j]];
int offset2 = offsets[j];
for (int i = 0; i < (int) atoms.size() && valid; i++) {
for (int k = 0; k < (int) forces.size(); k++)
if (!forces[k]->areParticlesIdentical(atoms[i]+offset1, atoms[i]+offset2))
valid = false;
}
// See if the force groups are identical.
for (int i = 0; i < (int) forces.size() && valid; i++) {
for (int k = 0; k < (int) m1.groups[i].size() && valid; k++)
if (!forces[i]->areGroupsIdentical(m1.groups[i][k], m2.groups[i][k]))
valid = false;
}
}
}
if (valid)
return;
// The list of which molecules are identical is no longer valid. We need to restore the
// atoms to their original order, rebuild the list of identical molecules, and sort them
// again.
vector<mm_float4> newPosq(numAtoms);
vector<mm_float4> newVelm(numAtoms);
vector<mm_int4> newCellOffsets(numAtoms);
posq->download();
velm->download();
for (int i = 0; i < numAtoms; i++) {
int index = atomIndex->get(i);
newPosq[index] = posq->get(i);
newVelm[index] = velm->get(i);
newCellOffsets[index] = posCellOffsets[i];
}
for (int i = 0; i < numAtoms; i++) {
posq->set(i, newPosq[i]);
velm->set(i, newVelm[i]);
atomIndex->set(i, i);
posCellOffsets[i] = newCellOffsets[i];
}
posq->upload();
velm->upload();
atomIndex->upload();
findMoleculeGroups();
}
void OpenCLContext::reorderAtoms() { void OpenCLContext::reorderAtoms() {
if (numAtoms == 0 || nonbonded == NULL || !nonbonded->getUseCutoff()) if (numAtoms == 0 || nonbonded == NULL || !nonbonded->getUseCutoff())
return; return;
if (moleculesInvalid)
validateMolecules();
atomsWereReordered = true; atomsWereReordered = true;
// Find the range of positions and the number of bins along each axis. // Find the range of positions and the number of bins along each axis.
...@@ -767,7 +840,7 @@ void OpenCLContext::reorderAtoms() { ...@@ -767,7 +840,7 @@ void OpenCLContext::reorderAtoms() {
// Find the center of each molecule. // Find the center of each molecule.
MoleculeGroup& mol = moleculeGroups[group]; MoleculeGroup& mol = moleculeGroups[group];
int numMolecules = mol.instances.size(); int numMolecules = mol.offsets.size();
vector<int>& atoms = mol.atoms; vector<int>& atoms = mol.atoms;
vector<mm_float4> molPos(numMolecules); vector<mm_float4> molPos(numMolecules);
float invNumAtoms = 1.0f/atoms.size(); float invNumAtoms = 1.0f/atoms.size();
...@@ -776,7 +849,7 @@ void OpenCLContext::reorderAtoms() { ...@@ -776,7 +849,7 @@ void OpenCLContext::reorderAtoms() {
molPos[i].y = 0.0f; molPos[i].y = 0.0f;
molPos[i].z = 0.0f; molPos[i].z = 0.0f;
for (int j = 0; j < (int)atoms.size(); j++) { for (int j = 0; j < (int)atoms.size(); j++) {
int atom = atoms[j]+mol.instances[i]; int atom = atoms[j]+mol.offsets[i];
const mm_float4& pos = posq->get(atom); const mm_float4& pos = posq->get(atom);
molPos[i].x += pos.x; molPos[i].x += pos.x;
molPos[i].y += pos.y; molPos[i].y += pos.y;
...@@ -801,7 +874,7 @@ void OpenCLContext::reorderAtoms() { ...@@ -801,7 +874,7 @@ void OpenCLContext::reorderAtoms() {
molPos[i].y -= dy; molPos[i].y -= dy;
molPos[i].z -= dz; molPos[i].z -= dz;
for (int j = 0; j < (int) atoms.size(); j++) { for (int j = 0; j < (int) atoms.size(); j++) {
int atom = atoms[j]+mol.instances[i]; int atom = atoms[j]+mol.offsets[i];
mm_float4 p = posq->get(atom); mm_float4 p = posq->get(atom);
p.x -= dx; p.x -= dx;
p.y -= dy; p.y -= dy;
...@@ -854,8 +927,8 @@ void OpenCLContext::reorderAtoms() { ...@@ -854,8 +927,8 @@ void OpenCLContext::reorderAtoms() {
for (int i = 0; i < numMolecules; i++) { for (int i = 0; i < numMolecules; i++) {
for (int j = 0; j < (int)atoms.size(); j++) { for (int j = 0; j < (int)atoms.size(); j++) {
int oldIndex = mol.instances[molBins[i].second]+atoms[j]; int oldIndex = mol.offsets[molBins[i].second]+atoms[j];
int newIndex = mol.instances[i]+atoms[j]; int newIndex = mol.offsets[i]+atoms[j];
originalIndex[newIndex] = atomIndex->get(oldIndex); originalIndex[newIndex] = atomIndex->get(oldIndex);
newPosq[newIndex] = posq->get(oldIndex); newPosq[newIndex] = posq->get(oldIndex);
newVelm[newIndex] = velm->get(oldIndex); newVelm[newIndex] = velm->get(oldIndex);
......
...@@ -152,7 +152,7 @@ public: ...@@ -152,7 +152,7 @@ public:
* This is called to initialize internal data structures after all Forces in the system * This is called to initialize internal data structures after all Forces in the system
* have been initialized. * have been initialized.
*/ */
void initialize(const System& system); void initialize();
/** /**
* Add an OpenCLForce to this context. * Add an OpenCLForce to this context.
*/ */
...@@ -479,12 +479,31 @@ public: ...@@ -479,12 +479,31 @@ public:
std::vector<ReorderListener*>& getReorderListeners() { std::vector<ReorderListener*>& getReorderListeners() {
return reorderListeners; return reorderListeners;
} }
/**
* Mark that the current molecule definitions (and hence the atom order) may be invalid.
* This should be called whenever force field parameters change. It will cause the definitions
* and order to be revalidated the next to reorderAtoms() is called.
*/
void invalidateMolecules();
/**
* Get whether the current molecule definitions are valid.
*/
bool getMoleculesAreInvalid() {
return moleculesInvalid;
}
private: private:
struct Molecule; struct Molecule;
struct MoleculeGroup; struct MoleculeGroup;
class VirtualSiteInfo; class VirtualSiteInfo;
void findMoleculeGroups(const System& system); void findMoleculeGroups();
static void tagAtomsInMolecule(int atom, int molecule, std::vector<int>& atomMolecule, std::vector<std::vector<int> >& atomBonds); static void tagAtomsInMolecule(int atom, int molecule, std::vector<int>& atomMolecule, std::vector<std::vector<int> >& atomBonds);
/**
* Ensure that all molecules marked as "identical" really are identical. This should be
* called whenever force field parameters change. If necessary, it will rebuild the list
* of molecules and resort the atoms.
*/
void validateMolecules();
const System& system;
double time; double time;
OpenCLPlatform::PlatformData& platformData; OpenCLPlatform::PlatformData& platformData;
int deviceIndex; int deviceIndex;
...@@ -497,7 +516,7 @@ private: ...@@ -497,7 +516,7 @@ private:
int numThreadBlocks; int numThreadBlocks;
int numForceBuffers; int numForceBuffers;
int simdWidth; int simdWidth;
bool supports64BitGlobalAtomics, supportsDoublePrecision, atomsWereReordered; bool supports64BitGlobalAtomics, supportsDoublePrecision, atomsWereReordered, moleculesInvalid;
mm_float4 periodicBoxSize; mm_float4 periodicBoxSize;
mm_float4 invPeriodicBoxSize; mm_float4 invPeriodicBoxSize;
std::string defaultOptimizationOptions; std::string defaultOptimizationOptions;
...@@ -515,6 +534,7 @@ private: ...@@ -515,6 +534,7 @@ private:
cl::Kernel reduceFloat4Kernel; cl::Kernel reduceFloat4Kernel;
cl::Kernel reduceForcesKernel; cl::Kernel reduceForcesKernel;
std::vector<OpenCLForceInfo*> forces; std::vector<OpenCLForceInfo*> forces;
std::vector<Molecule> molecules;
std::vector<MoleculeGroup> moleculeGroups; std::vector<MoleculeGroup> moleculeGroups;
std::vector<mm_int4> posCellOffsets; std::vector<mm_int4> posCellOffsets;
OpenCLArray<mm_float4>* posq; OpenCLArray<mm_float4>* posq;
...@@ -533,9 +553,16 @@ private: ...@@ -533,9 +553,16 @@ private:
WorkThread* thread; WorkThread* thread;
}; };
struct OpenCLContext::Molecule {
std::vector<int> atoms;
std::vector<int> constraints;
std::vector<std::vector<int> > groups;
};
struct OpenCLContext::MoleculeGroup { struct OpenCLContext::MoleculeGroup {
std::vector<int> atoms; std::vector<int> atoms;
std::vector<int> instances; std::vector<int> instances;
std::vector<int> offsets;
}; };
/** /**
......
...@@ -98,7 +98,7 @@ void OpenCLCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, boo ...@@ -98,7 +98,7 @@ void OpenCLCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, boo
OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities(); OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
bool includeNonbonded = ((groups&(1<<nb.getForceGroup())) != 0); bool includeNonbonded = ((groups&(1<<nb.getForceGroup())) != 0);
cl.setAtomsWereReordered(false); cl.setAtomsWereReordered(false);
if (nb.getUseCutoff() && includeNonbonded && cl.getComputeForceCount()%100 == 0) { if (nb.getUseCutoff() && includeNonbonded && (cl.getComputeForceCount()%100 == 0 || cl.getMoleculesAreInvalid())) {
cl.reorderAtoms(); cl.reorderAtoms();
nb.updateNeighborListSize(); nb.updateNeighborListSize();
cl.setComputeForceCount(cl.getComputeForceCount()+1); cl.setComputeForceCount(cl.getComputeForceCount()+1);
...@@ -1058,8 +1058,8 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb ...@@ -1058,8 +1058,8 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
vector<mm_float2> sigmaEpsilonVector(numParticles); vector<mm_float2> sigmaEpsilonVector(numParticles);
vector<vector<int> > exclusionList(numParticles); vector<vector<int> > exclusionList(numParticles);
double sumSquaredCharges = 0.0; double sumSquaredCharges = 0.0;
bool hasCoulomb = false; hasCoulomb = false;
bool hasLJ = false; hasLJ = false;
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
double charge, sigma, epsilon; double charge, sigma, epsilon;
force.getParticleParameters(i, charge, sigma, epsilon); force.getParticleParameters(i, charge, sigma, epsilon);
...@@ -1095,7 +1095,7 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb ...@@ -1095,7 +1095,7 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(system, force); dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(system, force);
else else
dispersionCoefficient = 0.0; dispersionCoefficient = 0.0;
double alpha = 0; alpha = 0;
if (force.getNonbondedMethod() == NonbondedForce::Ewald) { if (force.getNonbondedMethod() == NonbondedForce::Ewald) {
// Compute the Ewald parameters. // Compute the Ewald parameters.
...@@ -1367,6 +1367,77 @@ double OpenCLCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -1367,6 +1367,77 @@ double OpenCLCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
return energy; return energy;
} }
void OpenCLCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const NonbondedForce& force) {
// Make sure the new parameters are acceptable.
if (force.getNumParticles() != cl.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed");
if (!hasCoulomb || !hasLJ) {
for (int i = 0; i < force.getNumParticles(); i++) {
double charge, sigma, epsilon;
force.getParticleParameters(i, charge, sigma, epsilon);
if (!hasCoulomb && charge != 0.0)
throw OpenMMException("updateParametersInContext: The nonbonded force kernel does not include Coulomb interactions, because all charges were originally 0");
if (!hasLJ && epsilon != 0.0)
throw OpenMMException("updateParametersInContext: The nonbonded force kernel does not include Lennard-Jones interactions, because all epsilons were originally 0");
}
}
vector<int> exceptions;
for (int i = 0; i < force.getNumExceptions(); i++) {
int particle1, particle2;
double chargeProd, sigma, epsilon;
force.getExceptionParameters(i, particle1, particle2, chargeProd, sigma, epsilon);
if (chargeProd != 0.0 || epsilon != 0.0)
exceptions.push_back(i);
}
int numContexts = cl.getPlatformData().contexts.size();
int startIndex = cl.getContextIndex()*exceptions.size()/numContexts;
int endIndex = (cl.getContextIndex()+1)*exceptions.size()/numContexts;
int numExceptions = endIndex-startIndex;
if ((exceptionParams == NULL && numExceptions > 0) || (exceptionParams != NULL && numExceptions != exceptionParams->getSize()))
throw OpenMMException("updateParametersInContext: The number of non-excluded exceptions has changed");
// Record the per-particle parameters.
OpenCLArray<mm_float4>& posq = cl.getPosq();
posq.download();
vector<mm_float2> sigmaEpsilonVector(force.getNumParticles());
double sumSquaredCharges = 0.0;
OpenCLArray<cl_int>& order = cl.getAtomIndex();
for (int i = 0; i < force.getNumParticles(); i++) {
int index = order[i];
double charge, sigma, epsilon;
force.getParticleParameters(index, charge, sigma, epsilon);
posq[i].w = (float) charge;
sigmaEpsilonVector[index] = mm_float2((float) (0.5*sigma), (float) (2.0*sqrt(epsilon)));
sumSquaredCharges += charge*charge;
}
posq.upload();
sigmaEpsilon->upload(sigmaEpsilonVector);
// Record the exceptions.
if (numExceptions > 0) {
vector<vector<int> > atoms(numExceptions, vector<int>(2));
vector<mm_float4> exceptionParamsVector(numExceptions);
for (int i = 0; i < numExceptions; i++) {
double chargeProd, sigma, epsilon;
force.getExceptionParameters(exceptions[startIndex+i], atoms[i][0], atoms[i][1], chargeProd, sigma, epsilon);
exceptionParamsVector[i] = mm_float4((float) (ONE_4PI_EPS0*chargeProd), (float) sigma, (float) (4.0*epsilon), 0.0f);
}
exceptionParams->upload(exceptionParamsVector);
}
// Compute other values.
NonbondedForce::NonbondedMethod method = force.getNonbondedMethod();
if (method == NonbondedForce::Ewald || method == NonbondedForce::PME)
ewaldSelfEnergy = (cl.getContextIndex() == 0 ? -ONE_4PI_EPS0*alpha*sumSquaredCharges/sqrt(M_PI) : 0.0);
if (force.getUseDispersionCorrection() && cl.getContextIndex() == 0 && (method == NonbondedForce::CutoffPeriodic || method == NonbondedForce::Ewald || method == NonbondedForce::PME))
dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(context.getSystem(), force);
cl.invalidateMolecules();
}
class OpenCLCustomNonbondedForceInfo : public OpenCLForceInfo { class OpenCLCustomNonbondedForceInfo : public OpenCLForceInfo {
public: public:
OpenCLCustomNonbondedForceInfo(int requiredBuffers, const CustomNonbondedForce& force) : OpenCLForceInfo(requiredBuffers), force(force) { OpenCLCustomNonbondedForceInfo(int requiredBuffers, const CustomNonbondedForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
......
...@@ -522,6 +522,13 @@ public: ...@@ -522,6 +522,13 @@ public:
* @return the potential energy due to the force * @return the potential energy due to the force
*/ */
double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal); double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the NonbondedForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const NonbondedForce& force);
private: private:
struct SortTrait { struct SortTrait {
typedef mm_int2 DataType; typedef mm_int2 DataType;
...@@ -560,8 +567,9 @@ private: ...@@ -560,8 +567,9 @@ private:
cl::Kernel pmeConvolutionKernel; cl::Kernel pmeConvolutionKernel;
cl::Kernel pmeInterpolateForceKernel; cl::Kernel pmeInterpolateForceKernel;
std::map<std::string, std::string> pmeDefines; std::map<std::string, std::string> pmeDefines;
double ewaldSelfEnergy, dispersionCoefficient; double ewaldSelfEnergy, dispersionCoefficient, alpha;
int interpolateForceThreads; int interpolateForceThreads;
bool hasCoulomb, hasLJ;
static const int PmeOrder = 5; static const int PmeOrder = 5;
}; };
......
...@@ -522,6 +522,11 @@ double OpenCLParallelCalcNonbondedForceKernel::execute(ContextImpl& context, boo ...@@ -522,6 +522,11 @@ double OpenCLParallelCalcNonbondedForceKernel::execute(ContextImpl& context, boo
return 0.0; return 0.0;
} }
void OpenCLParallelCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const NonbondedForce& force) {
for (int i = 0; i < (int) kernels.size(); i++)
getKernel(i).copyParametersToContext(context, force);
}
class OpenCLParallelCalcCustomNonbondedForceKernel::Task : public OpenCLContext::WorkTask { class OpenCLParallelCalcCustomNonbondedForceKernel::Task : public OpenCLContext::WorkTask {
public: public:
Task(ContextImpl& context, OpenCLCalcCustomNonbondedForceKernel& kernel, bool includeForce, Task(ContextImpl& context, OpenCLCalcCustomNonbondedForceKernel& kernel, bool includeForce,
......
...@@ -363,6 +363,13 @@ public: ...@@ -363,6 +363,13 @@ public:
* @return the potential energy due to the force * @return the potential energy due to the force
*/ */
double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal); double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the NonbondedForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const NonbondedForce& force);
private: private:
class Task; class Task;
OpenCLPlatform::PlatformData& data; OpenCLPlatform::PlatformData& data;
......
...@@ -147,7 +147,7 @@ OpenCLPlatform::PlatformData::~PlatformData() { ...@@ -147,7 +147,7 @@ OpenCLPlatform::PlatformData::~PlatformData() {
void OpenCLPlatform::PlatformData::initializeContexts(const System& system) { void OpenCLPlatform::PlatformData::initializeContexts(const System& system) {
for (int i = 0; i < (int) contexts.size(); i++) for (int i = 0; i < (int) contexts.size(); i++)
contexts[i]->initialize(system); contexts[i]->initialize();
} }
void OpenCLPlatform::PlatformData::syncContexts() { void OpenCLPlatform::PlatformData::syncContexts() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment