Commit b8cba79f authored by Peter Eastman's avatar Peter Eastman
Browse files

Restrict the number of threads in GK kernels based on available shared memory

parent 549d62bd
...@@ -1644,12 +1644,27 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::initialize(const System& syst ...@@ -1644,12 +1644,27 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::initialize(const System& syst
} }
params->upload(paramsVector); params->upload(paramsVector);
// Select the number of threads for each kernel.
double computeBornSumThreadMemory = 4*elementSize+3*sizeof(float);
double gkForceThreadMemory = 24*elementSize;
double chainRuleThreadMemory = 10*elementSize;
double ediffThreadMemory = 28*elementSize+2*sizeof(float)+3*sizeof(int)/(double) cu.TileSize;
int maxThreads = cu.getNonbondedUtilities().getForceThreadBlockSize();
computeBornSumThreads = min(maxThreads, cu.computeThreadBlockSize(computeBornSumThreadMemory));
gkForceThreads = min(maxThreads, cu.computeThreadBlockSize(gkForceThreadMemory));
chainRuleThreads = min(maxThreads, cu.computeThreadBlockSize(chainRuleThreadMemory));
ediffThreads = min(maxThreads, cu.computeThreadBlockSize(ediffThreadMemory));
// Create the kernels. // Create the kernels.
map<string, string> defines; map<string, string> defines;
defines["NUM_ATOMS"] = cu.intToString(cu.getNumAtoms()); defines["NUM_ATOMS"] = cu.intToString(cu.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = cu.intToString(paddedNumAtoms); defines["PADDED_NUM_ATOMS"] = cu.intToString(paddedNumAtoms);
defines["THREAD_BLOCK_SIZE"] = cu.intToString(nb.getForceThreadBlockSize()); defines["BORN_SUM_THREAD_BLOCK_SIZE"] = cu.intToString(computeBornSumThreads);
defines["GK_FORCE_THREAD_BLOCK_SIZE"] = cu.intToString(gkForceThreads);
defines["CHAIN_RULE_THREAD_BLOCK_SIZE"] = cu.intToString(chainRuleThreads);
defines["EDIFF_THREAD_BLOCK_SIZE"] = cu.intToString(ediffThreads);
defines["NUM_BLOCKS"] = cu.intToString(cu.getNumAtomBlocks()); defines["NUM_BLOCKS"] = cu.intToString(cu.getNumAtomBlocks());
defines["GK_C"] = cu.doubleToString(2.455); defines["GK_C"] = cu.doubleToString(2.455);
double solventDielectric = force.getSolventDielectric(); double solventDielectric = force.getSolventDielectric();
...@@ -1710,10 +1725,9 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::computeBornRadii() { ...@@ -1710,10 +1725,9 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::computeBornRadii() {
CudaNonbondedUtilities& nb = cu.getNonbondedUtilities(); CudaNonbondedUtilities& nb = cu.getNonbondedUtilities();
int numTiles = nb.getNumTiles(); int numTiles = nb.getNumTiles();
int numForceThreadBlocks = nb.getNumForceThreadBlocks(); int numForceThreadBlocks = nb.getNumForceThreadBlocks();
int forceThreadBlockSize = nb.getForceThreadBlockSize();
void* computeBornSumArgs[] = {&bornSum->getDevicePointer(), &cu.getPosq().getDevicePointer(), void* computeBornSumArgs[] = {&bornSum->getDevicePointer(), &cu.getPosq().getDevicePointer(),
&params->getDevicePointer(), &numTiles}; &params->getDevicePointer(), &numTiles};
cu.executeKernel(computeBornSumKernel, computeBornSumArgs, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize); cu.executeKernel(computeBornSumKernel, computeBornSumArgs, numForceThreadBlocks*computeBornSumThreads, computeBornSumThreads);
void* reduceBornSumArgs[] = {&bornSum->getDevicePointer(), &params->getDevicePointer(), &bornRadii->getDevicePointer()}; void* reduceBornSumArgs[] = {&bornSum->getDevicePointer(), &params->getDevicePointer(), &bornRadii->getDevicePointer()};
cu.executeKernel(reduceBornSumKernel, reduceBornSumArgs, cu.getNumAtoms()); cu.executeKernel(reduceBornSumKernel, reduceBornSumArgs, cu.getNumAtoms());
} }
...@@ -1724,7 +1738,6 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::finishComputation(CudaArray& ...@@ -1724,7 +1738,6 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::finishComputation(CudaArray&
int startTileIndex = nb.getStartTileIndex(); int startTileIndex = nb.getStartTileIndex();
int numTileIndices = nb.getNumTiles(); int numTileIndices = nb.getNumTiles();
int numForceThreadBlocks = nb.getNumForceThreadBlocks(); int numForceThreadBlocks = nb.getNumForceThreadBlocks();
int forceThreadBlockSize = nb.getForceThreadBlockSize();
// Compute the GK force. // Compute the GK force.
...@@ -1732,7 +1745,7 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::finishComputation(CudaArray& ...@@ -1732,7 +1745,7 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::finishComputation(CudaArray&
&cu.getPosq().getDevicePointer(), &startTileIndex, &numTileIndices, &labFrameDipoles.getDevicePointer(), &cu.getPosq().getDevicePointer(), &startTileIndex, &numTileIndices, &labFrameDipoles.getDevicePointer(),
&labFrameQuadrupoles.getDevicePointer(), &inducedDipoleS->getDevicePointer(), &inducedDipolePolarS->getDevicePointer(), &labFrameQuadrupoles.getDevicePointer(), &inducedDipoleS->getDevicePointer(), &inducedDipolePolarS->getDevicePointer(),
&bornRadii->getDevicePointer(), &bornForce->getDevicePointer()}; &bornRadii->getDevicePointer(), &bornForce->getDevicePointer()};
cu.executeKernel(gkForceKernel, gkForceArgs, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize); cu.executeKernel(gkForceKernel, gkForceArgs, numForceThreadBlocks*gkForceThreads, gkForceThreads);
// Compute the surface area force. // Compute the surface area force.
...@@ -1745,14 +1758,14 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::finishComputation(CudaArray& ...@@ -1745,14 +1758,14 @@ void CudaCalcAmoebaGeneralizedKirkwoodForceKernel::finishComputation(CudaArray&
void* chainRuleArgs[] = {&cu.getForce().getDevicePointer(), &cu.getPosq().getDevicePointer(), &startTileIndex, &numTileIndices, void* chainRuleArgs[] = {&cu.getForce().getDevicePointer(), &cu.getPosq().getDevicePointer(), &startTileIndex, &numTileIndices,
&params->getDevicePointer(), &bornRadii->getDevicePointer(), &bornForce->getDevicePointer()}; &params->getDevicePointer(), &bornRadii->getDevicePointer(), &bornForce->getDevicePointer()};
cu.executeKernel(chainRuleKernel, chainRuleArgs, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize); cu.executeKernel(chainRuleKernel, chainRuleArgs, numForceThreadBlocks*chainRuleThreads, chainRuleThreads);
void* ediffArgs[] = {&cu.getForce().getDevicePointer(), &torque.getDevicePointer(), &cu.getEnergyBuffer().getDevicePointer(), void* ediffArgs[] = {&cu.getForce().getDevicePointer(), &torque.getDevicePointer(), &cu.getEnergyBuffer().getDevicePointer(),
&cu.getPosq().getDevicePointer(), &nb.getExclusionIndices().getDevicePointer(), &nb.getExclusionRowIndices().getDevicePointer(), &cu.getPosq().getDevicePointer(), &nb.getExclusionIndices().getDevicePointer(), &nb.getExclusionRowIndices().getDevicePointer(),
&covalentFlags.getDevicePointer(), &polarizationGroupFlags.getDevicePointer(), &startTileIndex, &numTileIndices, &covalentFlags.getDevicePointer(), &polarizationGroupFlags.getDevicePointer(), &startTileIndex, &numTileIndices,
&labFrameDipoles.getDevicePointer(), &labFrameQuadrupoles.getDevicePointer(), &inducedDipole.getDevicePointer(), &labFrameDipoles.getDevicePointer(), &labFrameQuadrupoles.getDevicePointer(), &inducedDipole.getDevicePointer(),
&inducedDipolePolar.getDevicePointer(), &inducedDipoleS->getDevicePointer(), &inducedDipolePolarS->getDevicePointer(), &inducedDipolePolar.getDevicePointer(), &inducedDipoleS->getDevicePointer(), &inducedDipolePolarS->getDevicePointer(),
&dampingAndThole.getDevicePointer()}; &dampingAndThole.getDevicePointer()};
cu.executeKernel(ediffKernel, ediffArgs, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize); cu.executeKernel(ediffKernel, ediffArgs, numForceThreadBlocks*ediffThreads, ediffThreads);
} }
/* -------------------------------------------------------------------------- * /* -------------------------------------------------------------------------- *
......
...@@ -424,6 +424,7 @@ private: ...@@ -424,6 +424,7 @@ private:
CudaContext& cu; CudaContext& cu;
System& system; System& system;
bool includeSurfaceArea; bool includeSurfaceArea;
int computeBornSumThreads, gkForceThreads, chainRuleThreads, ediffThreads;
CudaArray* params; CudaArray* params;
CudaArray* bornSum; CudaArray* bornSum;
CudaArray* bornRadii; CudaArray* bornRadii;
......
#define TILE_SIZE 32 #define TILE_SIZE 32
#define WARPS_PER_GROUP (THREAD_BLOCK_SIZE/TILE_SIZE)
/** /**
* Reduce the Born sums to compute the Born radii. * Reduce the Born sums to compute the Born radii.
...@@ -93,7 +92,7 @@ extern "C" __global__ void computeBornSum(unsigned long long* __restrict__ bornS ...@@ -93,7 +92,7 @@ extern "C" __global__ void computeBornSum(unsigned long long* __restrict__ bornS
unsigned int pos = warp*numTiles/totalWarps; unsigned int pos = warp*numTiles/totalWarps;
unsigned int end = (warp+1)*numTiles/totalWarps; unsigned int end = (warp+1)*numTiles/totalWarps;
unsigned int lasty = 0xFFFFFFFF; unsigned int lasty = 0xFFFFFFFF;
__shared__ AtomData1 localData[THREAD_BLOCK_SIZE]; __shared__ AtomData1 localData[BORN_SUM_THREAD_BLOCK_SIZE];
do { do {
// Extract the coordinates of this tile // Extract the coordinates of this tile
const unsigned int tgx = threadIdx.x & (TILE_SIZE-1); const unsigned int tgx = threadIdx.x & (TILE_SIZE-1);
...@@ -227,7 +226,7 @@ extern "C" __global__ void computeGKForces( ...@@ -227,7 +226,7 @@ extern "C" __global__ void computeGKForces(
unsigned int pos = startTileIndex+warp*numTiles/totalWarps; unsigned int pos = startTileIndex+warp*numTiles/totalWarps;
unsigned int end = startTileIndex+(warp+1)*numTiles/totalWarps; unsigned int end = startTileIndex+(warp+1)*numTiles/totalWarps;
real energy = 0; real energy = 0;
__shared__ AtomData2 localData[THREAD_BLOCK_SIZE]; __shared__ AtomData2 localData[GK_FORCE_THREAD_BLOCK_SIZE];
do { do {
// Extract the coordinates of this tile // Extract the coordinates of this tile
...@@ -466,7 +465,7 @@ extern "C" __global__ void computeChainRuleForce( ...@@ -466,7 +465,7 @@ extern "C" __global__ void computeChainRuleForce(
const unsigned int numTiles = numTileIndices; const unsigned int numTiles = numTileIndices;
unsigned int pos = startTileIndex+warp*numTiles/totalWarps; unsigned int pos = startTileIndex+warp*numTiles/totalWarps;
unsigned int end = startTileIndex+(warp+1)*numTiles/totalWarps; unsigned int end = startTileIndex+(warp+1)*numTiles/totalWarps;
__shared__ AtomData3 localData[THREAD_BLOCK_SIZE]; __shared__ AtomData3 localData[CHAIN_RULE_THREAD_BLOCK_SIZE];
do { do {
// Extract the coordinates of this tile // Extract the coordinates of this tile
...@@ -551,7 +550,7 @@ typedef struct { ...@@ -551,7 +550,7 @@ typedef struct {
real3 pos, force, dipole, inducedDipole, inducedDipolePolar, inducedDipoleS, inducedDipolePolarS; real3 pos, force, dipole, inducedDipole, inducedDipolePolar, inducedDipoleS, inducedDipolePolarS;
real q, quadrupoleXX, quadrupoleXY, quadrupoleXZ; real q, quadrupoleXX, quadrupoleXY, quadrupoleXZ;
real quadrupoleYY, quadrupoleYZ, quadrupoleZZ; real quadrupoleYY, quadrupoleYZ, quadrupoleZZ;
float thole, damp, padding; float thole, damp;
} AtomData4; } AtomData4;
__device__ void computeOneEDiffInteractionF1(AtomData4& atom1, volatile AtomData4& atom2, float dScale, float pScale, real& outputEnergy, real3& outputForce); __device__ void computeOneEDiffInteractionF1(AtomData4& atom1, volatile AtomData4& atom2, float dScale, float pScale, real& outputEnergy, real3& outputForce);
...@@ -618,9 +617,9 @@ extern "C" __global__ void computeEDiffForce( ...@@ -618,9 +617,9 @@ extern "C" __global__ void computeEDiffForce(
unsigned int pos = startTileIndex+warp*numTiles/totalWarps; unsigned int pos = startTileIndex+warp*numTiles/totalWarps;
unsigned int end = startTileIndex+(warp+1)*numTiles/totalWarps; unsigned int end = startTileIndex+(warp+1)*numTiles/totalWarps;
real energy = 0; real energy = 0;
__shared__ AtomData4 localData[THREAD_BLOCK_SIZE]; __shared__ AtomData4 localData[EDIFF_THREAD_BLOCK_SIZE];
__shared__ unsigned int exclusionRange[2*WARPS_PER_GROUP]; __shared__ unsigned int exclusionRange[2*(EDIFF_THREAD_BLOCK_SIZE/TILE_SIZE)];
__shared__ int exclusionIndex[WARPS_PER_GROUP]; __shared__ int exclusionIndex[EDIFF_THREAD_BLOCK_SIZE/TILE_SIZE];
do { do {
// Extract the coordinates of this tile // Extract the coordinates of this tile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment