Commit 87542608 authored by Peter Eastman's avatar Peter Eastman
Browse files

Restructured API for periodic box size so it is now queried from State rather than Context

parent 607f2b6a
...@@ -173,6 +173,22 @@ public: ...@@ -173,6 +173,22 @@ public:
* @param forces on exit, this contains the forces * @param forces on exit, this contains the forces
*/ */
virtual void getForces(ContextImpl& context, std::vector<Vec3>& forces) = 0; virtual void getForces(ContextImpl& context, std::vector<Vec3>& forces) = 0;
/**
* Get the current periodic box vectors.
*
* @param a on exit, this contains the vector defining the first edge of the periodic box
* @param b on exit, this contains the vector defining the second edge of the periodic box
* @param c on exit, this contains the vector defining the third edge of the periodic box
*/
virtual void getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const = 0;
/**
* Set the current periodic box vectors.
*
* @param a the vector defining the first edge of the periodic box
* @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box
*/
virtual void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const = 0;
}; };
/** /**
......
...@@ -153,18 +153,6 @@ public: ...@@ -153,18 +153,6 @@ public:
* @param value the value of the parameter * @param value the value of the parameter
*/ */
void setParameter(const std::string& name, double value); void setParameter(const std::string& name, double value);
/**
* Get the vectors defining the axes of the periodic box (measured in nm). They will affect
* any Force that uses periodic boundary conditions.
*
* Currently, only rectangular boxes are supported. This means that a, b, and c must be aligned with the
* x, y, and z axes respectively. Future releases may support arbitrary triclinic boxes.
*
* @param a on exit, this contains the vector defining the first edge of the periodic box
* @param b on exit, this contains the vector defining the second edge of the periodic box
* @param c on exit, this contains the vector defining the third edge of the periodic box
*/
void getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) const;
/** /**
* Set the vectors defining the axes of the periodic box (measured in nm). They will affect * Set the vectors defining the axes of the periodic box (measured in nm). They will affect
* any Force that uses periodic boundary conditions. * any Force that uses periodic boundary conditions.
...@@ -176,7 +164,7 @@ public: ...@@ -176,7 +164,7 @@ public:
* @param b the vector defining the second edge of the periodic box * @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box * @param c the vector defining the third edge of the periodic box
*/ */
void setPeriodicBoxVectors(Vec3 a, Vec3 b, Vec3 c); void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
/** /**
* When a Context is created, it may cache information about the System being simulated * When a Context is created, it may cache information about the System being simulated
* and the Force objects contained in it. This means that, if the System or Forces are then * and the Force objects contained in it. This means that, if the System or Forces are then
...@@ -191,7 +179,6 @@ private: ...@@ -191,7 +179,6 @@ private:
friend class Platform; friend class Platform;
ContextImpl* impl; ContextImpl* impl;
std::map<std::string, std::string> properties; std::map<std::string, std::string> properties;
Vec3 periodicBoxVectors[3];
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -81,6 +81,14 @@ public: ...@@ -81,6 +81,14 @@ public:
* Get the total potential energy of the system. If this State does not contain energies, this will throw an exception. * Get the total potential energy of the system. If this State does not contain energies, this will throw an exception.
*/ */
double getPotentialEnergy() const; double getPotentialEnergy() const;
/**
* Get the vectors defining the axes of the periodic box (measured in nm).
*
* @param a on exit, this contains the vector defining the first edge of the periodic box
* @param b on exit, this contains the vector defining the second edge of the periodic box
* @param c on exit, this contains the vector defining the third edge of the periodic box
*/
void getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) const;
/** /**
* Get a map containing the values of all parameters. If this State does not contain parameters, this will throw an exception. * Get a map containing the values of all parameters. If this State does not contain parameters, this will throw an exception.
*/ */
...@@ -93,11 +101,13 @@ private: ...@@ -93,11 +101,13 @@ private:
std::vector<Vec3>& updForces(); std::vector<Vec3>& updForces();
std::map<std::string, double>& updParameters(); std::map<std::string, double>& updParameters();
void setEnergy(double ke, double pe); void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
DataType types; DataType types;
double time, ke, pe; double time, ke, pe;
std::vector<Vec3> positions; std::vector<Vec3> positions;
std::vector<Vec3> velocities; std::vector<Vec3> velocities;
std::vector<Vec3> forces; std::vector<Vec3> forces;
Vec3 periodicBoxVectors[3];
std::map<std::string, double> parameters; std::map<std::string, double> parameters;
}; };
......
...@@ -189,7 +189,7 @@ public: ...@@ -189,7 +189,7 @@ public:
* @param b the vector defining the second edge of the periodic box * @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box * @param c the vector defining the third edge of the periodic box
*/ */
void setDefaultPeriodicBoxVectors(Vec3 a, Vec3 b, Vec3 c); void setDefaultPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
private: private:
class ConstraintInfo; class ConstraintInfo;
Vec3 periodicBoxVectors[3]; Vec3 periodicBoxVectors[3];
......
...@@ -133,6 +133,30 @@ public: ...@@ -133,6 +133,30 @@ public:
* @param value the value of the parameter * @param value the value of the parameter
*/ */
void setParameter(std::string name, double value); void setParameter(std::string name, double value);
/**
* Get the vectors defining the axes of the periodic box (measured in nm). They will affect
* any Force that uses periodic boundary conditions.
*
* Currently, only rectangular boxes are supported. This means that a, b, and c must be aligned with the
* x, y, and z axes respectively. Future releases may support arbitrary triclinic boxes.
*
* @param a the vector defining the first edge of the periodic box
* @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box
*/
void getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c);
/**
* Set the vectors defining the axes of the periodic box (measured in nm). They will affect
* any Force that uses periodic boundary conditions.
*
* Currently, only rectangular boxes are supported. This means that a, b, and c must be aligned with the
* x, y, and z axes respectively. Future releases may support arbitrary triclinic boxes.
*
* @param a the vector defining the first edge of the periodic box
* @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box
*/
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
/** /**
* Recalculate all of the forces in the system. After calling this, use getForces() to retrieve * Recalculate all of the forces in the system. After calling this, use getForces() to retrieve
* the forces that were calculated. * the forces that were calculated.
......
...@@ -37,17 +37,14 @@ using namespace OpenMM; ...@@ -37,17 +37,14 @@ using namespace OpenMM;
using namespace std; using namespace std;
Context::Context(System& system, Integrator& integrator) : properties(map<string, string>()) { Context::Context(System& system, Integrator& integrator) : properties(map<string, string>()) {
system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
impl = new ContextImpl(*this, system, integrator, 0, properties); impl = new ContextImpl(*this, system, integrator, 0, properties);
} }
Context::Context(System& system, Integrator& integrator, Platform& platform) : properties(map<string, string>()) { Context::Context(System& system, Integrator& integrator, Platform& platform) : properties(map<string, string>()) {
system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
impl = new ContextImpl(*this, system, integrator, &platform, properties); impl = new ContextImpl(*this, system, integrator, &platform, properties);
} }
Context::Context(System& system, Integrator& integrator, Platform& platform, const map<string, string>& properties) : properties(properties) { Context::Context(System& system, Integrator& integrator, Platform& platform, const map<string, string>& properties) : properties(properties) {
system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
impl = new ContextImpl(*this, system, integrator, &platform, properties); impl = new ContextImpl(*this, system, integrator, &platform, properties);
} }
...@@ -82,6 +79,9 @@ Platform& Context::getPlatform() { ...@@ -82,6 +79,9 @@ Platform& Context::getPlatform() {
State Context::getState(int types) const { State Context::getState(int types) const {
State state(impl->getTime(), impl->getSystem().getNumParticles(), State::DataType(types)); State state(impl->getTime(), impl->getSystem().getNumParticles(), State::DataType(types));
Vec3 periodicBoxSize[3];
impl->getPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
state.setPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
if (types&State::Energy) if (types&State::Energy)
state.setEnergy(impl->calcKineticEnergy(), impl->calcPotentialEnergy()); state.setEnergy(impl->calcKineticEnergy(), impl->calcPotentialEnergy());
if (types&State::Forces) { if (types&State::Forces) {
...@@ -123,29 +123,14 @@ void Context::setParameter(const string& name, double value) { ...@@ -123,29 +123,14 @@ void Context::setParameter(const string& name, double value) {
impl->setParameter(name, value); impl->setParameter(name, value);
} }
void Context::getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) const { void Context::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
a = periodicBoxVectors[0]; impl->setPeriodicBoxVectors(a, b, c);
b = periodicBoxVectors[1];
c = periodicBoxVectors[2];
}
void Context::setPeriodicBoxVectors(Vec3 a, Vec3 b, Vec3 c) {
if (a[1] != 0.0 || a[2] != 0.0)
throw OpenMMException("First periodic box vector must be parallel to x.");
if (b[0] != 0.0 || b[2] != 0.0)
throw OpenMMException("Second periodic box vector must be parallel to y.");
if (c[0] != 0.0 || c[1] != 0.0)
throw OpenMMException("Third periodic box vector must be parallel to z.");
periodicBoxVectors[0] = a;
periodicBoxVectors[1] = b;
periodicBoxVectors[2] = c;
} }
void Context::reinitialize() { void Context::reinitialize() {
System& system = impl->getSystem(); System& system = impl->getSystem();
Integrator& integrator = impl->getIntegrator(); Integrator& integrator = impl->getIntegrator();
Platform& platform = impl->getPlatform(); Platform& platform = impl->getPlatform();
system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
delete impl; delete impl;
impl = new ContextImpl(*this, system, integrator, &platform, properties); impl = new ContextImpl(*this, system, integrator, &platform, properties);
} }
...@@ -73,6 +73,9 @@ ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, ...@@ -73,6 +73,9 @@ ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator,
dynamic_cast<CalcKineticEnergyKernel&>(kineticEnergyKernel.getImpl()).initialize(system); dynamic_cast<CalcKineticEnergyKernel&>(kineticEnergyKernel.getImpl()).initialize(system);
updateStateDataKernel = platform->createKernel(UpdateStateDataKernel::Name(), *this); updateStateDataKernel = platform->createKernel(UpdateStateDataKernel::Name(), *this);
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).initialize(system); dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).initialize(system);
Vec3 periodicBoxVectors[3];
system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setPeriodicBoxVectors(*this, periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
for (size_t i = 0; i < forceImpls.size(); ++i) for (size_t i = 0; i < forceImpls.size(); ++i)
forceImpls[i]->initialize(*this); forceImpls[i]->initialize(*this);
integrator.initialize(*this); integrator.initialize(*this);
...@@ -125,6 +128,20 @@ void ContextImpl::setParameter(std::string name, double value) { ...@@ -125,6 +128,20 @@ void ContextImpl::setParameter(std::string name, double value) {
parameters[name] = value; parameters[name] = value;
} }
void ContextImpl::getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) {
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).getPeriodicBoxVectors(*this, a, b, c);
}
void ContextImpl::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
if (a[1] != 0.0 || a[2] != 0.0)
throw OpenMMException("First periodic box vector must be parallel to x.");
if (b[0] != 0.0 || b[2] != 0.0)
throw OpenMMException("Second periodic box vector must be parallel to y.");
if (c[0] != 0.0 || c[1] != 0.0)
throw OpenMMException("Third periodic box vector must be parallel to z.");
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setPeriodicBoxVectors(*this, a, b, c);
}
void ContextImpl::calcForces() { void ContextImpl::calcForces() {
CalcForcesAndEnergyKernel& kernel = dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl()); CalcForcesAndEnergyKernel& kernel = dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl());
kernel.beginForceComputation(*this); kernel.beginForceComputation(*this);
......
...@@ -52,7 +52,7 @@ void MonteCarloBarostatImpl::initialize(ContextImpl& context) { ...@@ -52,7 +52,7 @@ void MonteCarloBarostatImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(ApplyMonteCarloBarostatKernel::Name(), context); kernel = context.getPlatform().createKernel(ApplyMonteCarloBarostatKernel::Name(), context);
dynamic_cast<ApplyMonteCarloBarostatKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); dynamic_cast<ApplyMonteCarloBarostatKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
Vec3 box[3]; Vec3 box[3];
context.getOwner().getPeriodicBoxVectors(box[0], box[1], box[2]); context.getPeriodicBoxVectors(box[0], box[1], box[2]);
double volume = box[0][0]*box[1][1]*box[2][2]; double volume = box[0][0]*box[1][1]*box[2][2];
volumeScale = 0.01*volume; volumeScale = 0.01*volume;
init_gen_rand(owner.getRandomNumberSeed(), random); init_gen_rand(owner.getRandomNumberSeed(), random);
...@@ -70,7 +70,7 @@ void MonteCarloBarostatImpl::updateContextState(ContextImpl& context) { ...@@ -70,7 +70,7 @@ void MonteCarloBarostatImpl::updateContextState(ContextImpl& context) {
// Modify the periodic box size. // Modify the periodic box size.
Vec3 box[3]; Vec3 box[3];
context.getOwner().getPeriodicBoxVectors(box[0], box[1], box[2]); context.getPeriodicBoxVectors(box[0], box[1], box[2]);
double volume = box[0][0]*box[1][1]*box[2][2]; double volume = box[0][0]*box[1][1]*box[2][2];
double deltaVolume = volumeScale*2*(genrand_real2(random)-0.5); double deltaVolume = volumeScale*2*(genrand_real2(random)-0.5);
double newVolume = volume+deltaVolume; double newVolume = volume+deltaVolume;
......
...@@ -63,6 +63,11 @@ double State::getPotentialEnergy() const { ...@@ -63,6 +63,11 @@ double State::getPotentialEnergy() const {
throw OpenMMException("Invoked getPotentialEnergy() on a State which does not contain energies."); throw OpenMMException("Invoked getPotentialEnergy() on a State which does not contain energies.");
return pe; return pe;
} }
void State::getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) const {
a = periodicBoxVectors[0];
b = periodicBoxVectors[1];
c = periodicBoxVectors[2];
}
const map<string, double>& State::getParameters() const { const map<string, double>& State::getParameters() const {
if ((types&Parameters) == 0) if ((types&Parameters) == 0)
throw OpenMMException("Invoked getParameters() on a State which does not contain parameters."); throw OpenMMException("Invoked getParameters() on a State which does not contain parameters.");
...@@ -88,3 +93,9 @@ void State::setEnergy(double kinetic, double potential) { ...@@ -88,3 +93,9 @@ void State::setEnergy(double kinetic, double potential) {
ke = kinetic; ke = kinetic;
pe = potential; pe = potential;
} }
void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
periodicBoxVectors[0] = a;
periodicBoxVectors[1] = b;
periodicBoxVectors[2] = c;
}
...@@ -69,7 +69,7 @@ void System::getDefaultPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) const { ...@@ -69,7 +69,7 @@ void System::getDefaultPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) const {
c = periodicBoxVectors[2]; c = periodicBoxVectors[2];
} }
void System::setDefaultPeriodicBoxVectors(Vec3 a, Vec3 b, Vec3 c) { void System::setDefaultPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
if (a[1] != 0.0 || a[2] != 0.0) if (a[1] != 0.0 || a[2] != 0.0)
throw OpenMMException("First periodic box vector must be parallel to x."); throw OpenMMException("First periodic box vector must be parallel to x.");
if (b[0] != 0.0 || b[2] != 0.0) if (b[0] != 0.0 || b[2] != 0.0)
......
...@@ -45,13 +45,6 @@ void CudaCalcForcesAndEnergyKernel::initialize(const System& system) { ...@@ -45,13 +45,6 @@ void CudaCalcForcesAndEnergyKernel::initialize(const System& system) {
void CudaCalcForcesAndEnergyKernel::beginForceComputation(ContextImpl& context) { void CudaCalcForcesAndEnergyKernel::beginForceComputation(ContextImpl& context) {
_gpuContext* gpu = data.gpu; _gpuContext* gpu = data.gpu;
Vec3 boxVectors[3];
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
float boxx = boxVectors[0][0], boxy = boxVectors[1][1], boxz = boxVectors[2][2];
if (boxx != gpu->sim.periodicBoxSizeX || boxy != gpu->sim.periodicBoxSizeY || boxz != gpu->sim.periodicBoxSizeZ) {
gpuSetPeriodicBoxSize(gpu, boxx, boxy, boxz);
gpuSetConstants(gpu);
}
if (data.nonbondedMethod != NO_CUTOFF && data.computeForceCount%100 == 0) if (data.nonbondedMethod != NO_CUTOFF && data.computeForceCount%100 == 0)
gpuReorderAtoms(gpu); gpuReorderAtoms(gpu);
data.computeForceCount++; data.computeForceCount++;
...@@ -84,13 +77,6 @@ void CudaCalcForcesAndEnergyKernel::finishForceComputation(ContextImpl& context) ...@@ -84,13 +77,6 @@ void CudaCalcForcesAndEnergyKernel::finishForceComputation(ContextImpl& context)
void CudaCalcForcesAndEnergyKernel::beginEnergyComputation(ContextImpl& context) { void CudaCalcForcesAndEnergyKernel::beginEnergyComputation(ContextImpl& context) {
_gpuContext* gpu = data.gpu; _gpuContext* gpu = data.gpu;
Vec3 boxVectors[3];
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
float boxx = boxVectors[0][0], boxy = boxVectors[1][1], boxz = boxVectors[2][2];
if (boxx != gpu->sim.periodicBoxSizeX || boxy != gpu->sim.periodicBoxSizeY || boxz != gpu->sim.periodicBoxSizeZ) {
gpuSetPeriodicBoxSize(gpu, boxx, boxy, boxz);
gpuSetConstants(gpu);
}
if (data.nonbondedMethod != NO_CUTOFF && data.stepCount%100 == 0) if (data.nonbondedMethod != NO_CUTOFF && data.stepCount%100 == 0)
gpuReorderAtoms(gpu); gpuReorderAtoms(gpu);
data.stepCount++; data.stepCount++;
...@@ -197,6 +183,19 @@ void CudaUpdateStateDataKernel::getForces(ContextImpl& context, std::vector<Vec3 ...@@ -197,6 +183,19 @@ void CudaUpdateStateDataKernel::getForces(ContextImpl& context, std::vector<Vec3
} }
} }
void CudaUpdateStateDataKernel::getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const {
_gpuContext* gpu = data.gpu;
a = Vec3(gpu->sim.periodicBoxSizeX, 0, 0);
b = Vec3(0, gpu->sim.periodicBoxSizeY, 0);
c = Vec3(0, 0, gpu->sim.periodicBoxSizeZ);
}
void CudaUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const {
_gpuContext* gpu = data.gpu;
gpuSetPeriodicBoxSize(gpu, a[0], b[1], c[2]);
gpuSetConstants(gpu);
}
CudaCalcHarmonicBondForceKernel::~CudaCalcHarmonicBondForceKernel() { CudaCalcHarmonicBondForceKernel::~CudaCalcHarmonicBondForceKernel() {
} }
......
...@@ -149,6 +149,22 @@ public: ...@@ -149,6 +149,22 @@ public:
* @param forces on exit, this contains the forces * @param forces on exit, this contains the forces
*/ */
void getForces(ContextImpl& context, std::vector<Vec3>& forces); void getForces(ContextImpl& context, std::vector<Vec3>& forces);
/**
* Get the current periodic box vectors.
*
* @param a on exit, this contains the vector defining the first edge of the periodic box
* @param b on exit, this contains the vector defining the second edge of the periodic box
* @param c on exit, this contains the vector defining the third edge of the periodic box
*/
void getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const;
/**
* Set the current periodic box vectors.
*
* @param a the vector defining the first edge of the periodic box
* @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box
*/
void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const;
private: private:
CudaPlatform::PlatformData& data; CudaPlatform::PlatformData& data;
}; };
......
...@@ -57,12 +57,12 @@ void testChangingBoxSize() { ...@@ -57,12 +57,12 @@ void testChangingBoxSize() {
LangevinIntegrator integrator(300.0, 1.0, 0.01); LangevinIntegrator integrator(300.0, 1.0, 0.01);
Context context(system, integrator, platform); Context context(system, integrator, platform);
Vec3 x, y, z; Vec3 x, y, z;
context.getPeriodicBoxVectors(x, y, z); context.getState(0).getPeriodicBoxVectors(x, y, z);
ASSERT_EQUAL_VEC(Vec3(4, 0, 0), x, 0); ASSERT_EQUAL_VEC(Vec3(4, 0, 0), x, 0);
ASSERT_EQUAL_VEC(Vec3(0, 5, 0), y, 0); ASSERT_EQUAL_VEC(Vec3(0, 5, 0), y, 0);
ASSERT_EQUAL_VEC(Vec3(0, 0, 6), z, 0); ASSERT_EQUAL_VEC(Vec3(0, 0, 6), z, 0);
context.setPeriodicBoxVectors(Vec3(7, 0, 0), Vec3(0, 8, 0), Vec3(0, 0, 9)); context.setPeriodicBoxVectors(Vec3(7, 0, 0), Vec3(0, 8, 0), Vec3(0, 0, 9));
context.getPeriodicBoxVectors(x, y, z); context.getState(0).getPeriodicBoxVectors(x, y, z);
ASSERT_EQUAL_VEC(Vec3(7, 0, 0), x, 0); ASSERT_EQUAL_VEC(Vec3(7, 0, 0), x, 0);
ASSERT_EQUAL_VEC(Vec3(0, 8, 0), y, 0); ASSERT_EQUAL_VEC(Vec3(0, 8, 0), y, 0);
ASSERT_EQUAL_VEC(Vec3(0, 0, 9), z, 0); ASSERT_EQUAL_VEC(Vec3(0, 0, 9), z, 0);
...@@ -110,7 +110,7 @@ void testIdealGas() { ...@@ -110,7 +110,7 @@ void testIdealGas() {
double volume = 0.0; double volume = 0.0;
for (int j = 0; j < steps; ++j) { for (int j = 0; j < steps; ++j) {
Vec3 box[3]; Vec3 box[3];
context.getPeriodicBoxVectors(box[0], box[1], box[2]); context.getState(0).getPeriodicBoxVectors(box[0], box[1], box[2]);
volume += box[0][0]*box[1][1]*box[2][2]; volume += box[0][0]*box[1][1]*box[2][2];
ASSERT_EQUAL_TOL(0.5*box[0][0], box[1][1], 1e-5); ASSERT_EQUAL_TOL(0.5*box[0][0], box[1][1], 1e-5);
ASSERT_EQUAL_TOL(2*box[0][0], box[2][2], 1e-5); ASSERT_EQUAL_TOL(2*box[0][0], box[2][2], 1e-5);
...@@ -243,7 +243,7 @@ void testWater() { ...@@ -243,7 +243,7 @@ void testWater() {
double volume = 0.0; double volume = 0.0;
for (int j = 0; j < steps; ++j) { for (int j = 0; j < steps; ++j) {
Vec3 box[3]; Vec3 box[3];
context.getPeriodicBoxVectors(box[0], box[1], box[2]); context.getState(0).getPeriodicBoxVectors(box[0], box[1], box[2]);
volume += box[0][0]*box[1][1]*box[2][2]; volume += box[0][0]*box[1][1]*box[2][2];
integrator.step(frequency); integrator.step(frequency);
} }
......
...@@ -70,9 +70,6 @@ void OpenCLCalcForcesAndEnergyKernel::initialize(const System& system) { ...@@ -70,9 +70,6 @@ void OpenCLCalcForcesAndEnergyKernel::initialize(const System& system) {
} }
void OpenCLCalcForcesAndEnergyKernel::beginForceComputation(ContextImpl& context) { void OpenCLCalcForcesAndEnergyKernel::beginForceComputation(ContextImpl& context) {
Vec3 boxVectors[3];
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
cl.setPeriodicBoxSize(boxVectors[0][0], boxVectors[1][1], boxVectors[2][2]);
if (cl.getNonbondedUtilities().getUseCutoff() && cl.getComputeForceCount()%100 == 0) if (cl.getNonbondedUtilities().getUseCutoff() && cl.getComputeForceCount()%100 == 0)
cl.reorderAtoms(); cl.reorderAtoms();
cl.setComputeForceCount(cl.getComputeForceCount()+1); cl.setComputeForceCount(cl.getComputeForceCount()+1);
...@@ -86,9 +83,6 @@ void OpenCLCalcForcesAndEnergyKernel::finishForceComputation(ContextImpl& contex ...@@ -86,9 +83,6 @@ void OpenCLCalcForcesAndEnergyKernel::finishForceComputation(ContextImpl& contex
} }
void OpenCLCalcForcesAndEnergyKernel::beginEnergyComputation(ContextImpl& context) { void OpenCLCalcForcesAndEnergyKernel::beginEnergyComputation(ContextImpl& context) {
Vec3 boxVectors[3];
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
cl.setPeriodicBoxSize(boxVectors[0][0], boxVectors[1][1], boxVectors[2][2]);
if (cl.getNonbondedUtilities().getUseCutoff() && cl.getComputeForceCount()%100 == 0) if (cl.getNonbondedUtilities().getUseCutoff() && cl.getComputeForceCount()%100 == 0)
cl.reorderAtoms(); cl.reorderAtoms();
cl.setComputeForceCount(cl.getComputeForceCount()+1); cl.setComputeForceCount(cl.getComputeForceCount()+1);
...@@ -185,6 +179,17 @@ void OpenCLUpdateStateDataKernel::getForces(ContextImpl& context, std::vector<Ve ...@@ -185,6 +179,17 @@ void OpenCLUpdateStateDataKernel::getForces(ContextImpl& context, std::vector<Ve
} }
} }
void OpenCLUpdateStateDataKernel::getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const {
mm_float4 box = cl.getPeriodicBoxSize();
a = Vec3(box.x, 0, 0);
b = Vec3(0, box.y, 0);
c = Vec3(0, 0, box.z);
}
void OpenCLUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const {
cl.setPeriodicBoxSize(a[0], b[1], c[2]);
}
class OpenCLBondForceInfo : public OpenCLForceInfo { class OpenCLBondForceInfo : public OpenCLForceInfo {
public: public:
OpenCLBondForceInfo(int requiredBuffers, const HarmonicBondForce& force) : OpenCLForceInfo(requiredBuffers), force(force) { OpenCLBondForceInfo(int requiredBuffers, const HarmonicBondForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
......
...@@ -144,6 +144,22 @@ public: ...@@ -144,6 +144,22 @@ public:
* @param forces on exit, this contains the forces * @param forces on exit, this contains the forces
*/ */
void getForces(ContextImpl& context, std::vector<Vec3>& forces); void getForces(ContextImpl& context, std::vector<Vec3>& forces);
/**
* Get the current periodic box vectors.
*
* @param a on exit, this contains the vector defining the first edge of the periodic box
* @param b on exit, this contains the vector defining the second edge of the periodic box
* @param c on exit, this contains the vector defining the third edge of the periodic box
*/
void getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const;
/**
* Set the current periodic box vectors.
*
* @param a the vector defining the first edge of the periodic box
* @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box
*/
void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const;
private: private:
OpenCLContext& cl; OpenCLContext& cl;
}; };
......
...@@ -57,12 +57,12 @@ void testChangingBoxSize() { ...@@ -57,12 +57,12 @@ void testChangingBoxSize() {
LangevinIntegrator integrator(300.0, 1.0, 0.01); LangevinIntegrator integrator(300.0, 1.0, 0.01);
Context context(system, integrator, platform); Context context(system, integrator, platform);
Vec3 x, y, z; Vec3 x, y, z;
context.getPeriodicBoxVectors(x, y, z); context.getState(0).getPeriodicBoxVectors(x, y, z);
ASSERT_EQUAL_VEC(Vec3(4, 0, 0), x, 0); ASSERT_EQUAL_VEC(Vec3(4, 0, 0), x, 0);
ASSERT_EQUAL_VEC(Vec3(0, 5, 0), y, 0); ASSERT_EQUAL_VEC(Vec3(0, 5, 0), y, 0);
ASSERT_EQUAL_VEC(Vec3(0, 0, 6), z, 0); ASSERT_EQUAL_VEC(Vec3(0, 0, 6), z, 0);
context.setPeriodicBoxVectors(Vec3(7, 0, 0), Vec3(0, 8, 0), Vec3(0, 0, 9)); context.setPeriodicBoxVectors(Vec3(7, 0, 0), Vec3(0, 8, 0), Vec3(0, 0, 9));
context.getPeriodicBoxVectors(x, y, z); context.getState(0).getPeriodicBoxVectors(x, y, z);
ASSERT_EQUAL_VEC(Vec3(7, 0, 0), x, 0); ASSERT_EQUAL_VEC(Vec3(7, 0, 0), x, 0);
ASSERT_EQUAL_VEC(Vec3(0, 8, 0), y, 0); ASSERT_EQUAL_VEC(Vec3(0, 8, 0), y, 0);
ASSERT_EQUAL_VEC(Vec3(0, 0, 9), z, 0); ASSERT_EQUAL_VEC(Vec3(0, 0, 9), z, 0);
...@@ -110,7 +110,7 @@ void testIdealGas() { ...@@ -110,7 +110,7 @@ void testIdealGas() {
double volume = 0.0; double volume = 0.0;
for (int j = 0; j < steps; ++j) { for (int j = 0; j < steps; ++j) {
Vec3 box[3]; Vec3 box[3];
context.getPeriodicBoxVectors(box[0], box[1], box[2]); context.getState(0).getPeriodicBoxVectors(box[0], box[1], box[2]);
volume += box[0][0]*box[1][1]*box[2][2]; volume += box[0][0]*box[1][1]*box[2][2];
ASSERT_EQUAL_TOL(0.5*box[0][0], box[1][1], 1e-5); ASSERT_EQUAL_TOL(0.5*box[0][0], box[1][1], 1e-5);
ASSERT_EQUAL_TOL(2*box[0][0], box[2][2], 1e-5); ASSERT_EQUAL_TOL(2*box[0][0], box[2][2], 1e-5);
...@@ -243,7 +243,7 @@ void testWater() { ...@@ -243,7 +243,7 @@ void testWater() {
double volume = 0.0; double volume = 0.0;
for (int j = 0; j < steps; ++j) { for (int j = 0; j < steps; ++j) {
Vec3 box[3]; Vec3 box[3];
context.getPeriodicBoxVectors(box[0], box[1], box[2]); context.getState(0).getPeriodicBoxVectors(box[0], box[1], box[2]);
volume += box[0][0]*box[1][1]*box[2][2]; volume += box[0][0]*box[1][1]*box[2][2];
integrator.step(frequency); integrator.step(frequency);
} }
......
...@@ -66,6 +66,7 @@ public: ...@@ -66,6 +66,7 @@ public:
void* positions; void* positions;
void* velocities; void* velocities;
void* forces; void* forces;
void* periodicBoxSize;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -137,6 +137,11 @@ static RealOpenMM** extractForces(ContextImpl& context) { ...@@ -137,6 +137,11 @@ static RealOpenMM** extractForces(ContextImpl& context) {
return (RealOpenMM**) data->forces; return (RealOpenMM**) data->forces;
} }
static RealOpenMM* extractBoxSize(ContextImpl& context) {
ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return (RealOpenMM*) data->periodicBoxSize;
}
static void findAnglesForCCMA(const System& system, vector<ReferenceCCMAAlgorithm::AngleInfo>& angles) { static void findAnglesForCCMA(const System& system, vector<ReferenceCCMAAlgorithm::AngleInfo>& angles) {
for (int i = 0; i < system.getNumForces(); i++) { for (int i = 0; i < system.getNumForces(); i++) {
const HarmonicAngleForce* force = dynamic_cast<const HarmonicAngleForce*>(&system.getForce(i)); const HarmonicAngleForce* force = dynamic_cast<const HarmonicAngleForce*>(&system.getForce(i));
...@@ -229,6 +234,20 @@ void ReferenceUpdateStateDataKernel::getForces(ContextImpl& context, std::vector ...@@ -229,6 +234,20 @@ void ReferenceUpdateStateDataKernel::getForces(ContextImpl& context, std::vector
forces[i] = Vec3(forceData[i][0], forceData[i][1], forceData[i][2]); forces[i] = Vec3(forceData[i][0], forceData[i][1], forceData[i][2]);
} }
void ReferenceUpdateStateDataKernel::getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const {
RealOpenMM* box = extractBoxSize(context);
a = Vec3(box[0], 0, 0);
b = Vec3(0, box[1], 0);
c = Vec3(0, 0, box[2]);
}
void ReferenceUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const {
RealOpenMM* box = extractBoxSize(context);
box[0] = (RealOpenMM) a[0];
box[1] = (RealOpenMM) b[1];
box[2] = (RealOpenMM) c[2];
}
ReferenceCalcHarmonicBondForceKernel::~ReferenceCalcHarmonicBondForceKernel() { ReferenceCalcHarmonicBondForceKernel::~ReferenceCalcHarmonicBondForceKernel() {
disposeIntArray(bondIndexArray, numBonds); disposeIntArray(bondIndexArray, numBonds);
disposeRealArray(bondParamArray, numBonds); disposeRealArray(bondParamArray, numBonds);
...@@ -687,18 +706,12 @@ void ReferenceCalcNonbondedForceKernel::executeForces(ContextImpl& context) { ...@@ -687,18 +706,12 @@ void ReferenceCalcNonbondedForceKernel::executeForces(ContextImpl& context) {
bool periodic = (nonbondedMethod == CutoffPeriodic); bool periodic = (nonbondedMethod == CutoffPeriodic);
bool ewald = (nonbondedMethod == Ewald); bool ewald = (nonbondedMethod == Ewald);
bool pme = (nonbondedMethod == PME); bool pme = (nonbondedMethod == PME);
RealOpenMM periodicBoxSize[3];
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
Vec3 boxVectors[3]; computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, (periodic || ewald || pme) ? extractBoxSize(context) : NULL, nonbondedCutoff, 0.0);
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, (periodic || ewald || pme) ? periodicBoxSize : NULL, nonbondedCutoff, 0.0);
clj.setUseCutoff(nonbondedCutoff, *neighborList, rfDielectric); clj.setUseCutoff(nonbondedCutoff, *neighborList, rfDielectric);
} }
if (periodic || ewald || pme) if (periodic || ewald || pme)
clj.setPeriodic(periodicBoxSize); clj.setPeriodic(extractBoxSize(context));
if (ewald) if (ewald)
clj.setUseEwald(ewaldAlpha, kmax[0], kmax[1], kmax[2]); clj.setUseEwald(ewaldAlpha, kmax[0], kmax[1], kmax[2]);
if (pme) if (pme)
...@@ -717,18 +730,12 @@ double ReferenceCalcNonbondedForceKernel::executeEnergy(ContextImpl& context) { ...@@ -717,18 +730,12 @@ double ReferenceCalcNonbondedForceKernel::executeEnergy(ContextImpl& context) {
bool periodic = (nonbondedMethod == CutoffPeriodic); bool periodic = (nonbondedMethod == CutoffPeriodic);
bool ewald = (nonbondedMethod == Ewald); bool ewald = (nonbondedMethod == Ewald);
bool pme = (nonbondedMethod == PME); bool pme = (nonbondedMethod == PME);
RealOpenMM periodicBoxSize[3];
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
Vec3 boxVectors[3]; computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, (periodic || ewald || pme) ? extractBoxSize(context) : NULL, nonbondedCutoff, 0.0);
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, (periodic || ewald || pme) ? periodicBoxSize : NULL, nonbondedCutoff, 0.0);
clj.setUseCutoff(nonbondedCutoff, *neighborList, rfDielectric); clj.setUseCutoff(nonbondedCutoff, *neighborList, rfDielectric);
} }
if (periodic || ewald || pme) if (periodic || ewald || pme)
clj.setPeriodic(periodicBoxSize); clj.setPeriodic(extractBoxSize(context));
if (ewald) if (ewald)
clj.setUseEwald(ewaldAlpha, kmax[0], kmax[1], kmax[2]); clj.setUseEwald(ewaldAlpha, kmax[0], kmax[1], kmax[2]);
if (pme) if (pme)
...@@ -882,18 +889,12 @@ void ReferenceCalcCustomNonbondedForceKernel::executeForces(ContextImpl& context ...@@ -882,18 +889,12 @@ void ReferenceCalcCustomNonbondedForceKernel::executeForces(ContextImpl& context
RealOpenMM** forceData = extractForces(context); RealOpenMM** forceData = extractForces(context);
ReferenceCustomNonbondedIxn ixn(energyExpression, forceExpression, parameterNames); ReferenceCustomNonbondedIxn ixn(energyExpression, forceExpression, parameterNames);
bool periodic = (nonbondedMethod == CutoffPeriodic); bool periodic = (nonbondedMethod == CutoffPeriodic);
RealOpenMM periodicBoxSize[3];
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
Vec3 boxVectors[3]; computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, periodic ? extractBoxSize(context) : NULL, nonbondedCutoff, 0.0);
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, periodic ? periodicBoxSize : NULL, nonbondedCutoff, 0.0);
ixn.setUseCutoff(nonbondedCutoff, *neighborList); ixn.setUseCutoff(nonbondedCutoff, *neighborList);
} }
if (periodic) if (periodic)
ixn.setPeriodic(periodicBoxSize); ixn.setPeriodic(extractBoxSize(context));
map<string, double> globalParameters; map<string, double> globalParameters;
for (int i = 0; i < (int) globalParameterNames.size(); i++) for (int i = 0; i < (int) globalParameterNames.size(); i++)
globalParameters[globalParameterNames[i]] = context.getParameter(globalParameterNames[i]); globalParameters[globalParameterNames[i]] = context.getParameter(globalParameterNames[i]);
...@@ -906,18 +907,12 @@ double ReferenceCalcCustomNonbondedForceKernel::executeEnergy(ContextImpl& conte ...@@ -906,18 +907,12 @@ double ReferenceCalcCustomNonbondedForceKernel::executeEnergy(ContextImpl& conte
RealOpenMM energy = 0; RealOpenMM energy = 0;
ReferenceCustomNonbondedIxn ixn(energyExpression, forceExpression, parameterNames); ReferenceCustomNonbondedIxn ixn(energyExpression, forceExpression, parameterNames);
bool periodic = (nonbondedMethod == CutoffPeriodic); bool periodic = (nonbondedMethod == CutoffPeriodic);
RealOpenMM periodicBoxSize[3];
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
Vec3 boxVectors[3]; computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, periodic ? extractBoxSize(context) : NULL, nonbondedCutoff, 0.0);
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, periodic ? periodicBoxSize : NULL, nonbondedCutoff, 0.0);
ixn.setUseCutoff(nonbondedCutoff, *neighborList); ixn.setUseCutoff(nonbondedCutoff, *neighborList);
} }
if (periodic) if (periodic)
ixn.setPeriodic(periodicBoxSize); ixn.setPeriodic(extractBoxSize(context));
map<string, double> globalParameters; map<string, double> globalParameters;
for (int i = 0; i < (int) globalParameterNames.size(); i++) for (int i = 0; i < (int) globalParameterNames.size(); i++)
globalParameters[globalParameterNames[i]] = context.getParameter(globalParameterNames[i]); globalParameters[globalParameterNames[i]] = context.getParameter(globalParameterNames[i]);
...@@ -960,30 +955,16 @@ void ReferenceCalcGBSAOBCForceKernel::initialize(const System& system, const GBS ...@@ -960,30 +955,16 @@ void ReferenceCalcGBSAOBCForceKernel::initialize(const System& system, const GBS
void ReferenceCalcGBSAOBCForceKernel::executeForces(ContextImpl& context) { void ReferenceCalcGBSAOBCForceKernel::executeForces(ContextImpl& context) {
RealOpenMM** posData = extractPositions(context); RealOpenMM** posData = extractPositions(context);
RealOpenMM** forceData = extractForces(context); RealOpenMM** forceData = extractForces(context);
if (isPeriodic) { if (isPeriodic)
Vec3 boxVectors[3]; obc->getObcParameters()->setPeriodic(extractBoxSize(context));
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
RealOpenMM periodicBoxSize[3];
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
obc->getObcParameters()->setPeriodic(periodicBoxSize);
}
obc->computeImplicitSolventForces(posData, &charges[0], forceData, 1); obc->computeImplicitSolventForces(posData, &charges[0], forceData, 1);
} }
double ReferenceCalcGBSAOBCForceKernel::executeEnergy(ContextImpl& context) { double ReferenceCalcGBSAOBCForceKernel::executeEnergy(ContextImpl& context) {
RealOpenMM** posData = extractPositions(context); RealOpenMM** posData = extractPositions(context);
RealOpenMM** forceData = allocateRealArray(context.getSystem().getNumParticles(), 3); RealOpenMM** forceData = allocateRealArray(context.getSystem().getNumParticles(), 3);
if (isPeriodic) { if (isPeriodic)
Vec3 boxVectors[3]; obc->getObcParameters()->setPeriodic(extractBoxSize(context));
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
RealOpenMM periodicBoxSize[3];
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
obc->getObcParameters()->setPeriodic(periodicBoxSize);
}
obc->computeImplicitSolventForces(posData, &charges[0], forceData, 1); obc->computeImplicitSolventForces(posData, &charges[0], forceData, 1);
disposeRealArray(forceData, context.getSystem().getNumParticles()); disposeRealArray(forceData, context.getSystem().getNumParticles());
return obc->getEnergy(); return obc->getEnergy();
...@@ -1026,15 +1007,8 @@ void ReferenceCalcGBVIForceKernel::executeForces(ContextImpl& context) { ...@@ -1026,15 +1007,8 @@ void ReferenceCalcGBVIForceKernel::executeForces(ContextImpl& context) {
RealOpenMM** posData = extractPositions(context); RealOpenMM** posData = extractPositions(context);
RealOpenMM** forceData = extractForces(context); RealOpenMM** forceData = extractForces(context);
RealOpenMM* bornRadii = new RealOpenMM[context.getSystem().getNumParticles()]; RealOpenMM* bornRadii = new RealOpenMM[context.getSystem().getNumParticles()];
if (isPeriodic) { if (isPeriodic)
Vec3 boxVectors[3]; gbvi->getGBVIParameters()->setPeriodic(extractBoxSize(context));
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
RealOpenMM periodicBoxSize[3];
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
gbvi->getGBVIParameters()->setPeriodic(periodicBoxSize);
}
gbvi->computeBornRadii(posData, bornRadii, NULL ); gbvi->computeBornRadii(posData, bornRadii, NULL );
gbvi->computeBornForces(bornRadii, posData, &charges[0], forceData); gbvi->computeBornForces(bornRadii, posData, &charges[0], forceData);
delete[] bornRadii; delete[] bornRadii;
...@@ -1043,15 +1017,8 @@ void ReferenceCalcGBVIForceKernel::executeForces(ContextImpl& context) { ...@@ -1043,15 +1017,8 @@ void ReferenceCalcGBVIForceKernel::executeForces(ContextImpl& context) {
double ReferenceCalcGBVIForceKernel::executeEnergy(ContextImpl& context) { double ReferenceCalcGBVIForceKernel::executeEnergy(ContextImpl& context) {
RealOpenMM** posData = extractPositions(context); RealOpenMM** posData = extractPositions(context);
RealOpenMM* bornRadii = new RealOpenMM[context.getSystem().getNumParticles()]; RealOpenMM* bornRadii = new RealOpenMM[context.getSystem().getNumParticles()];
if (isPeriodic) { if (isPeriodic)
Vec3 boxVectors[3]; gbvi->getGBVIParameters()->setPeriodic(extractBoxSize(context));
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
RealOpenMM periodicBoxSize[3];
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
gbvi->getGBVIParameters()->setPeriodic(periodicBoxSize);
}
gbvi->computeBornRadii(posData, bornRadii, NULL ); gbvi->computeBornRadii(posData, bornRadii, NULL );
RealOpenMM energy = gbvi->computeBornEnergy(bornRadii ,posData, &charges[0]); RealOpenMM energy = gbvi->computeBornEnergy(bornRadii ,posData, &charges[0]);
delete[] bornRadii; delete[] bornRadii;
...@@ -1184,17 +1151,10 @@ void ReferenceCalcCustomGBForceKernel::executeForces(ContextImpl& context) { ...@@ -1184,17 +1151,10 @@ void ReferenceCalcCustomGBForceKernel::executeForces(ContextImpl& context) {
ReferenceCustomGBIxn ixn(valueExpressions, valueDerivExpressions, valueGradientExpressions, valueNames, valueTypes, energyExpressions, ReferenceCustomGBIxn ixn(valueExpressions, valueDerivExpressions, valueGradientExpressions, valueNames, valueTypes, energyExpressions,
energyDerivExpressions, energyGradientExpressions, energyTypes, particleParameterNames); energyDerivExpressions, energyGradientExpressions, energyTypes, particleParameterNames);
bool periodic = (nonbondedMethod == CutoffPeriodic); bool periodic = (nonbondedMethod == CutoffPeriodic);
RealOpenMM periodicBoxSize[3]; if (periodic)
if (periodic) { ixn.setPeriodic(extractBoxSize(context));
Vec3 boxVectors[3];
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
ixn.setPeriodic(periodicBoxSize);
}
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, periodic ? periodicBoxSize : NULL, nonbondedCutoff, 0.0); computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, periodic ? extractBoxSize(context) : NULL, nonbondedCutoff, 0.0);
ixn.setUseCutoff(nonbondedCutoff, *neighborList); ixn.setUseCutoff(nonbondedCutoff, *neighborList);
} }
map<string, double> globalParameters; map<string, double> globalParameters;
...@@ -1210,17 +1170,10 @@ double ReferenceCalcCustomGBForceKernel::executeEnergy(ContextImpl& context) { ...@@ -1210,17 +1170,10 @@ double ReferenceCalcCustomGBForceKernel::executeEnergy(ContextImpl& context) {
ReferenceCustomGBIxn ixn(valueExpressions, valueDerivExpressions, valueGradientExpressions, valueNames, valueTypes, energyExpressions, ReferenceCustomGBIxn ixn(valueExpressions, valueDerivExpressions, valueGradientExpressions, valueNames, valueTypes, energyExpressions,
energyDerivExpressions, energyGradientExpressions, energyTypes, particleParameterNames); energyDerivExpressions, energyGradientExpressions, energyTypes, particleParameterNames);
bool periodic = (nonbondedMethod == CutoffPeriodic); bool periodic = (nonbondedMethod == CutoffPeriodic);
RealOpenMM periodicBoxSize[3]; if (periodic)
if (periodic) { ixn.setPeriodic(extractBoxSize(context));
Vec3 boxVectors[3];
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
ixn.setPeriodic(periodicBoxSize);
}
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, periodic ? periodicBoxSize : NULL, nonbondedCutoff, 0.0); computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, periodic ? extractBoxSize(context) : NULL, nonbondedCutoff, 0.0);
ixn.setUseCutoff(nonbondedCutoff, *neighborList); ixn.setUseCutoff(nonbondedCutoff, *neighborList);
} }
map<string, double> globalParameters; map<string, double> globalParameters;
...@@ -1389,15 +1342,8 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const ...@@ -1389,15 +1342,8 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
void ReferenceCalcCustomHbondForceKernel::executeForces(ContextImpl& context) { void ReferenceCalcCustomHbondForceKernel::executeForces(ContextImpl& context) {
RealOpenMM** posData = extractPositions(context); RealOpenMM** posData = extractPositions(context);
RealOpenMM** forceData = extractForces(context); RealOpenMM** forceData = extractForces(context);
if (isPeriodic) { if (isPeriodic)
RealOpenMM periodicBoxSize[3]; ixn->setPeriodic(extractBoxSize(context));
Vec3 boxVectors[3];
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
ixn->setPeriodic(periodicBoxSize);
}
map<string, double> globalParameters; map<string, double> globalParameters;
for (int i = 0; i < (int) globalParameterNames.size(); i++) for (int i = 0; i < (int) globalParameterNames.size(); i++)
globalParameters[globalParameterNames[i]] = context.getParameter(globalParameterNames[i]); globalParameters[globalParameterNames[i]] = context.getParameter(globalParameterNames[i]);
...@@ -1407,15 +1353,8 @@ void ReferenceCalcCustomHbondForceKernel::executeForces(ContextImpl& context) { ...@@ -1407,15 +1353,8 @@ void ReferenceCalcCustomHbondForceKernel::executeForces(ContextImpl& context) {
double ReferenceCalcCustomHbondForceKernel::executeEnergy(ContextImpl& context) { double ReferenceCalcCustomHbondForceKernel::executeEnergy(ContextImpl& context) {
RealOpenMM** posData = extractPositions(context); RealOpenMM** posData = extractPositions(context);
RealOpenMM** forceData = allocateRealArray(numParticles, 3); RealOpenMM** forceData = allocateRealArray(numParticles, 3);
if (isPeriodic) { if (isPeriodic)
RealOpenMM periodicBoxSize[3]; ixn->setPeriodic(extractBoxSize(context));
Vec3 boxVectors[3];
context.getOwner().getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
periodicBoxSize[0] = (RealOpenMM) boxVectors[0][0];
periodicBoxSize[1] = (RealOpenMM) boxVectors[1][1];
periodicBoxSize[2] = (RealOpenMM) boxVectors[2][2];
ixn->setPeriodic(periodicBoxSize);
}
RealOpenMM energy = 0; RealOpenMM energy = 0;
map<string, double> globalParameters; map<string, double> globalParameters;
for (int i = 0; i < (int) globalParameterNames.size(); i++) for (int i = 0; i < (int) globalParameterNames.size(); i++)
...@@ -1770,9 +1709,7 @@ void ReferenceApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& conte ...@@ -1770,9 +1709,7 @@ void ReferenceApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& conte
if (barostat == NULL) if (barostat == NULL)
barostat = new ReferenceMonteCarloBarostat(context.getSystem().getNumParticles(), context.getMolecules()); barostat = new ReferenceMonteCarloBarostat(context.getSystem().getNumParticles(), context.getMolecules());
RealOpenMM** posData = extractPositions(context); RealOpenMM** posData = extractPositions(context);
Vec3 box[3]; RealOpenMM* boxSize = extractBoxSize(context);
context.getOwner().getPeriodicBoxVectors(box[0], box[1], box[2]);
RealOpenMM boxSize[] = {box[0][0], box[1][1], box[2][2]};
barostat->applyBarostat(posData, boxSize, scale); barostat->applyBarostat(posData, boxSize, scale);
} }
......
...@@ -156,6 +156,22 @@ public: ...@@ -156,6 +156,22 @@ public:
* @param forces on exit, this contains the forces * @param forces on exit, this contains the forces
*/ */
void getForces(ContextImpl& context, std::vector<Vec3>& forces); void getForces(ContextImpl& context, std::vector<Vec3>& forces);
/**
* Get the current periodic box vectors.
*
* @param a on exit, this contains the vector defining the first edge of the periodic box
* @param b on exit, this contains the vector defining the second edge of the periodic box
* @param c on exit, this contains the vector defining the third edge of the periodic box
*/
void getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const;
/**
* Set the current periodic box vectors.
*
* @param a the vector defining the first edge of the periodic box
* @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box
*/
void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const;
private: private:
ReferencePlatform::PlatformData& data; ReferencePlatform::PlatformData& data;
}; };
......
...@@ -92,6 +92,7 @@ ReferencePlatform::PlatformData::PlatformData(int numParticles) : time(0.0), ste ...@@ -92,6 +92,7 @@ ReferencePlatform::PlatformData::PlatformData(int numParticles) : time(0.0), ste
this->positions = positions; this->positions = positions;
this->velocities = velocities; this->velocities = velocities;
this->forces = forces; this->forces = forces;
periodicBoxSize = new RealOpenMM[3];
} }
ReferencePlatform::PlatformData::~PlatformData() { ReferencePlatform::PlatformData::~PlatformData() {
...@@ -106,4 +107,5 @@ ReferencePlatform::PlatformData::~PlatformData() { ...@@ -106,4 +107,5 @@ ReferencePlatform::PlatformData::~PlatformData() {
delete[] positions; delete[] positions;
delete[] velocities; delete[] velocities;
delete[] forces; delete[] forces;
delete[] periodicBoxSize;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment