Commit be61ee5b authored by peastman's avatar peastman
Browse files

Centralized the forces, positions, and thread pool so they can be shared between kernels

parent 521f61ef
......@@ -37,17 +37,58 @@
#include "CpuNonbondedForce.h"
#include "openmm/kernels.h"
#include "openmm/System.h"
#include "openmm/internal/ThreadPool.h"
namespace OpenMM {
/**
* This kernel is invoked at the beginning and end of force and energy computations. It gives the
* Platform a chance to clear buffers and do other initialization at the beginning, and to do any
* necessary work at the end to determine the final results.
*/
class CpuCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel {
public:
CpuCalcForcesAndEnergyKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ContextImpl& context);
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* This is called at the beginning of each force/energy computation, before calcForcesAndEnergy() has been called on
* any ForceImpl.
*
* @param context the context in which to execute this kernel
* @param includeForce true if forces should be computed
* @param includeEnergy true if potential energy should be computed
* @param groups a set of bit flags for which force groups to include
*/
void beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups);
/**
* This is called at the end of each force/energy computation, after calcForcesAndEnergy() has been called on
* every ForceImpl.
*
* @param context the context in which to execute this kernel
* @param includeForce true if forces should be computed
* @param includeEnergy true if potential energy should be computed
* @param groups a set of bit flags for which force groups to include
* @return the potential energy of the system. This value is added to all values returned by ForceImpls'
* calcForcesAndEnergy() methods. That is, each force kernel may <i>either</i> return its contribution to the
* energy directly, <i>or</i> add it to an internal buffer so that it will be included here.
*/
double finishComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups);
private:
CpuPlatform::PlatformData& data;
Kernel referenceKernel;
};
/**
* This kernel is invoked by NonbondedForce to calculate the forces acting on the system.
*/
class CpuCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public:
CpuCalcNonbondedForceKernel(std::string name, const Platform& platform) : CalcNonbondedForceKernel(name, platform),
bonded14IndexArray(NULL), bonded14ParamArray(NULL), hasInitializedPme(false) {
CpuCalcNonbondedForceKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data) : CalcNonbondedForceKernel(name, platform),
data(data), bonded14IndexArray(NULL), bonded14ParamArray(NULL), hasInitializedPme(false) {
}
~CpuCalcNonbondedForceKernel();
/**
......@@ -77,6 +118,7 @@ public:
void copyParametersToContext(ContextImpl& context, const NonbondedForce& force);
private:
class PmeIO;
CpuPlatform::PlatformData& data;
int numParticles, num14;
int **bonded14IndexArray;
double **bonded14ParamArray;
......@@ -85,13 +127,10 @@ private:
bool useSwitchingFunction, useOptimizedPme, hasInitializedPme;
std::vector<std::set<int> > exclusions;
std::vector<std::pair<float, float> > particleParams;
std::vector<float> posq;
std::vector<float> forces;
std::vector<RealVec> lastPositions;
NonbondedMethod nonbondedMethod;
CpuNeighborList neighborList;
CpuNonbondedForce nonbonded;
ThreadPool threads;
Kernel optimizedPme;
};
......
......@@ -143,7 +143,7 @@ class CpuNonbondedForce {
--------------------------------------------------------------------------------------- */
void calculateDirectIxn(int numberOfAtoms, float* posq, const std::vector<RealVec>& atomCoordinates, const std::vector<std::pair<float, float> >& atomParameters,
const std::vector<std::set<int> >& exclusions, float* forces, float* totalEnergy, ThreadPool& threads);
const std::vector<std::set<int> >& exclusions, std::vector<std::vector<float> >& threadForce, float* totalEnergy, ThreadPool& threads);
/**
* This routine contains the code executed by each thread.
......@@ -165,7 +165,6 @@ private:
int meshDim[3];
std::vector<float> ewaldScaleTable;
float ewaldDX, ewaldDXInv;
std::vector<std::vector<float> > threadForce;
std::vector<double> threadEnergy;
// The following variables are used to make information accessible to the individual threads.
int numberOfAtoms;
......@@ -173,6 +172,7 @@ private:
RealVec const* atomCoordinates;
std::pair<float, float> const* atomParameters;
std::set<int> const* exclusions;
std::vector<std::vector<float> >* threadForce;
bool includeEnergy;
static const float TWO_OVER_SQRT_PI;
......
......@@ -33,7 +33,10 @@
* -------------------------------------------------------------------------- */
#include "ReferencePlatform.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/ThreadPool.h"
#include "windowsExportCpu.h"
#include <map>
namespace OpenMM {
......@@ -43,6 +46,7 @@ namespace OpenMM {
class OPENMM_EXPORT_CPU CpuPlatform : public ReferencePlatform {
public:
class PlatformData;
CpuPlatform();
const std::string& getName() const {
static const std::string name = "CPU";
......@@ -51,6 +55,24 @@ public:
double getSpeed() const;
bool supportsDoublePrecision() const;
static bool isProcessorSupported();
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void contextDestroyed(ContextImpl& context) const;
/**
* We cannot use the standard mechanism for platform data, because that is already used by the superclass.
* Instead, we maintain a table of ContextImpls to PlatformDatas.
*/
static PlatformData& getPlatformData(ContextImpl& context);
private:
static std::map<ContextImpl*, PlatformData*> contextData;
};
class CpuPlatform::PlatformData {
public:
PlatformData(int numParticles);
std::vector<float> posq;
std::vector<std::vector<float> > threadForce;
ThreadPool threads;
bool isPeriodic;
};
} // namespace OpenMM
......
......@@ -38,8 +38,10 @@
using namespace OpenMM;
KernelImpl* CpuKernelFactory::createKernelImpl(std::string name, const Platform& platform, ContextImpl& context) const {
ReferencePlatform::PlatformData& data = *static_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
CpuPlatform::PlatformData& data = CpuPlatform::getPlatformData(context);
if (name == CalcForcesAndEnergyKernel::Name())
return new CpuCalcForcesAndEnergyKernel(name, platform, data, context);
if (name == CalcNonbondedForceKernel::Name())
return new CpuCalcNonbondedForceKernel(name, platform);
return new CpuCalcNonbondedForceKernel(name, platform, data);
throw OpenMMException((std::string("Tried to create kernel with illegal kernel name '") + name + "'").c_str());
}
......@@ -31,11 +31,14 @@
#include "CpuKernels.h"
#include "ReferenceBondForce.h"
#include "ReferenceKernelFactory.h"
#include "ReferenceKernels.h"
#include "ReferenceLJCoulomb14.h"
#include "openmm/Context.h"
#include "openmm/OpenMMException.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/vectorize.h"
#include "RealVec.h"
using namespace OpenMM;
......@@ -61,6 +64,67 @@ static RealVec& extractBoxSize(ContextImpl& context) {
return *(RealVec*) data->periodicBoxSize;
}
CpuCalcForcesAndEnergyKernel::CpuCalcForcesAndEnergyKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ContextImpl& context) :
CalcForcesAndEnergyKernel(name, platform), data(data) {
// Create a Reference platform version of this kernel.
ReferenceKernelFactory referenceFactory;
referenceKernel = Kernel(referenceFactory.createKernelImpl(name, platform, context));
}
void CpuCalcForcesAndEnergyKernel::initialize(const System& system) {
referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().initialize(system);
}
void CpuCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) {
referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().beginComputation(context, includeForce, includeEnergy, groups);
// Convert the positions to single precision and apply periodic boundary conditions
vector<float>& posq = data.posq;
vector<RealVec>& posData = extractPositions(context);
RealVec boxSize = extractBoxSize(context);
float floatBoxSize[3] = {(float) boxSize[0], (float) boxSize[1], (float) boxSize[2]};
int numParticles = context.getSystem().getNumParticles();
if (data.isPeriodic)
for (int i = 0; i < numParticles; i++)
for (int j = 0; j < 3; j++) {
RealOpenMM x = posData[i][j];
double base = floor(x/boxSize[j])*boxSize[j];
posq[4*i+j] = (float) (x-base);
}
else
for (int i = 0; i < numParticles; i++) {
posq[4*i] = (float) posData[i][0];
posq[4*i+1] = (float) posData[i][1];
posq[4*i+2] = (float) posData[i][2];
}
// Clear the forces.
fvec4 zero(0.0f);
for (int i = 0; i < (int) data.threadForce.size(); i++)
for (int j = 0; j < numParticles; j++)
zero.store(&data.threadForce[i][j*4]);
}
double CpuCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) {
// Sum the forces from all the threads.
int numParticles = context.getSystem().getNumParticles();
int numThreads = data.threads.getNumThreads();
vector<RealVec>& forceData = extractForces(context);
for (int i = 0; i < numParticles; i++) {
fvec4 f(0.0f);
for (int j = 0; j < numThreads; j++)
f += fvec4(&data.threadForce[j][4*i]);
forceData[i][0] += f[0];
forceData[i][1] += f[1];
forceData[i][2] += f[2];
}
return referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().finishComputation(context, includeForce, includeEnergy, groups);
}
class CpuCalcNonbondedForceKernel::PmeIO : public CalcPmeReciprocalForceKernel::IO {
public:
PmeIO(float* posq, float* force, int numParticles) : posq(posq), force(force), numParticles(numParticles) {
......@@ -97,8 +161,6 @@ void CpuCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
// Identify which exceptions are 1-4 interactions.
numParticles = force.getNumParticles();
posq.resize(4*numParticles, 0);
forces.resize(4*numParticles, 0);
exclusions.resize(numParticles);
vector<int> nb14s;
for (int i = 0; i < force.getNumExceptions(); i++) {
......@@ -125,7 +187,7 @@ void CpuCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
for (int i = 0; i < numParticles; ++i) {
double charge, radius, depth;
force.getParticleParameters(i, charge, radius, depth);
posq[4*i+3] = (float) charge;
data.posq[4*i+3] = (float) charge;
particleParams[i] = make_pair((float) (0.5*radius), (float) (2.0*sqrt(depth)));
sumSquaredCharges += charge*charge;
}
......@@ -173,6 +235,7 @@ void CpuCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
else
dispersionCoefficient = 0.0;
lastPositions.resize(numParticles, Vec3(1e10, 1e10, 1e10));
data.isPeriodic = (nonbondedMethod == CutoffPeriodic || nonbondedMethod == Ewald || nonbondedMethod == PME);
}
double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) {
......@@ -192,32 +255,14 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
}
}
}
vector<float>& posq = data.posq;
vector<RealVec>& posData = extractPositions(context);
vector<RealVec>& forceData = extractForces(context);
RealVec boxSize = extractBoxSize(context);
float floatBoxSize[3] = {(float) boxSize[0], (float) boxSize[1], (float) boxSize[2]};
double energy = ewaldSelfEnergy;
bool periodic = (nonbondedMethod == CutoffPeriodic);
bool ewald = (nonbondedMethod == Ewald);
bool pme = (nonbondedMethod == PME);
// Convert the positions to single precision.
if (periodic || ewald || pme)
for (int i = 0; i < numParticles; i++)
for (int j = 0; j < 3; j++) {
RealOpenMM x = posData[i][j];
double base = floor(x/boxSize[j])*boxSize[j];
posq[4*i+j] = (float) (x-base);
}
else
for (int i = 0; i < numParticles; i++) {
posq[4*i] = (float) posData[i][0];
posq[4*i+1] = (float) posData[i][1];
posq[4*i+2] = (float) posData[i][2];
}
for (int i = 0; i < 4*numParticles; i++)
forces[i] = 0.0f;
if (nonbondedMethod != NoCutoff) {
// Determine whether we need to recompute the neighbor list.
......@@ -260,12 +305,12 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
}
}
if (needRecompute) {
neighborList.computeNeighborList(numParticles, posq, exclusions, floatBoxSize, periodic || ewald || pme, nonbondedCutoff+padding, threads);
neighborList.computeNeighborList(numParticles, posq, exclusions, floatBoxSize, data.isPeriodic, nonbondedCutoff+padding, data.threads);
lastPositions = posData;
}
nonbonded.setUseCutoff(nonbondedCutoff, neighborList, rfDielectric);
}
if (periodic || ewald || pme) {
if (data.isPeriodic) {
double minAllowedSize = 1.999999*nonbondedCutoff;
if (boxSize[0] < minAllowedSize || boxSize[1] < minAllowedSize || boxSize[2] < minAllowedSize)
throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
......@@ -279,10 +324,10 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
nonbonded.setUseSwitchingFunction(switchingDistance);
float nonbondedEnergy = 0;
if (includeDirect)
nonbonded.calculateDirectIxn(numParticles, &posq[0], posData, particleParams, exclusions, &forces[0], includeEnergy ? &nonbondedEnergy : NULL, threads);
nonbonded.calculateDirectIxn(numParticles, &posq[0], posData, particleParams, exclusions, data.threadForce, includeEnergy ? &nonbondedEnergy : NULL, data.threads);
if (includeReciprocal) {
if (useOptimizedPme) {
PmeIO io(&posq[0], &forces[0], numParticles);
PmeIO io(&posq[0], &data.threadForce[0][0], numParticles);
Vec3 periodicBoxSize(boxSize[0], boxSize[1], boxSize[2]);
optimizedPme.getAs<CalcPmeReciprocalForceKernel>().beginComputation(io, periodicBoxSize, includeEnergy);
optimizedPme.getAs<CalcPmeReciprocalForceKernel>().finishComputation(io);
......@@ -291,16 +336,11 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
nonbonded.calculateReciprocalIxn(numParticles, &posq[0], posData, particleParams, exclusions, forceData, includeEnergy ? &nonbondedEnergy : NULL);
}
energy += nonbondedEnergy;
for (int i = 0; i < numParticles; i++) {
forceData[i][0] += forces[4*i];
forceData[i][1] += forces[4*i+1];
forceData[i][2] += forces[4*i+2];
}
if (includeDirect) {
ReferenceBondForce refBondForce;
ReferenceLJCoulomb14 nonbonded14;
refBondForce.calculateForce(num14, bonded14IndexArray, posData, bonded14ParamArray, forceData, includeEnergy ? &energy : NULL, nonbonded14);
if (periodic || ewald || pme)
if (data.isPeriodic)
energy += dispersionCoefficient/(boxSize[0]*boxSize[1]*boxSize[2]);
}
return energy;
......@@ -326,7 +366,7 @@ void CpuCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context,
for (int i = 0; i < numParticles; ++i) {
double charge, radius, depth;
force.getParticleParameters(i, charge, radius, depth);
posq[4*i+3] = (float) charge;
data.posq[4*i+3] = (float) charge;
particleParams[i] = make_pair((float) (0.5*radius), (float) (2.0*sqrt(depth)));
sumSquaredCharges += charge*charge;
}
......
......@@ -292,7 +292,7 @@ void CpuNonbondedForce::calculateReciprocalIxn(int numberOfAtoms, float* posq, c
void CpuNonbondedForce::calculateDirectIxn(int numberOfAtoms, float* posq, const vector<RealVec>& atomCoordinates, const vector<pair<float, float> >& atomParameters,
const vector<set<int> >& exclusions, float* forces, float* totalEnergy, ThreadPool& threads) {
const vector<set<int> >& exclusions, vector<vector<float> >& threadForce, float* totalEnergy, ThreadPool& threads) {
// Record the parameters for the threads.
this->numberOfAtoms = numberOfAtoms;
......@@ -300,9 +300,9 @@ void CpuNonbondedForce::calculateDirectIxn(int numberOfAtoms, float* posq, const
this->atomCoordinates = &atomCoordinates[0];
this->atomParameters = &atomParameters[0];
this->exclusions = &exclusions[0];
this->threadForce = &threadForce;
includeEnergy = (totalEnergy != NULL);
threadEnergy.resize(threads.getNumThreads());
threadForce.resize(threads.getNumThreads());
// Signal the threads to start running and wait for them to finish.
......@@ -310,21 +310,15 @@ void CpuNonbondedForce::calculateDirectIxn(int numberOfAtoms, float* posq, const
threads.execute(task);
threads.waitForThreads();
// Combine the results from all the threads.
// Combine the energies from all the threads.
if (totalEnergy != NULL) {
double directEnergy = 0;
int numThreads = threads.getNumThreads();
for (int i = 0; i < numThreads; i++)
directEnergy += threadEnergy[i];
for (int i = 0; i < numberOfAtoms; i++) {
fvec4 f(forces+4*i);
for (int j = 0; j < numThreads; j++)
f += fvec4(&threadForce[j][4*i]);
f.store(forces+4*i);
}
if (totalEnergy != NULL)
*totalEnergy += (float) directEnergy;
}
}
void CpuNonbondedForce::threadComputeDirect(ThreadPool& threads, int threadIndex) {
......@@ -333,10 +327,7 @@ void CpuNonbondedForce::threadComputeDirect(ThreadPool& threads, int threadIndex
int numThreads = threads.getNumThreads();
threadEnergy[threadIndex] = 0;
double* energyPtr = (includeEnergy ? &threadEnergy[threadIndex] : NULL);
threadForce[threadIndex].resize(4*numberOfAtoms, 0.0f);
float* forces = &threadForce[threadIndex][0];
for (int i = 0; i < 4*numberOfAtoms; i++)
forces[i] = 0.0f;
float* forces = &(*threadForce)[threadIndex][0];
fvec4 boxSize(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2], 0);
fvec4 invBoxSize((1/periodicBoxSize[0]), (1/periodicBoxSize[1]), (1/periodicBoxSize[2]), 0);
if (ewald || pme) {
......
......@@ -35,6 +35,7 @@
#include "openmm/internal/hardware.h"
using namespace OpenMM;
using namespace std;
extern "C" OPENMM_EXPORT_CPU void registerPlatforms() {
// Only register this platform if the CPU supports SSE 4.1.
......@@ -43,8 +44,11 @@ extern "C" OPENMM_EXPORT_CPU void registerPlatforms() {
Platform::registerPlatform(new CpuPlatform());
}
map<ContextImpl*, CpuPlatform::PlatformData*> CpuPlatform::contextData;
CpuPlatform::CpuPlatform() {
CpuKernelFactory* factory = new CpuKernelFactory();
registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory);
registerKernelFactory(CalcNonbondedForceKernel::Name(), factory);
}
......@@ -67,3 +71,28 @@ bool CpuPlatform::isProcessorSupported() {
}
return false;
}
void CpuPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
ReferencePlatform::contextCreated(context, properties);
PlatformData* data = new PlatformData(context.getSystem().getNumParticles());
contextData[&context] = data;
}
void CpuPlatform::contextDestroyed(ContextImpl& context) const {
PlatformData* data = contextData[&context];
delete data;
contextData.erase(&context);
}
CpuPlatform::PlatformData& CpuPlatform::getPlatformData(ContextImpl& context) {
return *contextData[&context];
}
CpuPlatform::PlatformData::PlatformData(int numParticles) {
posq.resize(4*numParticles);
int numThreads = threads.getNumThreads();
threadForce.resize(numThreads);
for (int i = 0; i < numThreads; i++)
threadForce[i].resize(4*numParticles);
isPeriodic = false;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment