Commit 150d943a authored by peastman's avatar peastman
Browse files

Optimizations to CUDA version of CustomManyParticleForce

parent bddaf4e7
...@@ -464,6 +464,46 @@ void testTypeFilters() { ...@@ -464,6 +464,46 @@ void testTypeFilters() {
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
} }
void testLargeSystem() {
int gridSize = 8;
int numParticles = gridSize*gridSize*gridSize;
double boxSize = 3.0;
double spacing = boxSize/gridSize;
CpuPlatform platform;
CustomManyParticleForce* force = new CustomManyParticleForce(3,
"C*(1+3*cos(theta1)*cos(theta2)*cos(theta3))/(r12*r13*r23)^3;"
"theta1=angle(p1,p2,p3); theta2=angle(p2,p3,p1); theta3=angle(p3,p1,p2);"
"r12=distance(p1,p2); r13=distance(p1,p3); r23=distance(p2,p3)");
force->addGlobalParameter("C", 1.5);
force->setNonbondedMethod(CustomManyParticleForce::CutoffPeriodic);
force->setCutoffDistance(0.6);
vector<double> params;
vector<Vec3> positions;
System system;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < gridSize; i++)
for (int j = 0; j < gridSize; j++)
for (int k = 0; k < gridSize; k++) {
force->addParticle(params);
positions.push_back(Vec3((i+0.4*genrand_real2(sfmt))*spacing, (j+0.4*genrand_real2(sfmt))*spacing, (k+0.4*genrand_real2(sfmt))*spacing));
system.addParticle(1.0);
}
system.setDefaultPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
system.addForce(force);
VerletIntegrator integrator1(0.001);
VerletIntegrator integrator2(0.001);
Context context1(system, integrator1, Platform::getPlatformByName("Reference"));
Context context2(system, integrator2, platform);
context1.setPositions(positions);
context2.setPositions(positions);
State state1 = context1.getState(State::Forces | State::Energy);
State state2 = context2.getState(State::Forces | State::Energy);
ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), 1e-4);
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-4);
}
int main() { int main() {
try { try {
testNoCutoff(); testNoCutoff();
...@@ -474,6 +514,7 @@ int main() { ...@@ -474,6 +514,7 @@ int main() {
testParameters(); testParameters();
testTabulatedFunctions(); testTabulatedFunctions();
testTypeFilters(); testTypeFilters();
testLargeSystem();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -964,7 +964,7 @@ private: ...@@ -964,7 +964,7 @@ private:
CudaContext& cu; CudaContext& cu;
bool hasInitializedKernel; bool hasInitializedKernel;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
int maxNeighborPairs, forceWorkgroupSize; int maxNeighborPairs, forceWorkgroupSize, findNeighborsWorkgroupSize;
CudaParameterSet* params; CudaParameterSet* params;
CudaArray* globals; CudaArray* globals;
CudaArray* particleTypes; CudaArray* particleTypes;
......
...@@ -4476,6 +4476,7 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con ...@@ -4476,6 +4476,7 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
int particlesPerSet = force.getNumParticlesPerSet(); int particlesPerSet = force.getNumParticlesPerSet();
nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod()); nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
forceWorkgroupSize = 128; forceWorkgroupSize = 128;
findNeighborsWorkgroupSize = 128;
// Record parameter values. // Record parameter values.
...@@ -4791,29 +4792,35 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con ...@@ -4791,29 +4792,35 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
permute<<"int atom"<<(i+1)<<" = particleSet[particleOrder["<<numTypes<<"*order+"<<i<<"]];\n"; permute<<"int atom"<<(i+1)<<" = particleSet[particleOrder["<<numTypes<<"*order+"<<i<<"]];\n";
else else
permute<<"int atom"<<(i+1)<<" = p"<<(i+1)<<";\n"; permute<<"int atom"<<(i+1)<<" = p"<<(i+1)<<";\n";
loadData<<"real4 pos"<<(i+1)<<" = posq[atom"<<(i+1)<<"];\n"; loadData<<"real3 pos"<<(i+1)<<" = trim(posq[atom"<<(i+1)<<"]);\n";
for (int j = 0; j < (int) params->getBuffers().size(); j++) for (int j = 0; j < (int) params->getBuffers().size(); j++)
loadData<<params->getBuffers()[j].getType()<<" params"<<(j+1)<<(i+1)<<" = global_params"<<(j+1)<<"[atom"<<(i+1)<<"];\n"; loadData<<params->getBuffers()[j].getType()<<" params"<<(j+1)<<(i+1)<<" = global_params"<<(j+1)<<"[atom"<<(i+1)<<"];\n";
} }
for (int i = 2; i < particlesPerSet; i++) { for (int i = 2; i < particlesPerSet; i++) {
if (i > 2) if (i > 2)
isValidCombination<<" && "; isValidCombination<<" && ";
isValidCombination<<"p"<<(i+1)<<">p"<<i; isValidCombination<<"a"<<(i+1)<<">a"<<i;
} }
atomsForCombination<<"int tempIndex = index;\n"; atomsForCombination<<"int tempIndex = index;\n";
for (int i = 1; i < particlesPerSet; i++) { for (int i = 1; i < particlesPerSet; i++) {
if (i > 1) if (i > 1)
numCombinations<<"*"; numCombinations<<"*";
numCombinations<<"numNeighbors"; numCombinations<<"numNeighbors";
atomsForCombination<<"int a"<<(i+1)<<" = 1+tempIndex%numNeighbors;\n";
if (i < particlesPerSet-1)
atomsForCombination<<"tempIndex /= numNeighbors;\n";
}
if (particlesPerSet > 2)
atomsForCombination<<"a2 = (a3%2 == 0 ? a2 : numNeighbors-a2+1);\n";
for (int i = 1; i < particlesPerSet; i++) {
if (nonbondedMethod == NoCutoff) if (nonbondedMethod == NoCutoff)
atomsForCombination<<"int p"<<(i+1)<<" = p1+1+tempIndex%numNeighbors;\n"; atomsForCombination<<"int p"<<(i+1)<<" = p1+a"<<(i+1)<<";\n";
else else
atomsForCombination<<"int p"<<(i+1)<<" = neighbors[firstNeighbor+tempIndex%numNeighbors];\n"; atomsForCombination<<"int p"<<(i+1)<<" = neighbors[firstNeighbor-1+a"<<(i+1)<<"];\n";
atomsForCombination<<"tempIndex /= numNeighbors;\n";
} }
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
for (int i = 1; i < particlesPerSet; i++) for (int i = 1; i < particlesPerSet; i++)
verifyCutoff<<"real4 pos"<<(i+1)<<" = posq[p"<<(i+1)<<"];\n"; verifyCutoff<<"real3 pos"<<(i+1)<<" = trim(posq[p"<<(i+1)<<"]);\n";
for (int i = 1; i < particlesPerSet; i++) for (int i = 1; i < particlesPerSet; i++)
for (int j = i+1; j < particlesPerSet; j++) for (int j = i+1; j < particlesPerSet; j++)
verifyCutoff<<"includeInteraction &= (delta(pos"<<(i+1)<<", pos"<<(j+1)<<", periodicBoxSize, invPeriodicBoxSize).w < CUTOFF_SQUARED);\n"; verifyCutoff<<"includeInteraction &= (delta(pos"<<(i+1)<<", pos"<<(j+1)<<", periodicBoxSize, invPeriodicBoxSize).w < CUTOFF_SQUARED);\n";
...@@ -4866,13 +4873,15 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con ...@@ -4866,13 +4873,15 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
defines["CUTOFF_SQUARED"] = cu.doubleToString(force.getCutoffDistance()*force.getCutoffDistance()); defines["CUTOFF_SQUARED"] = cu.doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
defines["TILE_SIZE"] = cu.intToString(CudaContext::TileSize); defines["TILE_SIZE"] = cu.intToString(CudaContext::TileSize);
defines["NUM_BLOCKS"] = cu.intToString(cu.getNumAtomBlocks()); defines["NUM_BLOCKS"] = cu.intToString(cu.getNumAtomBlocks());
// std::cout << cu.replaceStrings(CudaKernelSources::vectorOps+CudaKernelSources::customManyParticle, replacements)<< std::endl; defines["FIND_NEIGHBORS_WORKGROUP_SIZE"] = cu.intToString(findNeighborsWorkgroupSize);
CUmodule module = cu.createModule(cu.replaceStrings(CudaKernelSources::vectorOps+CudaKernelSources::customManyParticle, replacements), defines); CUmodule module = cu.createModule(cu.replaceStrings(CudaKernelSources::vectorOps+CudaKernelSources::customManyParticle, replacements), defines);
forceKernel = cu.getKernel(module, "computeInteraction"); forceKernel = cu.getKernel(module, "computeInteraction");
blockBoundsKernel = cu.getKernel(module, "findBlockBounds"); blockBoundsKernel = cu.getKernel(module, "findBlockBounds");
neighborsKernel = cu.getKernel(module, "findNeighbors"); neighborsKernel = cu.getKernel(module, "findNeighbors");
startIndicesKernel = cu.getKernel(module, "computeNeighborStartIndices"); startIndicesKernel = cu.getKernel(module, "computeNeighborStartIndices");
copyPairsKernel = cu.getKernel(module, "copyPairsToNeighborList"); copyPairsKernel = cu.getKernel(module, "copyPairsToNeighborList");
cuFuncSetCacheConfig(forceKernel, CU_FUNC_CACHE_PREFER_L1);
cuFuncSetCacheConfig(neighborsKernel, CU_FUNC_CACHE_PREFER_L1);
} }
double CudaCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CudaCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -4964,7 +4973,7 @@ double CudaCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool ...@@ -4964,7 +4973,7 @@ double CudaCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool
int* numPairs = (int*) cu.getPinnedBuffer(); int* numPairs = (int*) cu.getPinnedBuffer();
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
cu.executeKernel(blockBoundsKernel, &blockBoundsArgs[0], cu.getNumAtomBlocks()); cu.executeKernel(blockBoundsKernel, &blockBoundsArgs[0], cu.getNumAtomBlocks());
cu.executeKernel(neighborsKernel, &neighborsArgs[0], cu.getNumAtoms()); cu.executeKernel(neighborsKernel, &neighborsArgs[0], cu.getNumAtoms(), findNeighborsWorkgroupSize);
// We need to make sure there was enough memory for the neighbor list. Download the // We need to make sure there was enough memory for the neighbor list. Download the
// information asynchronously so kernels can be running at the same time. // information asynchronously so kernels can be running at the same time.
......
...@@ -18,7 +18,7 @@ inline __device__ real3 trim(real4 v) { ...@@ -18,7 +18,7 @@ inline __device__ real3 trim(real4 v) {
* Compute the difference between two vectors, taking periodic boundary conditions into account * Compute the difference between two vectors, taking periodic boundary conditions into account
* and setting the fourth component to the squared magnitude. * and setting the fourth component to the squared magnitude.
*/ */
inline __device__ real4 delta(real4 vec1, real4 vec2, real4 periodicBoxSize, real4 invPeriodicBoxSize) { inline __device__ real4 delta(real3 vec1, real3 vec2, real4 periodicBoxSize, real4 invPeriodicBoxSize) {
real4 result = make_real4(vec1.x-vec2.x, vec1.y-vec2.y, vec1.z-vec2.z, 0.0f); real4 result = make_real4(vec1.x-vec2.x, vec1.y-vec2.y, vec1.z-vec2.z, 0.0f);
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
result.x -= floor(result.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; result.x -= floor(result.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
...@@ -95,6 +95,7 @@ extern "C" __global__ void computeInteraction( ...@@ -95,6 +95,7 @@ extern "C" __global__ void computeInteraction(
// Loop over particles to be the first one in the set. // Loop over particles to be the first one in the set.
for (int p1 = blockIdx.x; p1 < NUM_ATOMS; p1 += gridDim.x) { for (int p1 = blockIdx.x; p1 < NUM_ATOMS; p1 += gridDim.x) {
const int a1 = 0;
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
int firstNeighbor = neighborStartIndex[p1]; int firstNeighbor = neighborStartIndex[p1];
int numNeighbors = neighborStartIndex[p1+1]-firstNeighbor; int numNeighbors = neighborStartIndex[p1+1]-firstNeighbor;
...@@ -178,41 +179,58 @@ extern "C" __global__ void findNeighbors(real4 periodicBoxSize, real4 invPeriodi ...@@ -178,41 +179,58 @@ extern "C" __global__ void findNeighbors(real4 periodicBoxSize, real4 invPeriodi
, int* __restrict__ exclusions, int* __restrict__ exclusionStartIndex , int* __restrict__ exclusions, int* __restrict__ exclusionStartIndex
#endif #endif
) { ) {
__shared__ real3 positionCache[FIND_NEIGHBORS_WORKGROUP_SIZE];
int indexInWarp = threadIdx.x%32;
for (int atom1 = blockIdx.x*blockDim.x+threadIdx.x; atom1 < NUM_ATOMS; atom1 += blockDim.x*gridDim.x) { for (int atom1 = blockIdx.x*blockDim.x+threadIdx.x; atom1 < NUM_ATOMS; atom1 += blockDim.x*gridDim.x) {
// Load data for this atom. // Load data for this atom. Note that all threads in a warp are processing atoms from the same block.
real4 pos1 = posq[atom1]; real3 pos1 = trim(posq[atom1]);
int block1 = atom1/TILE_SIZE; int block1 = atom1/TILE_SIZE;
real4 blockCenter1 = blockCenter[block1]; real4 blockCenter1 = blockCenter[block1];
real4 blockSize1 = blockBoundingBox[block1]; real4 blockSize1 = blockBoundingBox[block1];
int totalNeighborsForAtom1 = 0; int totalNeighborsForAtom1 = 0;
// Loop over atom blocks to search for neighbors. // Loop over atom blocks to search for neighbors. The threads in a warp compare block1 against 32
// other blocks in parallel.
for (int block2 = block1; block2 < NUM_BLOCKS; block2++) {
real4 blockCenter2 = blockCenter[block2]; for (int block2Base = block1; block2Base < NUM_BLOCKS; block2Base += 32) {
real4 blockSize2 = blockBoundingBox[block2]; int block2 = block2Base+indexInWarp;
real4 blockDelta = blockCenter1-blockCenter2; bool includeBlock2 = (block2 < NUM_BLOCKS);
if (includeBlock2) {
real4 blockCenter2 = blockCenter[block2];
real4 blockSize2 = blockBoundingBox[block2];
real4 blockDelta = blockCenter1-blockCenter2;
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
blockDelta.x -= floor(blockDelta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; blockDelta.x -= floor(blockDelta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
blockDelta.y -= floor(blockDelta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y; blockDelta.y -= floor(blockDelta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
blockDelta.z -= floor(blockDelta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z; blockDelta.z -= floor(blockDelta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif #endif
blockDelta.x = max(0.0f, fabs(blockDelta.x)-blockSize1.x-blockSize2.x); blockDelta.x = max(0.0f, fabs(blockDelta.x)-blockSize1.x-blockSize2.x);
blockDelta.y = max(0.0f, fabs(blockDelta.y)-blockSize1.y-blockSize2.y); blockDelta.y = max(0.0f, fabs(blockDelta.y)-blockSize1.y-blockSize2.y);
blockDelta.z = max(0.0f, fabs(blockDelta.z)-blockSize1.z-blockSize2.z); blockDelta.z = max(0.0f, fabs(blockDelta.z)-blockSize1.z-blockSize2.z);
if (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < CUTOFF_SQUARED) { includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < CUTOFF_SQUARED);
}
// Loop over any blocks we identified as potentially containing neighbors.
int includeBlockFlags = __ballot(includeBlock2);
while (includeBlockFlags != 0) {
int i = __ffs(includeBlockFlags)-1;
includeBlockFlags &= includeBlockFlags-1;
int block2 = block2Base+i;
// Loop over atoms in this block. // Loop over atoms in this block.
int start = block2*TILE_SIZE; int start = block2*TILE_SIZE;
int end = (block2+1)*TILE_SIZE;
int included[TILE_SIZE]; int included[TILE_SIZE];
int numIncluded = 0; int numIncluded = 0;
for (int atom2 = start; atom2 < end; atom2++) { positionCache[threadIdx.x] = trim(posq[start+indexInWarp]);
real4 pos2 = posq[atom2]; for (int j = 0; j < 32; j++) {
int atom2 = start+j;
real3 pos2 = positionCache[threadIdx.x-indexInWarp+j];
// Decide whether to include this atom pair in the neighbor list. // Decide whether to include this atom pair in the neighbor list.
real4 atomDelta = delta(pos1, pos2, periodicBoxSize, invPeriodicBoxSize); real4 atomDelta = delta(pos1, pos2, periodicBoxSize, invPeriodicBoxSize);
bool includeAtom = (atom2 > atom1 && atom2 < NUM_ATOMS && atomDelta.w < CUTOFF_SQUARED); bool includeAtom = (atom2 > atom1 && atom2 < NUM_ATOMS && atomDelta.w < CUTOFF_SQUARED);
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
...@@ -222,14 +240,14 @@ extern "C" __global__ void findNeighbors(real4 periodicBoxSize, real4 invPeriodi ...@@ -222,14 +240,14 @@ extern "C" __global__ void findNeighbors(real4 periodicBoxSize, real4 invPeriodi
if (includeAtom) if (includeAtom)
included[numIncluded++] = atom2; included[numIncluded++] = atom2;
} }
// If we found any neighbors, store them to the neighbor list. // If we found any neighbors, store them to the neighbor list.
if (numIncluded > 0) { if (numIncluded > 0) {
int baseIndex = atomicAdd(numNeighborPairs, numIncluded); int baseIndex = atomicAdd(numNeighborPairs, numIncluded);
if (baseIndex+numIncluded <= maxNeighborPairs) if (baseIndex+numIncluded <= maxNeighborPairs)
for (int i = 0; i < numIncluded; i++) for (int j = 0; j < numIncluded; j++)
neighborPairs[baseIndex+i] = make_int2(atom1, included[i]); neighborPairs[baseIndex+j] = make_int2(atom1, included[j]);
totalNeighborsForAtom1 += numIncluded; totalNeighborsForAtom1 += numIncluded;
} }
} }
......
...@@ -461,6 +461,45 @@ void testTypeFilters() { ...@@ -461,6 +461,45 @@ void testTypeFilters() {
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
} }
void testLargeSystem() {
int gridSize = 8;
int numParticles = gridSize*gridSize*gridSize;
double boxSize = 3.0;
double spacing = boxSize/gridSize;
CustomManyParticleForce* force = new CustomManyParticleForce(3,
"C*(1+3*cos(theta1)*cos(theta2)*cos(theta3))/(r12*r13*r23)^3;"
"theta1=angle(p1,p2,p3); theta2=angle(p2,p3,p1); theta3=angle(p3,p1,p2);"
"r12=distance(p1,p2); r13=distance(p1,p3); r23=distance(p2,p3)");
force->addGlobalParameter("C", 1.5);
force->setNonbondedMethod(CustomManyParticleForce::CutoffPeriodic);
force->setCutoffDistance(0.6);
vector<double> params;
vector<Vec3> positions;
System system;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < gridSize; i++)
for (int j = 0; j < gridSize; j++)
for (int k = 0; k < gridSize; k++) {
force->addParticle(params);
positions.push_back(Vec3((i+0.4*genrand_real2(sfmt))*spacing, (j+0.4*genrand_real2(sfmt))*spacing, (k+0.4*genrand_real2(sfmt))*spacing));
system.addParticle(1.0);
}
system.setDefaultPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
system.addForce(force);
VerletIntegrator integrator1(0.001);
VerletIntegrator integrator2(0.001);
Context context1(system, integrator1, Platform::getPlatformByName("Reference"));
Context context2(system, integrator2, platform);
context1.setPositions(positions);
context2.setPositions(positions);
State state1 = context1.getState(State::Forces | State::Energy);
State state2 = context2.getState(State::Forces | State::Energy);
ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), 1e-4);
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-4);
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -473,6 +512,7 @@ int main(int argc, char* argv[]) { ...@@ -473,6 +512,7 @@ int main(int argc, char* argv[]) {
testParameters(); testParameters();
testTabulatedFunctions(); testTabulatedFunctions();
testTypeFilters(); testTypeFilters();
testLargeSystem();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment