Commit bb7cba08 authored by Peter Eastman's avatar Peter Eastman
Browse files

Improved GB performance in mixed precision

parent 73c4302d
......@@ -2794,6 +2794,7 @@ double CudaCalcGBSAOBCForceKernel::execute(ContextImpl& context, bool includeFor
force1Args.push_back(&cu.getEnergyBuffer().getDevicePointer());
force1Args.push_back(&cu.getPosq().getDevicePointer());
force1Args.push_back(&bornRadii->getDevicePointer());
force1Args.push_back(NULL);
if (nb.getUseCutoff()) {
force1Args.push_back(&nb.getInteractingTiles().getDevicePointer());
force1Args.push_back(&nb.getInteractionCount().getDevicePointer());
......@@ -2813,13 +2814,14 @@ double CudaCalcGBSAOBCForceKernel::execute(ContextImpl& context, bool includeFor
reduceBornSumKernel = cu.getKernel(module, "reduceBornSum");
reduceBornForceKernel = cu.getKernel(module, "reduceBornForce");
}
force1Args[5] = &includeEnergy;
if (nb.getUseCutoff()) {
if (maxTiles < nb.getInteractingTiles().getSize()) {
maxTiles = nb.getInteractingTiles().getSize();
computeSumArgs[3] = &nb.getInteractingTiles().getDevicePointer();
force1Args[5] = &nb.getInteractingTiles().getDevicePointer();
force1Args[6] = &nb.getInteractingTiles().getDevicePointer();
computeSumArgs[13] = &nb.getInteractingAtoms().getDevicePointer();
force1Args[15] = &nb.getInteractingAtoms().getDevicePointer();
force1Args[16] = &nb.getInteractingAtoms().getDevicePointer();
}
}
cu.executeKernel(computeBornSumKernel, &computeSumArgs[0], nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
......@@ -3754,6 +3756,7 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo
pairEnergyArgs.push_back(&cu.getPosq().getDevicePointer());
pairEnergyArgs.push_back(&cu.getNonbondedUtilities().getExclusions().getDevicePointer());
pairEnergyArgs.push_back(&cu.getNonbondedUtilities().getExclusionTiles().getDevicePointer());
pairEnergyArgs.push_back(NULL);
if (nb.getUseCutoff()) {
pairEnergyArgs.push_back(&nb.getInteractingTiles().getDevicePointer());
pairEnergyArgs.push_back(&nb.getInteractionCount().getDevicePointer());
......@@ -3832,13 +3835,14 @@ double CudaCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeFo
if (changed)
globals->upload(globalParamValues);
}
pairEnergyArgs[5] = &includeEnergy;
if (nb.getUseCutoff()) {
if (maxTiles < nb.getInteractingTiles().getSize()) {
maxTiles = nb.getInteractingTiles().getSize();
pairValueArgs[4] = &nb.getInteractingTiles().getDevicePointer();
pairEnergyArgs[5] = &nb.getInteractingTiles().getDevicePointer();
pairEnergyArgs[6] = &nb.getInteractingTiles().getDevicePointer();
pairValueArgs[14] = &nb.getInteractingAtoms().getDevicePointer();
pairEnergyArgs[15] = &nb.getInteractingAtoms().getDevicePointer();
pairEnergyArgs[16] = &nb.getInteractingAtoms().getDevicePointer();
}
}
cu.executeKernel(pairValueKernel, &pairValueArgs[0], nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
......
......@@ -14,7 +14,7 @@ typedef struct {
* Compute a force based on pair interactions.
*/
extern "C" __global__ void computeN2Energy(unsigned long long* __restrict__ forceBuffers, mixed* __restrict__ energyBuffer,
const real4* __restrict__ posq, const unsigned int* __restrict__ exclusions, const ushort2* __restrict__ exclusionTiles,
const real4* __restrict__ posq, const unsigned int* __restrict__ exclusions, const ushort2* __restrict__ exclusionTiles, bool needEnergy,
#ifdef USE_CUTOFF
const int* __restrict__ tiles, const unsigned int* __restrict__ interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize,
real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ, unsigned int maxTiles, const real4* __restrict__ blockCenter,
......@@ -78,7 +78,8 @@ extern "C" __global__ void computeN2Energy(unsigned long long* __restrict__ forc
COMPUTE_INTERACTION
dEdR /= -r;
}
energy += 0.5f*tempEnergy;
if (needEnergy)
energy += 0.5f*tempEnergy;
delta *= dEdR;
force.x -= delta.x;
force.y -= delta.y;
......@@ -130,7 +131,8 @@ extern "C" __global__ void computeN2Energy(unsigned long long* __restrict__ forc
COMPUTE_INTERACTION
dEdR /= -r;
}
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta *= dEdR;
force.x -= delta.x;
force.y -= delta.y;
......@@ -274,7 +276,8 @@ extern "C" __global__ void computeN2Energy(unsigned long long* __restrict__ forc
COMPUTE_INTERACTION
dEdR /= -r;
}
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta *= dEdR;
force.x -= delta.x;
force.y -= delta.y;
......@@ -318,7 +321,8 @@ extern "C" __global__ void computeN2Energy(unsigned long long* __restrict__ forc
COMPUTE_INTERACTION
dEdR /= -r;
}
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta *= dEdR;
force.x -= delta.x;
force.y -= delta.y;
......
......@@ -400,7 +400,7 @@ typedef struct {
*/
extern "C" __global__ void computeGBSAForce1(unsigned long long* __restrict__ forceBuffers, unsigned long long* __restrict__ global_bornForce,
mixed* __restrict__ energyBuffer, const real4* __restrict__ posq, const real* __restrict__ global_bornRadii,
mixed* __restrict__ energyBuffer, const real4* __restrict__ posq, const real* __restrict__ global_bornRadii, bool needEnergy,
#ifdef USE_CUTOFF
const int* __restrict__ tiles, const unsigned int* __restrict__ interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize,
real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ, unsigned int maxTiles, const real4* __restrict__ blockCenter,
......@@ -465,7 +465,8 @@ extern "C" __global__ void computeGBSAForce1(unsigned long long* __restrict__ fo
if (atom1 != y*TILE_SIZE+j)
tempEnergy -= scaledChargeProduct/CUTOFF;
#endif
energy += 0.5f*tempEnergy;
if (needEnergy)
energy += 0.5f*tempEnergy;
delta *= dEdR;
force.x -= delta.x;
force.y -= delta.y;
......@@ -519,7 +520,8 @@ extern "C" __global__ void computeGBSAForce1(unsigned long long* __restrict__ fo
#ifdef USE_CUTOFF
tempEnergy -= scaledChargeProduct/CUTOFF;
#endif
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta *= dEdR;
force.x -= delta.x;
force.y -= delta.y;
......@@ -667,7 +669,8 @@ extern "C" __global__ void computeGBSAForce1(unsigned long long* __restrict__ fo
#ifdef USE_CUTOFF
tempEnergy -= scaledChargeProduct/CUTOFF;
#endif
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta *= dEdR;
force.x -= delta.x;
force.y -= delta.y;
......@@ -716,7 +719,8 @@ extern "C" __global__ void computeGBSAForce1(unsigned long long* __restrict__ fo
#ifdef USE_CUTOFF
tempEnergy -= scaledChargeProduct/CUTOFF;
#endif
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta *= dEdR;
force.x -= delta.x;
force.y -= delta.y;
......
......@@ -2872,6 +2872,7 @@ double OpenCLCalcGBSAOBCForceKernel::execute(ContextImpl& context, bool includeF
force1Kernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
force1Kernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
force1Kernel.setArg<cl::Buffer>(index++, bornRadii->getDeviceBuffer());
index++; // Whether to include energy.
if (nb.getUseCutoff()) {
force1Kernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
force1Kernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
......@@ -2907,17 +2908,18 @@ double OpenCLCalcGBSAOBCForceKernel::execute(ContextImpl& context, bool includeF
reduceBornForceKernel.setArg<cl::Buffer>(index++, bornRadii->getDeviceBuffer());
reduceBornForceKernel.setArg<cl::Buffer>(index++, obcChain->getDeviceBuffer());
}
force1Kernel.setArg<cl_int>(5, includeEnergy);
if (nb.getUseCutoff()) {
setPeriodicBoxArgs(cl, computeBornSumKernel, 5);
setPeriodicBoxArgs(cl, force1Kernel, 7);
setPeriodicBoxArgs(cl, force1Kernel, 8);
if (maxTiles < nb.getInteractingTiles().getSize()) {
maxTiles = nb.getInteractingTiles().getSize();
computeBornSumKernel.setArg<cl::Buffer>(3, nb.getInteractingTiles().getDeviceBuffer());
computeBornSumKernel.setArg<cl_uint>(10, maxTiles);
computeBornSumKernel.setArg<cl::Buffer>(13, nb.getInteractingAtoms().getDeviceBuffer());
force1Kernel.setArg<cl::Buffer>(5, nb.getInteractingTiles().getDeviceBuffer());
force1Kernel.setArg<cl_uint>(12, maxTiles);
force1Kernel.setArg<cl::Buffer>(15, nb.getInteractingAtoms().getDeviceBuffer());
force1Kernel.setArg<cl::Buffer>(6, nb.getInteractingTiles().getDeviceBuffer());
force1Kernel.setArg<cl_uint>(13, maxTiles);
force1Kernel.setArg<cl::Buffer>(16, nb.getInteractingAtoms().getDeviceBuffer());
}
}
cl.executeKernel(computeBornSumKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
......@@ -3933,6 +3935,7 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*4*elementSize, NULL);
pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusions().getDeviceBuffer());
pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusionTiles().getDeviceBuffer());
index++; // Whether to include energy.
if (nb.getUseCutoff()) {
pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
......@@ -4029,17 +4032,18 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
if (changed)
globals->upload(globalParamValues);
}
pairEnergyKernel.setArg<cl_int>(7, includeEnergy);
if (nb.getUseCutoff()) {
setPeriodicBoxArgs(cl, pairValueKernel, 8);
setPeriodicBoxArgs(cl, pairEnergyKernel, 9);
setPeriodicBoxArgs(cl, pairEnergyKernel, 10);
if (maxTiles < nb.getInteractingTiles().getSize()) {
maxTiles = nb.getInteractingTiles().getSize();
pairValueKernel.setArg<cl::Buffer>(6, nb.getInteractingTiles().getDeviceBuffer());
pairValueKernel.setArg<cl_uint>(13, maxTiles);
pairValueKernel.setArg<cl::Buffer>(16, nb.getInteractingAtoms().getDeviceBuffer());
pairEnergyKernel.setArg<cl::Buffer>(7, nb.getInteractingTiles().getDeviceBuffer());
pairEnergyKernel.setArg<cl_uint>(14, maxTiles);
pairEnergyKernel.setArg<cl::Buffer>(17, nb.getInteractingAtoms().getDeviceBuffer());
pairEnergyKernel.setArg<cl::Buffer>(8, nb.getInteractingTiles().getDeviceBuffer());
pairEnergyKernel.setArg<cl_uint>(15, maxTiles);
pairEnergyKernel.setArg<cl::Buffer>(18, nb.getInteractingAtoms().getDeviceBuffer());
}
}
cl.executeKernel(pairValueKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
......
......@@ -18,7 +18,7 @@ __kernel void computeN2Energy(
#endif
__global mixed* restrict energyBuffer, __local real4* restrict local_force,
__global const real4* restrict posq, __local real4* restrict local_posq, __global const unsigned int* restrict exclusions,
__global const ushort2* exclusionTiles,
__global const ushort2* exclusionTiles, int needEnergy,
#ifdef USE_CUTOFF
__global const int* restrict tiles, __global const unsigned int* restrict interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize,
real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ, unsigned int maxTiles, __global const real4* restrict blockCenter,
......@@ -82,7 +82,8 @@ __kernel void computeN2Energy(
COMPUTE_INTERACTION
dEdR /= -r;
}
energy += 0.5f*tempEnergy;
if (needEnergy)
energy += 0.5f*tempEnergy;
delta.xyz *= dEdR;
force.xyz -= delta.xyz;
#ifdef USE_CUTOFF
......@@ -133,7 +134,8 @@ __kernel void computeN2Energy(
COMPUTE_INTERACTION
dEdR /= -r;
}
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta.xyz *= dEdR;
force.xyz -= delta.xyz;
atom2 = tbx+tj;
......@@ -289,7 +291,8 @@ __kernel void computeN2Energy(
COMPUTE_INTERACTION
dEdR /= -r;
}
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta.xyz *= dEdR;
force.xyz -= delta.xyz;
atom2 = tbx+tj;
......@@ -328,7 +331,8 @@ __kernel void computeN2Energy(
COMPUTE_INTERACTION
dEdR /= -r;
}
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta.xyz *= dEdR;
force.xyz -= delta.xyz;
atom2 = tbx+tj;
......
......@@ -385,7 +385,7 @@ __kernel void computeGBSAForce1(
#else
__global real4* restrict forceBuffers, __global real* restrict global_bornForce,
#endif
__global mixed* restrict energyBuffer, __global const real4* restrict posq, __global const real* restrict global_bornRadii,
__global mixed* restrict energyBuffer, __global const real4* restrict posq, __global const real* restrict global_bornRadii, int needEnergy,
#ifdef USE_CUTOFF
__global const int* restrict tiles, __global const unsigned int* restrict interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize,
real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ, unsigned int maxTiles, __global const real4* restrict blockCenter,
......@@ -452,7 +452,8 @@ __kernel void computeGBSAForce1(
if (atom1 != y*TILE_SIZE+j)
tempEnergy -= scaledChargeProduct/CUTOFF;
#endif
energy += 0.5f*tempEnergy;
if (needEnergy)
energy += 0.5f*tempEnergy;
delta.xyz *= dEdR;
force.xyz -= delta.xyz;
#ifdef USE_CUTOFF
......@@ -506,7 +507,8 @@ __kernel void computeGBSAForce1(
#ifdef USE_CUTOFF
tempEnergy -= scaledChargeProduct/CUTOFF;
#endif
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta.xyz *= dEdR;
force.xyz -= delta.xyz;
localData[tbx+tj].fx += delta.x;
......@@ -669,7 +671,8 @@ __kernel void computeGBSAForce1(
#ifdef USE_CUTOFF
tempEnergy -= scaledChargeProduct/CUTOFF;
#endif
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta.xyz *= dEdR;
force.xyz -= delta.xyz;
localData[tbx+tj].fx += delta.x;
......@@ -717,7 +720,8 @@ __kernel void computeGBSAForce1(
#ifdef USE_CUTOFF
tempEnergy -= scaledChargeProduct/CUTOFF;
#endif
energy += tempEnergy;
if (needEnergy)
energy += tempEnergy;
delta.xyz *= dEdR;
force.xyz -= delta.xyz;
localData[tbx+tj].fx += delta.x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment