Commit 0e683e8d authored by peastman's avatar peastman
Browse files

CPU GayBerneForce uses common neighbor list

parent 2e9c418a
...@@ -52,11 +52,6 @@ public: ...@@ -52,11 +52,6 @@ public:
*/ */
CpuGayBerneForce(const GayBerneForce& force); CpuGayBerneForce(const GayBerneForce& force);
/**
* Destructor.
*/
~CpuGayBerneForce();
/** /**
* Compute the interaction. * Compute the interaction.
* *
...@@ -72,7 +67,12 @@ public: ...@@ -72,7 +67,12 @@ public:
/** /**
* This routine contains the code executed by each thread. * This routine contains the code executed by each thread.
*/ */
void threadComputeForce(ThreadPool& threads, int threadIndex); void threadComputeForce(ThreadPool& threads, int threadIndex, CpuNeighborList* neighborList);
/**
* Get the exclusions being used by the force.
*/
const std::vector<std::set<int> >& getExclusions() const;
private: private:
struct ParticleInfo; struct ParticleInfo;
...@@ -86,7 +86,6 @@ private: ...@@ -86,7 +86,6 @@ private:
bool useSwitchingFunction; bool useSwitchingFunction;
std::vector<RealOpenMM> s; std::vector<RealOpenMM> s;
std::vector<Matrix> A, B, G; std::vector<Matrix> A, B, G;
CpuNeighborList* neighborList;
std::vector<double> threadEnergy; std::vector<double> threadEnergy;
std::vector<std::vector<RealVec> > threadTorque; std::vector<std::vector<RealVec> > threadTorque;
// The following variables are used to make information accessible to the individual threads. // The following variables are used to make information accessible to the individual threads.
......
...@@ -82,7 +82,7 @@ class CpuPlatform::PlatformData { ...@@ -82,7 +82,7 @@ class CpuPlatform::PlatformData {
public: public:
PlatformData(int numParticles, int numThreads); PlatformData(int numParticles, int numThreads);
~PlatformData(); ~PlatformData();
void requestNeighborList(double cutoffDistance, double padding, bool useExclusions, std::vector<std::set<int> >& exclusionList); void requestNeighborList(double cutoffDistance, double padding, bool useExclusions, const std::vector<std::set<int> >& exclusionList);
AlignedArray<float> posq; AlignedArray<float> posq;
std::vector<AlignedArray<float> > threadForce; std::vector<AlignedArray<float> > threadForce;
ThreadPool threads; ThreadPool threads;
......
...@@ -41,15 +41,16 @@ using namespace std; ...@@ -41,15 +41,16 @@ using namespace std;
class CpuGayBerneForce::ComputeTask : public ThreadPool::Task { class CpuGayBerneForce::ComputeTask : public ThreadPool::Task {
public: public:
ComputeTask(CpuGayBerneForce& owner) : owner(owner) { ComputeTask(CpuGayBerneForce& owner, CpuNeighborList* neighborList) : owner(owner), neighborList(neighborList) {
} }
void execute(ThreadPool& threads, int threadIndex) { void execute(ThreadPool& threads, int threadIndex) {
owner.threadComputeForce(threads, threadIndex); owner.threadComputeForce(threads, threadIndex, neighborList);
} }
CpuGayBerneForce& owner; CpuGayBerneForce& owner;
CpuNeighborList* neighborList;
}; };
CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) : neighborList(NULL) { CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) {
// Record the force parameters. // Record the force parameters.
int numParticles = force.getNumParticles(); int numParticles = force.getNumParticles();
...@@ -85,8 +86,6 @@ CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) : neighborList(NU ...@@ -85,8 +86,6 @@ CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) : neighborList(NU
cutoffDistance = force.getCutoffDistance(); cutoffDistance = force.getCutoffDistance();
switchingDistance = force.getSwitchingDistance(); switchingDistance = force.getSwitchingDistance();
useSwitchingFunction = force.getUseSwitchingFunction(); useSwitchingFunction = force.getUseSwitchingFunction();
if (nonbondedMethod != GayBerneForce::NoCutoff)
neighborList = new CpuNeighborList(4);
// Allocate workspace for calculations. // Allocate workspace for calculations.
...@@ -103,9 +102,8 @@ CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) : neighborList(NU ...@@ -103,9 +102,8 @@ CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) : neighborList(NU
} }
} }
CpuGayBerneForce::~CpuGayBerneForce() { const vector<set<int> >& CpuGayBerneForce::getExclusions() const {
if (neighborList != NULL) return particleExclusions;
delete neighborList;
} }
RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, std::vector<RealVec>& forces, std::vector<AlignedArray<float> >& threadForce, RealVec* boxVectors, CpuPlatform::PlatformData& data) { RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, std::vector<RealVec>& forces, std::vector<AlignedArray<float> >& threadForce, RealVec* boxVectors, CpuPlatform::PlatformData& data) {
...@@ -114,12 +112,6 @@ RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, st ...@@ -114,12 +112,6 @@ RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, st
if (boxVectors[0][0] < minAllowedSize || boxVectors[1][1] < minAllowedSize || boxVectors[2][2] < minAllowedSize) if (boxVectors[0][0] < minAllowedSize || boxVectors[1][1] < minAllowedSize || boxVectors[2][2] < minAllowedSize)
throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff."); throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
} }
// Build the neighbor list.
int numParticles = particles.size();
if (nonbondedMethod != GayBerneForce::NoCutoff)
neighborList->computeNeighborList(numParticles, data.posq, particleExclusions, boxVectors, nonbondedMethod == GayBerneForce::CutoffPeriodic, cutoffDistance, data.threads);
// First find the orientations of the particles and compute the matrices we'll be needing. // First find the orientations of the particles and compute the matrices we'll be needing.
...@@ -140,7 +132,7 @@ RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, st ...@@ -140,7 +132,7 @@ RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, st
// Signal the threads to compute the pairwise interactions. // Signal the threads to compute the pairwise interactions.
ComputeTask task(*this); ComputeTask task(*this, data.neighborList);
threads.execute(task); threads.execute(task);
threads.waitForThreads(); threads.waitForThreads();
...@@ -162,7 +154,7 @@ RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, st ...@@ -162,7 +154,7 @@ RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, st
return energy; return energy;
} }
void CpuGayBerneForce::threadComputeForce(ThreadPool& threads, int threadIndex) { void CpuGayBerneForce::threadComputeForce(ThreadPool& threads, int threadIndex, CpuNeighborList* neighborList) {
int numParticles = particles.size(); int numParticles = particles.size();
int numThreads = threads.getNumThreads(); int numThreads = threads.getNumThreads();
threadEnergy[threadIndex] = 0; threadEnergy[threadIndex] = 0;
...@@ -198,14 +190,15 @@ void CpuGayBerneForce::threadComputeForce(ThreadPool& threads, int threadIndex) ...@@ -198,14 +190,15 @@ void CpuGayBerneForce::threadComputeForce(ThreadPool& threads, int threadIndex)
int blockIndex = gmx_atomic_fetch_add(reinterpret_cast<gmx_atomic_t*>(atomicCounter), 1); int blockIndex = gmx_atomic_fetch_add(reinterpret_cast<gmx_atomic_t*>(atomicCounter), 1);
if (blockIndex >= neighborList->getNumBlocks()) if (blockIndex >= neighborList->getNumBlocks())
break; break;
const int* blockAtom = &neighborList->getSortedAtoms()[4*blockIndex]; const int blockSize = neighborList->getBlockSize();
const int* blockAtom = &neighborList->getSortedAtoms()[blockSize*blockIndex];
const vector<int>& neighbors = neighborList->getBlockNeighbors(blockIndex); const vector<int>& neighbors = neighborList->getBlockNeighbors(blockIndex);
const vector<char>& exclusions = neighborList->getBlockExclusions(blockIndex); const vector<char>& exclusions = neighborList->getBlockExclusions(blockIndex);
for (int i = 0; i < (int) neighbors.size(); i++) { for (int i = 0; i < (int) neighbors.size(); i++) {
int first = neighbors[i]; int first = neighbors[i];
if (particles[first].sqrtEpsilon == 0.0f) if (particles[first].sqrtEpsilon == 0.0f)
continue; continue;
for (int k = 0; k < 4; k++) { for (int k = 0; k < blockSize; k++) {
if ((exclusions[i] & (1<<k)) == 0) { if ((exclusions[i] & (1<<k)) == 0) {
int second = blockAtom[k]; int second = blockAtom[k];
if (particles[second].sqrtEpsilon == 0.0f) if (particles[second].sqrtEpsilon == 0.0f)
......
...@@ -1212,6 +1212,10 @@ CpuCalcGayBerneForceKernel::~CpuCalcGayBerneForceKernel() { ...@@ -1212,6 +1212,10 @@ CpuCalcGayBerneForceKernel::~CpuCalcGayBerneForceKernel() {
void CpuCalcGayBerneForceKernel::initialize(const System& system, const GayBerneForce& force) { void CpuCalcGayBerneForceKernel::initialize(const System& system, const GayBerneForce& force) {
ixn = new CpuGayBerneForce(force); ixn = new CpuGayBerneForce(force);
data.isPeriodic = (force.getNonbondedMethod() == GayBerneForce::CutoffPeriodic); data.isPeriodic = (force.getNonbondedMethod() == GayBerneForce::CutoffPeriodic);
if (force.getNonbondedMethod() != GayBerneForce::NoCutoff) {
double cutoff = force.getCutoffDistance();
data.requestNeighborList(cutoff, 0.1*cutoff, true, ixn->getExclusions());
}
} }
double CpuCalcGayBerneForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CpuCalcGayBerneForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
......
...@@ -156,7 +156,7 @@ CpuPlatform::PlatformData::~PlatformData() { ...@@ -156,7 +156,7 @@ CpuPlatform::PlatformData::~PlatformData() {
bool isVec8Supported(); bool isVec8Supported();
void CpuPlatform::PlatformData::requestNeighborList(double cutoffDistance, double padding, bool useExclusions, vector<set<int> >& exclusionList) { void CpuPlatform::PlatformData::requestNeighborList(double cutoffDistance, double padding, bool useExclusions, const vector<set<int> >& exclusionList) {
if (neighborList == NULL) if (neighborList == NULL)
neighborList = new CpuNeighborList(isVec8Supported() ? 8 : 4); neighborList = new CpuNeighborList(isVec8Supported() ? 8 : 4);
if (cutoffDistance > cutoff) if (cutoffDistance > cutoff)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment