Commit 0e683e8d authored by peastman's avatar peastman
Browse files

CPU GayBerneForce uses common neighbor list

parent 2e9c418a
......@@ -52,11 +52,6 @@ public:
*/
CpuGayBerneForce(const GayBerneForce& force);
/**
* Destructor.
*/
~CpuGayBerneForce();
/**
* Compute the interaction.
*
......@@ -72,7 +67,12 @@ public:
/**
* This routine contains the code executed by each thread.
*/
void threadComputeForce(ThreadPool& threads, int threadIndex);
void threadComputeForce(ThreadPool& threads, int threadIndex, CpuNeighborList* neighborList);
/**
* Get the exclusions being used by the force.
*/
const std::vector<std::set<int> >& getExclusions() const;
private:
struct ParticleInfo;
......@@ -86,7 +86,6 @@ private:
bool useSwitchingFunction;
std::vector<RealOpenMM> s;
std::vector<Matrix> A, B, G;
CpuNeighborList* neighborList;
std::vector<double> threadEnergy;
std::vector<std::vector<RealVec> > threadTorque;
// The following variables are used to make information accessible to the individual threads.
......
......@@ -82,7 +82,7 @@ class CpuPlatform::PlatformData {
public:
PlatformData(int numParticles, int numThreads);
~PlatformData();
void requestNeighborList(double cutoffDistance, double padding, bool useExclusions, std::vector<std::set<int> >& exclusionList);
void requestNeighborList(double cutoffDistance, double padding, bool useExclusions, const std::vector<std::set<int> >& exclusionList);
AlignedArray<float> posq;
std::vector<AlignedArray<float> > threadForce;
ThreadPool threads;
......
......@@ -41,15 +41,16 @@ using namespace std;
class CpuGayBerneForce::ComputeTask : public ThreadPool::Task {
public:
ComputeTask(CpuGayBerneForce& owner) : owner(owner) {
ComputeTask(CpuGayBerneForce& owner, CpuNeighborList* neighborList) : owner(owner), neighborList(neighborList) {
}
void execute(ThreadPool& threads, int threadIndex) {
owner.threadComputeForce(threads, threadIndex);
owner.threadComputeForce(threads, threadIndex, neighborList);
}
CpuGayBerneForce& owner;
CpuNeighborList* neighborList;
};
CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) : neighborList(NULL) {
CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) {
// Record the force parameters.
int numParticles = force.getNumParticles();
......@@ -85,8 +86,6 @@ CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) : neighborList(NU
cutoffDistance = force.getCutoffDistance();
switchingDistance = force.getSwitchingDistance();
useSwitchingFunction = force.getUseSwitchingFunction();
if (nonbondedMethod != GayBerneForce::NoCutoff)
neighborList = new CpuNeighborList(4);
// Allocate workspace for calculations.
......@@ -103,9 +102,8 @@ CpuGayBerneForce::CpuGayBerneForce(const GayBerneForce& force) : neighborList(NU
}
}
CpuGayBerneForce::~CpuGayBerneForce() {
if (neighborList != NULL)
delete neighborList;
const vector<set<int> >& CpuGayBerneForce::getExclusions() const {
return particleExclusions;
}
RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, std::vector<RealVec>& forces, std::vector<AlignedArray<float> >& threadForce, RealVec* boxVectors, CpuPlatform::PlatformData& data) {
......@@ -114,12 +112,6 @@ RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, st
if (boxVectors[0][0] < minAllowedSize || boxVectors[1][1] < minAllowedSize || boxVectors[2][2] < minAllowedSize)
throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
}
// Build the neighbor list.
int numParticles = particles.size();
if (nonbondedMethod != GayBerneForce::NoCutoff)
neighborList->computeNeighborList(numParticles, data.posq, particleExclusions, boxVectors, nonbondedMethod == GayBerneForce::CutoffPeriodic, cutoffDistance, data.threads);
// First find the orientations of the particles and compute the matrices we'll be needing.
......@@ -140,7 +132,7 @@ RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, st
// Signal the threads to compute the pairwise interactions.
ComputeTask task(*this);
ComputeTask task(*this, data.neighborList);
threads.execute(task);
threads.waitForThreads();
......@@ -162,7 +154,7 @@ RealOpenMM CpuGayBerneForce::calculateForce(const vector<RealVec>& positions, st
return energy;
}
void CpuGayBerneForce::threadComputeForce(ThreadPool& threads, int threadIndex) {
void CpuGayBerneForce::threadComputeForce(ThreadPool& threads, int threadIndex, CpuNeighborList* neighborList) {
int numParticles = particles.size();
int numThreads = threads.getNumThreads();
threadEnergy[threadIndex] = 0;
......@@ -198,14 +190,15 @@ void CpuGayBerneForce::threadComputeForce(ThreadPool& threads, int threadIndex)
int blockIndex = gmx_atomic_fetch_add(reinterpret_cast<gmx_atomic_t*>(atomicCounter), 1);
if (blockIndex >= neighborList->getNumBlocks())
break;
const int* blockAtom = &neighborList->getSortedAtoms()[4*blockIndex];
const int blockSize = neighborList->getBlockSize();
const int* blockAtom = &neighborList->getSortedAtoms()[blockSize*blockIndex];
const vector<int>& neighbors = neighborList->getBlockNeighbors(blockIndex);
const vector<char>& exclusions = neighborList->getBlockExclusions(blockIndex);
for (int i = 0; i < (int) neighbors.size(); i++) {
int first = neighbors[i];
if (particles[first].sqrtEpsilon == 0.0f)
continue;
for (int k = 0; k < 4; k++) {
for (int k = 0; k < blockSize; k++) {
if ((exclusions[i] & (1<<k)) == 0) {
int second = blockAtom[k];
if (particles[second].sqrtEpsilon == 0.0f)
......
......@@ -1212,6 +1212,10 @@ CpuCalcGayBerneForceKernel::~CpuCalcGayBerneForceKernel() {
void CpuCalcGayBerneForceKernel::initialize(const System& system, const GayBerneForce& force) {
ixn = new CpuGayBerneForce(force);
data.isPeriodic = (force.getNonbondedMethod() == GayBerneForce::CutoffPeriodic);
if (force.getNonbondedMethod() != GayBerneForce::NoCutoff) {
double cutoff = force.getCutoffDistance();
data.requestNeighborList(cutoff, 0.1*cutoff, true, ixn->getExclusions());
}
}
double CpuCalcGayBerneForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
......
......@@ -156,7 +156,7 @@ CpuPlatform::PlatformData::~PlatformData() {
bool isVec8Supported();
void CpuPlatform::PlatformData::requestNeighborList(double cutoffDistance, double padding, bool useExclusions, vector<set<int> >& exclusionList) {
void CpuPlatform::PlatformData::requestNeighborList(double cutoffDistance, double padding, bool useExclusions, const vector<set<int> >& exclusionList) {
if (neighborList == NULL)
neighborList = new CpuNeighborList(isVec8Supported() ? 8 : 4);
if (cutoffDistance > cutoff)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment