Commit be61ee5b authored by peastman's avatar peastman
Browse files

Centralized the forces, positions, and thread pool so they can be shared between kernels

parent 521f61ef
...@@ -37,17 +37,58 @@ ...@@ -37,17 +37,58 @@
#include "CpuNonbondedForce.h" #include "CpuNonbondedForce.h"
#include "openmm/kernels.h" #include "openmm/kernels.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/internal/ThreadPool.h"
namespace OpenMM { namespace OpenMM {
/**
* This kernel is invoked at the beginning and end of force and energy computations. It gives the
* Platform a chance to clear buffers and do other initialization at the beginning, and to do any
* necessary work at the end to determine the final results.
*/
class CpuCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel {
public:
CpuCalcForcesAndEnergyKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ContextImpl& context);
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* This is called at the beginning of each force/energy computation, before calcForcesAndEnergy() has been called on
* any ForceImpl.
*
* @param context the context in which to execute this kernel
* @param includeForce true if forces should be computed
* @param includeEnergy true if potential energy should be computed
* @param groups a set of bit flags for which force groups to include
*/
void beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups);
/**
* This is called at the end of each force/energy computation, after calcForcesAndEnergy() has been called on
* every ForceImpl.
*
* @param context the context in which to execute this kernel
* @param includeForce true if forces should be computed
* @param includeEnergy true if potential energy should be computed
* @param groups a set of bit flags for which force groups to include
* @return the potential energy of the system. This value is added to all values returned by ForceImpls'
* calcForcesAndEnergy() methods. That is, each force kernel may <i>either</i> return its contribution to the
* energy directly, <i>or</i> add it to an internal buffer so that it will be included here.
*/
double finishComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups);
private:
CpuPlatform::PlatformData& data;
Kernel referenceKernel;
};
/** /**
* This kernel is invoked by NonbondedForce to calculate the forces acting on the system. * This kernel is invoked by NonbondedForce to calculate the forces acting on the system.
*/ */
class CpuCalcNonbondedForceKernel : public CalcNonbondedForceKernel { class CpuCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public: public:
CpuCalcNonbondedForceKernel(std::string name, const Platform& platform) : CalcNonbondedForceKernel(name, platform), CpuCalcNonbondedForceKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data) : CalcNonbondedForceKernel(name, platform),
bonded14IndexArray(NULL), bonded14ParamArray(NULL), hasInitializedPme(false) { data(data), bonded14IndexArray(NULL), bonded14ParamArray(NULL), hasInitializedPme(false) {
} }
~CpuCalcNonbondedForceKernel(); ~CpuCalcNonbondedForceKernel();
/** /**
...@@ -77,6 +118,7 @@ public: ...@@ -77,6 +118,7 @@ public:
void copyParametersToContext(ContextImpl& context, const NonbondedForce& force); void copyParametersToContext(ContextImpl& context, const NonbondedForce& force);
private: private:
class PmeIO; class PmeIO;
CpuPlatform::PlatformData& data;
int numParticles, num14; int numParticles, num14;
int **bonded14IndexArray; int **bonded14IndexArray;
double **bonded14ParamArray; double **bonded14ParamArray;
...@@ -85,13 +127,10 @@ private: ...@@ -85,13 +127,10 @@ private:
bool useSwitchingFunction, useOptimizedPme, hasInitializedPme; bool useSwitchingFunction, useOptimizedPme, hasInitializedPme;
std::vector<std::set<int> > exclusions; std::vector<std::set<int> > exclusions;
std::vector<std::pair<float, float> > particleParams; std::vector<std::pair<float, float> > particleParams;
std::vector<float> posq;
std::vector<float> forces;
std::vector<RealVec> lastPositions; std::vector<RealVec> lastPositions;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
CpuNeighborList neighborList; CpuNeighborList neighborList;
CpuNonbondedForce nonbonded; CpuNonbondedForce nonbonded;
ThreadPool threads;
Kernel optimizedPme; Kernel optimizedPme;
}; };
......
...@@ -143,7 +143,7 @@ class CpuNonbondedForce { ...@@ -143,7 +143,7 @@ class CpuNonbondedForce {
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateDirectIxn(int numberOfAtoms, float* posq, const std::vector<RealVec>& atomCoordinates, const std::vector<std::pair<float, float> >& atomParameters, void calculateDirectIxn(int numberOfAtoms, float* posq, const std::vector<RealVec>& atomCoordinates, const std::vector<std::pair<float, float> >& atomParameters,
const std::vector<std::set<int> >& exclusions, float* forces, float* totalEnergy, ThreadPool& threads); const std::vector<std::set<int> >& exclusions, std::vector<std::vector<float> >& threadForce, float* totalEnergy, ThreadPool& threads);
/** /**
* This routine contains the code executed by each thread. * This routine contains the code executed by each thread.
...@@ -165,7 +165,6 @@ private: ...@@ -165,7 +165,6 @@ private:
int meshDim[3]; int meshDim[3];
std::vector<float> ewaldScaleTable; std::vector<float> ewaldScaleTable;
float ewaldDX, ewaldDXInv; float ewaldDX, ewaldDXInv;
std::vector<std::vector<float> > threadForce;
std::vector<double> threadEnergy; std::vector<double> threadEnergy;
// The following variables are used to make information accessible to the individual threads. // The following variables are used to make information accessible to the individual threads.
int numberOfAtoms; int numberOfAtoms;
...@@ -173,6 +172,7 @@ private: ...@@ -173,6 +172,7 @@ private:
RealVec const* atomCoordinates; RealVec const* atomCoordinates;
std::pair<float, float> const* atomParameters; std::pair<float, float> const* atomParameters;
std::set<int> const* exclusions; std::set<int> const* exclusions;
std::vector<std::vector<float> >* threadForce;
bool includeEnergy; bool includeEnergy;
static const float TWO_OVER_SQRT_PI; static const float TWO_OVER_SQRT_PI;
......
...@@ -33,7 +33,10 @@ ...@@ -33,7 +33,10 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "ReferencePlatform.h" #include "ReferencePlatform.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/ThreadPool.h"
#include "windowsExportCpu.h" #include "windowsExportCpu.h"
#include <map>
namespace OpenMM { namespace OpenMM {
...@@ -43,6 +46,7 @@ namespace OpenMM { ...@@ -43,6 +46,7 @@ namespace OpenMM {
class OPENMM_EXPORT_CPU CpuPlatform : public ReferencePlatform { class OPENMM_EXPORT_CPU CpuPlatform : public ReferencePlatform {
public: public:
class PlatformData;
CpuPlatform(); CpuPlatform();
const std::string& getName() const { const std::string& getName() const {
static const std::string name = "CPU"; static const std::string name = "CPU";
...@@ -51,6 +55,24 @@ public: ...@@ -51,6 +55,24 @@ public:
double getSpeed() const; double getSpeed() const;
bool supportsDoublePrecision() const; bool supportsDoublePrecision() const;
static bool isProcessorSupported(); static bool isProcessorSupported();
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void contextDestroyed(ContextImpl& context) const;
/**
* We cannot use the standard mechanism for platform data, because that is already used by the superclass.
* Instead, we maintain a table of ContextImpls to PlatformDatas.
*/
static PlatformData& getPlatformData(ContextImpl& context);
private:
static std::map<ContextImpl*, PlatformData*> contextData;
};
class CpuPlatform::PlatformData {
public:
PlatformData(int numParticles);
std::vector<float> posq;
std::vector<std::vector<float> > threadForce;
ThreadPool threads;
bool isPeriodic;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -38,8 +38,10 @@ ...@@ -38,8 +38,10 @@
using namespace OpenMM; using namespace OpenMM;
KernelImpl* CpuKernelFactory::createKernelImpl(std::string name, const Platform& platform, ContextImpl& context) const { KernelImpl* CpuKernelFactory::createKernelImpl(std::string name, const Platform& platform, ContextImpl& context) const {
ReferencePlatform::PlatformData& data = *static_cast<ReferencePlatform::PlatformData*>(context.getPlatformData()); CpuPlatform::PlatformData& data = CpuPlatform::getPlatformData(context);
if (name == CalcForcesAndEnergyKernel::Name())
return new CpuCalcForcesAndEnergyKernel(name, platform, data, context);
if (name == CalcNonbondedForceKernel::Name()) if (name == CalcNonbondedForceKernel::Name())
return new CpuCalcNonbondedForceKernel(name, platform); return new CpuCalcNonbondedForceKernel(name, platform, data);
throw OpenMMException((std::string("Tried to create kernel with illegal kernel name '") + name + "'").c_str()); throw OpenMMException((std::string("Tried to create kernel with illegal kernel name '") + name + "'").c_str());
} }
...@@ -31,11 +31,14 @@ ...@@ -31,11 +31,14 @@
#include "CpuKernels.h" #include "CpuKernels.h"
#include "ReferenceBondForce.h" #include "ReferenceBondForce.h"
#include "ReferenceKernelFactory.h"
#include "ReferenceKernels.h"
#include "ReferenceLJCoulomb14.h" #include "ReferenceLJCoulomb14.h"
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h" #include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/vectorize.h"
#include "RealVec.h" #include "RealVec.h"
using namespace OpenMM; using namespace OpenMM;
...@@ -61,6 +64,67 @@ static RealVec& extractBoxSize(ContextImpl& context) { ...@@ -61,6 +64,67 @@ static RealVec& extractBoxSize(ContextImpl& context) {
return *(RealVec*) data->periodicBoxSize; return *(RealVec*) data->periodicBoxSize;
} }
CpuCalcForcesAndEnergyKernel::CpuCalcForcesAndEnergyKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ContextImpl& context) :
CalcForcesAndEnergyKernel(name, platform), data(data) {
// Create a Reference platform version of this kernel.
ReferenceKernelFactory referenceFactory;
referenceKernel = Kernel(referenceFactory.createKernelImpl(name, platform, context));
}
void CpuCalcForcesAndEnergyKernel::initialize(const System& system) {
referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().initialize(system);
}
void CpuCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) {
referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().beginComputation(context, includeForce, includeEnergy, groups);
// Convert the positions to single precision and apply periodic boundary conditions
vector<float>& posq = data.posq;
vector<RealVec>& posData = extractPositions(context);
RealVec boxSize = extractBoxSize(context);
float floatBoxSize[3] = {(float) boxSize[0], (float) boxSize[1], (float) boxSize[2]};
int numParticles = context.getSystem().getNumParticles();
if (data.isPeriodic)
for (int i = 0; i < numParticles; i++)
for (int j = 0; j < 3; j++) {
RealOpenMM x = posData[i][j];
double base = floor(x/boxSize[j])*boxSize[j];
posq[4*i+j] = (float) (x-base);
}
else
for (int i = 0; i < numParticles; i++) {
posq[4*i] = (float) posData[i][0];
posq[4*i+1] = (float) posData[i][1];
posq[4*i+2] = (float) posData[i][2];
}
// Clear the forces.
fvec4 zero(0.0f);
for (int i = 0; i < (int) data.threadForce.size(); i++)
for (int j = 0; j < numParticles; j++)
zero.store(&data.threadForce[i][j*4]);
}
double CpuCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) {
// Sum the forces from all the threads.
int numParticles = context.getSystem().getNumParticles();
int numThreads = data.threads.getNumThreads();
vector<RealVec>& forceData = extractForces(context);
for (int i = 0; i < numParticles; i++) {
fvec4 f(0.0f);
for (int j = 0; j < numThreads; j++)
f += fvec4(&data.threadForce[j][4*i]);
forceData[i][0] += f[0];
forceData[i][1] += f[1];
forceData[i][2] += f[2];
}
return referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().finishComputation(context, includeForce, includeEnergy, groups);
}
class CpuCalcNonbondedForceKernel::PmeIO : public CalcPmeReciprocalForceKernel::IO { class CpuCalcNonbondedForceKernel::PmeIO : public CalcPmeReciprocalForceKernel::IO {
public: public:
PmeIO(float* posq, float* force, int numParticles) : posq(posq), force(force), numParticles(numParticles) { PmeIO(float* posq, float* force, int numParticles) : posq(posq), force(force), numParticles(numParticles) {
...@@ -97,8 +161,6 @@ void CpuCalcNonbondedForceKernel::initialize(const System& system, const Nonbond ...@@ -97,8 +161,6 @@ void CpuCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
// Identify which exceptions are 1-4 interactions. // Identify which exceptions are 1-4 interactions.
numParticles = force.getNumParticles(); numParticles = force.getNumParticles();
posq.resize(4*numParticles, 0);
forces.resize(4*numParticles, 0);
exclusions.resize(numParticles); exclusions.resize(numParticles);
vector<int> nb14s; vector<int> nb14s;
for (int i = 0; i < force.getNumExceptions(); i++) { for (int i = 0; i < force.getNumExceptions(); i++) {
...@@ -125,7 +187,7 @@ void CpuCalcNonbondedForceKernel::initialize(const System& system, const Nonbond ...@@ -125,7 +187,7 @@ void CpuCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
for (int i = 0; i < numParticles; ++i) { for (int i = 0; i < numParticles; ++i) {
double charge, radius, depth; double charge, radius, depth;
force.getParticleParameters(i, charge, radius, depth); force.getParticleParameters(i, charge, radius, depth);
posq[4*i+3] = (float) charge; data.posq[4*i+3] = (float) charge;
particleParams[i] = make_pair((float) (0.5*radius), (float) (2.0*sqrt(depth))); particleParams[i] = make_pair((float) (0.5*radius), (float) (2.0*sqrt(depth)));
sumSquaredCharges += charge*charge; sumSquaredCharges += charge*charge;
} }
...@@ -173,6 +235,7 @@ void CpuCalcNonbondedForceKernel::initialize(const System& system, const Nonbond ...@@ -173,6 +235,7 @@ void CpuCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
else else
dispersionCoefficient = 0.0; dispersionCoefficient = 0.0;
lastPositions.resize(numParticles, Vec3(1e10, 1e10, 1e10)); lastPositions.resize(numParticles, Vec3(1e10, 1e10, 1e10));
data.isPeriodic = (nonbondedMethod == CutoffPeriodic || nonbondedMethod == Ewald || nonbondedMethod == PME);
} }
double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) { double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) {
...@@ -192,32 +255,14 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -192,32 +255,14 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
} }
} }
} }
vector<float>& posq = data.posq;
vector<RealVec>& posData = extractPositions(context); vector<RealVec>& posData = extractPositions(context);
vector<RealVec>& forceData = extractForces(context); vector<RealVec>& forceData = extractForces(context);
RealVec boxSize = extractBoxSize(context); RealVec boxSize = extractBoxSize(context);
float floatBoxSize[3] = {(float) boxSize[0], (float) boxSize[1], (float) boxSize[2]}; float floatBoxSize[3] = {(float) boxSize[0], (float) boxSize[1], (float) boxSize[2]};
double energy = ewaldSelfEnergy; double energy = ewaldSelfEnergy;
bool periodic = (nonbondedMethod == CutoffPeriodic);
bool ewald = (nonbondedMethod == Ewald); bool ewald = (nonbondedMethod == Ewald);
bool pme = (nonbondedMethod == PME); bool pme = (nonbondedMethod == PME);
// Convert the positions to single precision.
if (periodic || ewald || pme)
for (int i = 0; i < numParticles; i++)
for (int j = 0; j < 3; j++) {
RealOpenMM x = posData[i][j];
double base = floor(x/boxSize[j])*boxSize[j];
posq[4*i+j] = (float) (x-base);
}
else
for (int i = 0; i < numParticles; i++) {
posq[4*i] = (float) posData[i][0];
posq[4*i+1] = (float) posData[i][1];
posq[4*i+2] = (float) posData[i][2];
}
for (int i = 0; i < 4*numParticles; i++)
forces[i] = 0.0f;
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
// Determine whether we need to recompute the neighbor list. // Determine whether we need to recompute the neighbor list.
...@@ -260,12 +305,12 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -260,12 +305,12 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
} }
} }
if (needRecompute) { if (needRecompute) {
neighborList.computeNeighborList(numParticles, posq, exclusions, floatBoxSize, periodic || ewald || pme, nonbondedCutoff+padding, threads); neighborList.computeNeighborList(numParticles, posq, exclusions, floatBoxSize, data.isPeriodic, nonbondedCutoff+padding, data.threads);
lastPositions = posData; lastPositions = posData;
} }
nonbonded.setUseCutoff(nonbondedCutoff, neighborList, rfDielectric); nonbonded.setUseCutoff(nonbondedCutoff, neighborList, rfDielectric);
} }
if (periodic || ewald || pme) { if (data.isPeriodic) {
double minAllowedSize = 1.999999*nonbondedCutoff; double minAllowedSize = 1.999999*nonbondedCutoff;
if (boxSize[0] < minAllowedSize || boxSize[1] < minAllowedSize || boxSize[2] < minAllowedSize) if (boxSize[0] < minAllowedSize || boxSize[1] < minAllowedSize || boxSize[2] < minAllowedSize)
throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff."); throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
...@@ -279,10 +324,10 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -279,10 +324,10 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
nonbonded.setUseSwitchingFunction(switchingDistance); nonbonded.setUseSwitchingFunction(switchingDistance);
float nonbondedEnergy = 0; float nonbondedEnergy = 0;
if (includeDirect) if (includeDirect)
nonbonded.calculateDirectIxn(numParticles, &posq[0], posData, particleParams, exclusions, &forces[0], includeEnergy ? &nonbondedEnergy : NULL, threads); nonbonded.calculateDirectIxn(numParticles, &posq[0], posData, particleParams, exclusions, data.threadForce, includeEnergy ? &nonbondedEnergy : NULL, data.threads);
if (includeReciprocal) { if (includeReciprocal) {
if (useOptimizedPme) { if (useOptimizedPme) {
PmeIO io(&posq[0], &forces[0], numParticles); PmeIO io(&posq[0], &data.threadForce[0][0], numParticles);
Vec3 periodicBoxSize(boxSize[0], boxSize[1], boxSize[2]); Vec3 periodicBoxSize(boxSize[0], boxSize[1], boxSize[2]);
optimizedPme.getAs<CalcPmeReciprocalForceKernel>().beginComputation(io, periodicBoxSize, includeEnergy); optimizedPme.getAs<CalcPmeReciprocalForceKernel>().beginComputation(io, periodicBoxSize, includeEnergy);
optimizedPme.getAs<CalcPmeReciprocalForceKernel>().finishComputation(io); optimizedPme.getAs<CalcPmeReciprocalForceKernel>().finishComputation(io);
...@@ -291,16 +336,11 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -291,16 +336,11 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
nonbonded.calculateReciprocalIxn(numParticles, &posq[0], posData, particleParams, exclusions, forceData, includeEnergy ? &nonbondedEnergy : NULL); nonbonded.calculateReciprocalIxn(numParticles, &posq[0], posData, particleParams, exclusions, forceData, includeEnergy ? &nonbondedEnergy : NULL);
} }
energy += nonbondedEnergy; energy += nonbondedEnergy;
for (int i = 0; i < numParticles; i++) {
forceData[i][0] += forces[4*i];
forceData[i][1] += forces[4*i+1];
forceData[i][2] += forces[4*i+2];
}
if (includeDirect) { if (includeDirect) {
ReferenceBondForce refBondForce; ReferenceBondForce refBondForce;
ReferenceLJCoulomb14 nonbonded14; ReferenceLJCoulomb14 nonbonded14;
refBondForce.calculateForce(num14, bonded14IndexArray, posData, bonded14ParamArray, forceData, includeEnergy ? &energy : NULL, nonbonded14); refBondForce.calculateForce(num14, bonded14IndexArray, posData, bonded14ParamArray, forceData, includeEnergy ? &energy : NULL, nonbonded14);
if (periodic || ewald || pme) if (data.isPeriodic)
energy += dispersionCoefficient/(boxSize[0]*boxSize[1]*boxSize[2]); energy += dispersionCoefficient/(boxSize[0]*boxSize[1]*boxSize[2]);
} }
return energy; return energy;
...@@ -326,7 +366,7 @@ void CpuCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, ...@@ -326,7 +366,7 @@ void CpuCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context,
for (int i = 0; i < numParticles; ++i) { for (int i = 0; i < numParticles; ++i) {
double charge, radius, depth; double charge, radius, depth;
force.getParticleParameters(i, charge, radius, depth); force.getParticleParameters(i, charge, radius, depth);
posq[4*i+3] = (float) charge; data.posq[4*i+3] = (float) charge;
particleParams[i] = make_pair((float) (0.5*radius), (float) (2.0*sqrt(depth))); particleParams[i] = make_pair((float) (0.5*radius), (float) (2.0*sqrt(depth)));
sumSquaredCharges += charge*charge; sumSquaredCharges += charge*charge;
} }
......
...@@ -292,7 +292,7 @@ void CpuNonbondedForce::calculateReciprocalIxn(int numberOfAtoms, float* posq, c ...@@ -292,7 +292,7 @@ void CpuNonbondedForce::calculateReciprocalIxn(int numberOfAtoms, float* posq, c
void CpuNonbondedForce::calculateDirectIxn(int numberOfAtoms, float* posq, const vector<RealVec>& atomCoordinates, const vector<pair<float, float> >& atomParameters, void CpuNonbondedForce::calculateDirectIxn(int numberOfAtoms, float* posq, const vector<RealVec>& atomCoordinates, const vector<pair<float, float> >& atomParameters,
const vector<set<int> >& exclusions, float* forces, float* totalEnergy, ThreadPool& threads) { const vector<set<int> >& exclusions, vector<vector<float> >& threadForce, float* totalEnergy, ThreadPool& threads) {
// Record the parameters for the threads. // Record the parameters for the threads.
this->numberOfAtoms = numberOfAtoms; this->numberOfAtoms = numberOfAtoms;
...@@ -300,9 +300,9 @@ void CpuNonbondedForce::calculateDirectIxn(int numberOfAtoms, float* posq, const ...@@ -300,9 +300,9 @@ void CpuNonbondedForce::calculateDirectIxn(int numberOfAtoms, float* posq, const
this->atomCoordinates = &atomCoordinates[0]; this->atomCoordinates = &atomCoordinates[0];
this->atomParameters = &atomParameters[0]; this->atomParameters = &atomParameters[0];
this->exclusions = &exclusions[0]; this->exclusions = &exclusions[0];
this->threadForce = &threadForce;
includeEnergy = (totalEnergy != NULL); includeEnergy = (totalEnergy != NULL);
threadEnergy.resize(threads.getNumThreads()); threadEnergy.resize(threads.getNumThreads());
threadForce.resize(threads.getNumThreads());
// Signal the threads to start running and wait for them to finish. // Signal the threads to start running and wait for them to finish.
...@@ -310,21 +310,15 @@ void CpuNonbondedForce::calculateDirectIxn(int numberOfAtoms, float* posq, const ...@@ -310,21 +310,15 @@ void CpuNonbondedForce::calculateDirectIxn(int numberOfAtoms, float* posq, const
threads.execute(task); threads.execute(task);
threads.waitForThreads(); threads.waitForThreads();
// Combine the results from all the threads. // Combine the energies from all the threads.
double directEnergy = 0; if (totalEnergy != NULL) {
int numThreads = threads.getNumThreads(); double directEnergy = 0;
for (int i = 0; i < numThreads; i++) int numThreads = threads.getNumThreads();
directEnergy += threadEnergy[i]; for (int i = 0; i < numThreads; i++)
for (int i = 0; i < numberOfAtoms; i++) { directEnergy += threadEnergy[i];
fvec4 f(forces+4*i);
for (int j = 0; j < numThreads; j++)
f += fvec4(&threadForce[j][4*i]);
f.store(forces+4*i);
}
if (totalEnergy != NULL)
*totalEnergy += (float) directEnergy; *totalEnergy += (float) directEnergy;
}
} }
void CpuNonbondedForce::threadComputeDirect(ThreadPool& threads, int threadIndex) { void CpuNonbondedForce::threadComputeDirect(ThreadPool& threads, int threadIndex) {
...@@ -333,10 +327,7 @@ void CpuNonbondedForce::threadComputeDirect(ThreadPool& threads, int threadIndex ...@@ -333,10 +327,7 @@ void CpuNonbondedForce::threadComputeDirect(ThreadPool& threads, int threadIndex
int numThreads = threads.getNumThreads(); int numThreads = threads.getNumThreads();
threadEnergy[threadIndex] = 0; threadEnergy[threadIndex] = 0;
double* energyPtr = (includeEnergy ? &threadEnergy[threadIndex] : NULL); double* energyPtr = (includeEnergy ? &threadEnergy[threadIndex] : NULL);
threadForce[threadIndex].resize(4*numberOfAtoms, 0.0f); float* forces = &(*threadForce)[threadIndex][0];
float* forces = &threadForce[threadIndex][0];
for (int i = 0; i < 4*numberOfAtoms; i++)
forces[i] = 0.0f;
fvec4 boxSize(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2], 0); fvec4 boxSize(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2], 0);
fvec4 invBoxSize((1/periodicBoxSize[0]), (1/periodicBoxSize[1]), (1/periodicBoxSize[2]), 0); fvec4 invBoxSize((1/periodicBoxSize[0]), (1/periodicBoxSize[1]), (1/periodicBoxSize[2]), 0);
if (ewald || pme) { if (ewald || pme) {
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "openmm/internal/hardware.h" #include "openmm/internal/hardware.h"
using namespace OpenMM; using namespace OpenMM;
using namespace std;
extern "C" OPENMM_EXPORT_CPU void registerPlatforms() { extern "C" OPENMM_EXPORT_CPU void registerPlatforms() {
// Only register this platform if the CPU supports SSE 4.1. // Only register this platform if the CPU supports SSE 4.1.
...@@ -43,8 +44,11 @@ extern "C" OPENMM_EXPORT_CPU void registerPlatforms() { ...@@ -43,8 +44,11 @@ extern "C" OPENMM_EXPORT_CPU void registerPlatforms() {
Platform::registerPlatform(new CpuPlatform()); Platform::registerPlatform(new CpuPlatform());
} }
map<ContextImpl*, CpuPlatform::PlatformData*> CpuPlatform::contextData;
CpuPlatform::CpuPlatform() { CpuPlatform::CpuPlatform() {
CpuKernelFactory* factory = new CpuKernelFactory(); CpuKernelFactory* factory = new CpuKernelFactory();
registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory);
registerKernelFactory(CalcNonbondedForceKernel::Name(), factory); registerKernelFactory(CalcNonbondedForceKernel::Name(), factory);
} }
...@@ -67,3 +71,28 @@ bool CpuPlatform::isProcessorSupported() { ...@@ -67,3 +71,28 @@ bool CpuPlatform::isProcessorSupported() {
} }
return false; return false;
} }
void CpuPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
ReferencePlatform::contextCreated(context, properties);
PlatformData* data = new PlatformData(context.getSystem().getNumParticles());
contextData[&context] = data;
}
void CpuPlatform::contextDestroyed(ContextImpl& context) const {
PlatformData* data = contextData[&context];
delete data;
contextData.erase(&context);
}
CpuPlatform::PlatformData& CpuPlatform::getPlatformData(ContextImpl& context) {
return *contextData[&context];
}
CpuPlatform::PlatformData::PlatformData(int numParticles) {
posq.resize(4*numParticles);
int numThreads = threads.getNumThreads();
threadForce.resize(numThreads);
for (int i = 0; i < numThreads; i++)
threadForce[i].resize(4*numParticles);
isPeriodic = false;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment