Commit d279014c authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimization: introduced a kernel to clear forces at the start of each time...

Optimization: introduced a kernel to clear forces at the start of each time step rather than relying on the stream
parent b920a178
...@@ -52,6 +52,30 @@ ...@@ -52,6 +52,30 @@
namespace OpenMM { namespace OpenMM {
/**
* This kernel is invoked at the start of each force evaluation to clear the forces.
*/
class InitializeForcesKernel : public KernelImpl {
public:
static std::string Name() {
return "InitializeForces";
}
InitializeForcesKernel(std::string name, const Platform& platform) : KernelImpl(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
virtual void initialize(const System& system) = 0;
/**
* Execute the kernel.
*
* @param context the context in which to execute this kernel
*/
virtual double execute(OpenMMContextImpl& context) = 0;
};
/** /**
* This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system. * This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system.
*/ */
......
...@@ -166,7 +166,7 @@ private: ...@@ -166,7 +166,7 @@ private:
std::map<std::string, double> parameters; std::map<std::string, double> parameters;
Platform* platform; Platform* platform;
Stream positions, velocities, forces; Stream positions, velocities, forces;
Kernel kineticEnergyKernel; Kernel initializeForcesKernel, kineticEnergyKernel;
void* platformData; void* platformData;
}; };
......
...@@ -52,6 +52,7 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ ...@@ -52,6 +52,7 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ
{ {
vector<string> kernelNames; vector<string> kernelNames;
kernelNames.push_back(CalcKineticEnergyKernel::Name()); kernelNames.push_back(CalcKineticEnergyKernel::Name());
kernelNames.push_back(InitializeForcesKernel::Name());
for (int i = 0; i < system.getNumForces(); ++i) { for (int i = 0; i < system.getNumForces(); ++i) {
forceImpls.push_back(system.getForce(i).createImpl()); forceImpls.push_back(system.getForce(i).createImpl());
map<string, double> forceParameters = forceImpls[forceImpls.size()-1]->getDefaultParameters(); map<string, double> forceParameters = forceImpls[forceImpls.size()-1]->getDefaultParameters();
...@@ -66,10 +67,9 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ ...@@ -66,10 +67,9 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ
else if (!platform->supportsKernels(kernelNames)) else if (!platform->supportsKernels(kernelNames))
throw OpenMMException("Specified a Platform for an OpenMMContext which does not support all required kernels"); throw OpenMMException("Specified a Platform for an OpenMMContext which does not support all required kernels");
platform->contextCreated(*this); platform->contextCreated(*this);
initializeForcesKernel = platform->createKernel(InitializeForcesKernel::Name(), *this);
dynamic_cast<InitializeForcesKernel&>(initializeForcesKernel.getImpl()).initialize(system);
kineticEnergyKernel = platform->createKernel(CalcKineticEnergyKernel::Name(), *this); kineticEnergyKernel = platform->createKernel(CalcKineticEnergyKernel::Name(), *this);
vector<double> masses(system.getNumParticles());
for (size_t i = 0; i < masses.size(); ++i)
masses[i] = system.getParticleMass(i);
dynamic_cast<CalcKineticEnergyKernel&>(kineticEnergyKernel.getImpl()).initialize(system); dynamic_cast<CalcKineticEnergyKernel&>(kineticEnergyKernel.getImpl()).initialize(system);
for (size_t i = 0; i < forceImpls.size(); ++i) for (size_t i = 0; i < forceImpls.size(); ++i)
forceImpls[i]->initialize(*this); forceImpls[i]->initialize(*this);
...@@ -100,8 +100,7 @@ void OpenMMContextImpl::setParameter(std::string name, double value) { ...@@ -100,8 +100,7 @@ void OpenMMContextImpl::setParameter(std::string name, double value) {
} }
void OpenMMContextImpl::calcForces() { void OpenMMContextImpl::calcForces() {
double zero[] = {0.0, 0.0, 0.0}; dynamic_cast<InitializeForcesKernel&>(initializeForcesKernel.getImpl()).execute(*this);
forces.fillWithValue(zero);
for (int i = 0; i < (int) forceImpls.size(); ++i) for (int i = 0; i < (int) forceImpls.size(); ++i)
forceImpls[i]->calcForces(*this, forces); forceImpls[i]->calcForces(*this, forces);
} }
......
...@@ -37,6 +37,8 @@ using namespace OpenMM; ...@@ -37,6 +37,8 @@ using namespace OpenMM;
KernelImpl* CudaKernelFactory::createKernelImpl(std::string name, const Platform& platform, OpenMMContextImpl& context) const { KernelImpl* CudaKernelFactory::createKernelImpl(std::string name, const Platform& platform, OpenMMContextImpl& context) const {
CudaPlatform::PlatformData& data = *static_cast<CudaPlatform::PlatformData*>(context.getPlatformData()); CudaPlatform::PlatformData& data = *static_cast<CudaPlatform::PlatformData*>(context.getPlatformData());
if (name == InitializeForcesKernel::Name())
return new CudaInitializeForcesKernel(name, platform);
if (name == CalcHarmonicBondForceKernel::Name()) if (name == CalcHarmonicBondForceKernel::Name())
return new CudaCalcHarmonicBondForceKernel(name, platform, data, context.getSystem()); return new CudaCalcHarmonicBondForceKernel(name, platform, data, context.getSystem());
if (name == CalcHarmonicAngleForceKernel::Name()) if (name == CalcHarmonicAngleForceKernel::Name())
......
...@@ -79,6 +79,12 @@ static double calcEnergy(OpenMMContextImpl& context, System& system) { ...@@ -79,6 +79,12 @@ static double calcEnergy(OpenMMContextImpl& context, System& system) {
return refContext.getState(State::Energy).getPotentialEnergy(); return refContext.getState(State::Energy).getPotentialEnergy();
} }
void CudaInitializeForcesKernel::initialize(const System& system) {
}
double CudaInitializeForcesKernel::execute(OpenMMContextImpl& context) {
}
CudaCalcHarmonicBondForceKernel::~CudaCalcHarmonicBondForceKernel() { CudaCalcHarmonicBondForceKernel::~CudaCalcHarmonicBondForceKernel() {
} }
......
...@@ -45,6 +45,26 @@ class CudaVerletDynamics; ...@@ -45,6 +45,26 @@ class CudaVerletDynamics;
namespace OpenMM { namespace OpenMM {
/**
* This kernel is invoked at the start of each force evaluation to clear the forces.
*/
class CudaInitializeForcesKernel : public InitializeForcesKernel {
public:
CudaInitializeForcesKernel(std::string name, const Platform& platform) : InitializeForcesKernel(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* Execute the kernel.
*
* @param context the context in which to execute this kernel
*/
double execute(OpenMMContextImpl& context);
};
/** /**
* This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system. * This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system.
......
...@@ -46,6 +46,7 @@ extern "C" void initOpenMMPlugin() { ...@@ -46,6 +46,7 @@ extern "C" void initOpenMMPlugin() {
CudaPlatform::CudaPlatform() { CudaPlatform::CudaPlatform() {
CudaKernelFactory* factory = new CudaKernelFactory(); CudaKernelFactory* factory = new CudaKernelFactory();
registerKernelFactory(InitializeForcesKernel::Name(), factory);
registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory);
registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory);
registerKernelFactory(CalcPeriodicTorsionForceKernel::Name(), factory); registerKernelFactory(CalcPeriodicTorsionForceKernel::Name(), factory);
......
...@@ -36,7 +36,9 @@ ...@@ -36,7 +36,9 @@
using namespace OpenMM; using namespace OpenMM;
KernelImpl* ReferenceKernelFactory::createKernelImpl(std::string name, const Platform& platform, OpenMMContextImpl& context) const { KernelImpl* ReferenceKernelFactory::createKernelImpl(std::string name, const Platform& platform, OpenMMContextImpl& context) const {
if (name == CalcNonbondedForceKernel::Name()) if (name == InitializeForcesKernel::Name())
return new ReferenceInitializeForcesKernel(name, platform);
else if (name == CalcNonbondedForceKernel::Name())
return new ReferenceCalcNonbondedForceKernel(name, platform); return new ReferenceCalcNonbondedForceKernel(name, platform);
else if (name == CalcHarmonicBondForceKernel::Name()) else if (name == CalcHarmonicBondForceKernel::Name())
return new ReferenceCalcHarmonicBondForceKernel(name, platform); return new ReferenceCalcHarmonicBondForceKernel(name, platform);
......
...@@ -104,6 +104,14 @@ void disposeRealArray(RealOpenMM** array, int size) { ...@@ -104,6 +104,14 @@ void disposeRealArray(RealOpenMM** array, int size) {
} }
} }
void ReferenceInitializeForcesKernel::initialize(const System& system) {
}
double ReferenceInitializeForcesKernel::execute(OpenMMContextImpl& context) {
double zero[] = {0.0, 0.0, 0.0};
context.getForces().fillWithValue(zero);
}
ReferenceCalcHarmonicBondForceKernel::~ReferenceCalcHarmonicBondForceKernel() { ReferenceCalcHarmonicBondForceKernel::~ReferenceCalcHarmonicBondForceKernel() {
disposeIntArray(bondIndexArray, numBonds); disposeIntArray(bondIndexArray, numBonds);
disposeRealArray(bondParamArray, numBonds); disposeRealArray(bondParamArray, numBonds);
......
...@@ -45,6 +45,27 @@ class ReferenceVerletDynamics; ...@@ -45,6 +45,27 @@ class ReferenceVerletDynamics;
namespace OpenMM { namespace OpenMM {
/**
* This kernel is invoked at the start of each force evaluation to clear the forces.
*/
class ReferenceInitializeForcesKernel : public InitializeForcesKernel {
public:
ReferenceInitializeForcesKernel(std::string name, const Platform& platform) : InitializeForcesKernel(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* Execute the kernel.
*
* @param context the context in which to execute this kernel
*/
double execute(OpenMMContextImpl& context);
};
/** /**
* This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system. * This kernel is invoked by HarmonicBondForce to calculate the forces acting on the system and the energy of the system.
*/ */
......
...@@ -38,6 +38,7 @@ using namespace OpenMM; ...@@ -38,6 +38,7 @@ using namespace OpenMM;
ReferencePlatform::ReferencePlatform() { ReferencePlatform::ReferencePlatform() {
ReferenceKernelFactory* factory = new ReferenceKernelFactory(); ReferenceKernelFactory* factory = new ReferenceKernelFactory();
registerKernelFactory(InitializeForcesKernel::Name(), factory);
registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicBondForceKernel::Name(), factory);
registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory);
registerKernelFactory(CalcPeriodicTorsionForceKernel::Name(), factory); registerKernelFactory(CalcPeriodicTorsionForceKernel::Name(), factory);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment