Unverified Commit 4e742529 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Improved performance of CustomHbondForce on large systems (#4451)

* Improved performance of CustomHbondForce on large systems

* Fixed CUDA compilation errors
parent 86988b90
......@@ -893,7 +893,7 @@ public:
private:
class ForceInfo;
int numDonors, numAcceptors;
bool hasInitializedKernel;
bool hasInitializedKernel, useBoundingBoxes;
ComputeContext& cc;
ForceInfo* info;
ComputeParameterSet* donorParams;
......@@ -903,12 +903,13 @@ private:
ComputeArray acceptors;
ComputeArray donorExclusions;
ComputeArray acceptorExclusions;
ComputeArray donorBlockCenter, donorBlockSize, acceptorBlockCenter, acceptorBlockSize;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system;
ComputeKernel kernel;
ComputeKernel blockBoundsKernel, forceKernel;
};
/**
......
......@@ -4107,6 +4107,19 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
info = new ForceInfo(force);
cc.addForce(info);
// Decide whether to use bounding boxes to accelerate the calculation.
int numDonorBlocks = (numDonors+31)/32;
int numAcceptorBlocks = (numAcceptors+31)/32;
useBoundingBoxes = (force.getNonbondedMethod() != CustomHbondForce::NoCutoff && numDonorBlocks*numAcceptorBlocks > cc.getNumThreadBlocks());
if (useBoundingBoxes) {
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
donorBlockCenter.initialize(cc, numDonorBlocks, 4*elementSize, "donorBlockCenter");
donorBlockSize.initialize(cc, numDonorBlocks, 4*elementSize, "donorBlockSize");
acceptorBlockCenter.initialize(cc, numAcceptorBlocks, 4*elementSize, "acceptorBlockCenter");
acceptorBlockSize.initialize(cc, numAcceptorBlocks, 4*elementSize, "acceptorBlockSize");
}
// Record exclusions.
vector<mm_int4> donorExclusionVector(numDonors, mm_int4(-1, -1, -1, -1));
......@@ -4349,10 +4362,10 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
defines["NUM_DONORS"] = cc.intToString(numDonors);
defines["NUM_ACCEPTORS"] = cc.intToString(numAcceptors);
defines["NUM_DONOR_BLOCKS"] = cc.intToString((numDonors+31)/32);
defines["NUM_ACCEPTOR_BLOCKS"] = cc.intToString((numAcceptors+31)/32);
defines["NUM_DONOR_BLOCKS"] = cc.intToString(numDonorBlocks);
defines["NUM_ACCEPTOR_BLOCKS"] = cc.intToString(numAcceptorBlocks);
defines["M_PI"] = cc.doubleToString(M_PI);
defines["THREAD_BLOCK_SIZE"] = "64";
defines["THREAD_BLOCK_SIZE"] = "128";
if (force.getNonbondedMethod() != CustomHbondForce::NoCutoff) {
defines["USE_CUTOFF"] = "1";
defines["CUTOFF_SQUARED"] = cc.doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
......@@ -4361,8 +4374,11 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
defines["USE_PERIODIC"] = "1";
if (force.getNumExclusions() > 0)
defines["USE_EXCLUSIONS"] = "1";
if (useBoundingBoxes)
defines["USE_BOUNDING_BOXES"] = "1";
ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::customHbondForce, replacements), defines);
kernel = program->createKernel("computeHbondForces");
blockBoundsKernel = program->createKernel("findBlockBounds");
forceKernel = program->createKernel("computeHbondForces");
}
double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
......@@ -4382,27 +4398,48 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
}
if (!hasInitializedKernel) {
hasInitializedKernel = true;
kernel->addArg(cc.getLongForceBuffer());
kernel->addArg(cc.getEnergyBuffer());
kernel->addArg(cc.getPosq());
kernel->addArg(donorExclusions);
kernel->addArg(donors);
kernel->addArg(acceptors);
if (useBoundingBoxes) {
blockBoundsKernel->addArg(donors);
blockBoundsKernel->addArg(acceptors);
for (int i = 0; i < 5; i++)
blockBoundsKernel->addArg(); // Periodic box size arguments are set when the kernel is executed.
blockBoundsKernel->addArg(cc.getPosq());
blockBoundsKernel->addArg(donorBlockCenter);
blockBoundsKernel->addArg(donorBlockSize);
blockBoundsKernel->addArg(acceptorBlockCenter);
blockBoundsKernel->addArg(acceptorBlockSize);
}
forceKernel->addArg(cc.getLongForceBuffer());
forceKernel->addArg(cc.getEnergyBuffer());
forceKernel->addArg(cc.getPosq());
forceKernel->addArg(donorExclusions);
forceKernel->addArg(donors);
forceKernel->addArg(acceptors);
for (int i = 0; i < 5; i++)
kernel->addArg(); // Periodic box size arguments are set when the kernel is executed.
forceKernel->addArg(); // Periodic box size arguments are set when the kernel is executed.
if (useBoundingBoxes) {
forceKernel->addArg(donorBlockCenter);
forceKernel->addArg(donorBlockSize);
forceKernel->addArg(acceptorBlockCenter);
forceKernel->addArg(acceptorBlockSize);
}
if (globals.isInitialized())
kernel->addArg(globals);
forceKernel->addArg(globals);
for (auto& parameter : donorParams->getParameterInfos())
kernel->addArg(parameter.getArray());
forceKernel->addArg(parameter.getArray());
for (auto& parameter : acceptorParams->getParameterInfos())
kernel->addArg(parameter.getArray());
forceKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctionArrays)
kernel->addArg(function);
forceKernel->addArg(function);
}
if (useBoundingBoxes) {
setPeriodicBoxArgs(cc, blockBoundsKernel, 2);
blockBoundsKernel->execute(max(numDonors, numAcceptors));
}
setPeriodicBoxArgs(cc, kernel, 6);
setPeriodicBoxArgs(cc, forceKernel, 6);
int numDonorBlocks = (numDonors+31)/32;
int numAcceptorBlocks = (numAcceptors+31)/32;
kernel->execute(numDonorBlocks*numAcceptorBlocks*32, cc.getIsCPU() ? 32 : 64);
forceKernel->execute(numDonorBlocks*numAcceptorBlocks*32, cc.getIsCPU() ? 32 : 128);
return 0.0;
}
......
DEVICE void findBoundingBox(GLOBAL const real4* RESTRICT posq, GLOBAL const int4* RESTRICT atoms, int numGroups,
real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
GLOBAL real4* center, GLOBAL real4* blockSize) {
real4 pos = posq[atoms[0].x];
#ifdef USE_PERIODIC
APPLY_PERIODIC_TO_POS(pos)
#endif
real4 minPos = pos;
real4 maxPos = pos;
for (int i = 1; i < numGroups; i++) {
pos = posq[atoms[i].x];
#ifdef USE_PERIODIC
real4 center = 0.5f*(maxPos+minPos);
APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)
#endif
minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0);
maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0);
}
*blockSize = 0.5f*(maxPos-minPos);
*center = 0.5f*(maxPos+minPos);
}
KERNEL void findBlockBounds(GLOBAL const int4* RESTRICT donorAtoms, GLOBAL const int4* RESTRICT acceptorAtoms,
real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
GLOBAL const real4* RESTRICT posq, GLOBAL real4* RESTRICT donorBlockCenter, GLOBAL real4* RESTRICT donorBlockSize,
GLOBAL real4* RESTRICT acceptorBlockCenter, GLOBAL real4* RESTRICT acceptorBlockSize) {
for (int index = GLOBAL_ID; index < NUM_DONOR_BLOCKS; index += GLOBAL_SIZE) {
findBoundingBox(posq, donorAtoms+index*32, min(32, NUM_DONORS-index*32), periodicBoxSize,
invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ, donorBlockCenter+index,
donorBlockSize+index);
}
for (int index = GLOBAL_ID; index < NUM_ACCEPTOR_BLOCKS; index += GLOBAL_SIZE) {
findBoundingBox(posq, acceptorAtoms+index*32, min(32, NUM_ACCEPTORS-index*32), periodicBoxSize,
invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ, acceptorBlockCenter+index,
acceptorBlockSize+index);
}
}
/**
* Compute the difference between two vectors, optionally taking periodic boundary conditions into account
* and setting the fourth component to the squared magnitude.
......@@ -68,6 +106,10 @@ KERNEL void computeHbondForces(
GLOBAL mixed* RESTRICT energyBuffer, GLOBAL const real4* RESTRICT posq, GLOBAL const int4* RESTRICT exclusions,
GLOBAL const int4* RESTRICT donorAtoms, GLOBAL const int4* RESTRICT acceptorAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize,
real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ
#ifdef USE_BOUNDING_BOXES
, GLOBAL real4* RESTRICT donorBlockCenter, GLOBAL real4* RESTRICT donorBlockSize,
GLOBAL real4* RESTRICT acceptorBlockCenter, GLOBAL real4* RESTRICT acceptorBlockSize
#endif
PARAMETER_ARGUMENTS) {
const unsigned int totalWarps = GLOBAL_SIZE/32;
const unsigned int warp = GLOBAL_ID/32;
......@@ -76,15 +118,28 @@ KERNEL void computeHbondForces(
LOCAL AcceptorData localData[THREAD_BLOCK_SIZE];
mixed energy = 0;
for (int tile = warp; tile < NUM_DONOR_BLOCKS*NUM_ACCEPTOR_BLOCKS; tile += totalWarps) {
int donorStart = (tile/NUM_ACCEPTOR_BLOCKS)*32;
int acceptorStart = (tile%NUM_ACCEPTOR_BLOCKS)*32;
int donorBlock = tile/NUM_ACCEPTOR_BLOCKS;
int acceptorBlock = tile%NUM_ACCEPTOR_BLOCKS;
#ifdef USE_BOUNDING_BOXES
real4 blockDelta = donorBlockCenter[donorBlock]-acceptorBlockCenter[acceptorBlock];
#ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(blockDelta)
#endif
real4 donorSize = donorBlockSize[donorBlock];
real4 acceptorSize = acceptorBlockSize[acceptorBlock];
blockDelta.x = max((real) 0, fabs(blockDelta.x)-donorSize.x-acceptorSize.x);
blockDelta.y = max((real) 0, fabs(blockDelta.y)-donorSize.y-acceptorSize.y);
blockDelta.z = max((real) 0, fabs(blockDelta.z)-donorSize.z-acceptorSize.z);
if (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z >= CUTOFF_SQUARED)
continue;
#endif
// Load information about the donor this thread will compute forces on.
real3 f1 = make_real3(0);
real3 f2 = make_real3(0);
real3 f3 = make_real3(0);
int donorIndex = donorStart+indexInWarp;
int donorIndex = donorBlock*32 + indexInWarp;
int4 atoms, exclusionIndices;
real3 d1, d2, d3;
if (donorIndex < NUM_DONORS) {
......@@ -105,6 +160,7 @@ KERNEL void computeHbondForces(
localData[LOCAL_ID].f1 = make_real3(0);
localData[LOCAL_ID].f2 = make_real3(0);
localData[LOCAL_ID].f3 = make_real3(0);
int acceptorStart = acceptorBlock*32;
int blockSize = min(32, NUM_ACCEPTORS-acceptorStart);
int4 atoms2 = (indexInWarp < blockSize ? acceptorAtoms[acceptorStart+indexInWarp] : make_int4(-1));
if (indexInWarp < blockSize) {
......
......@@ -34,6 +34,7 @@
#include "openmm/CustomHbondForce.h"
#include "openmm/HarmonicAngleForce.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h"
#include "openmm/PeriodicTorsionForce.h"
#include "openmm/System.h"
#include "openmm/VerletIntegrator.h"
......@@ -309,10 +310,17 @@ void testParameters() {
ASSERT_EQUAL_TOL(2*(2*1.8+2.1)+2*(2*1.5+2.1), state.getPotentialEnergy(), TOL);
}
void testLargeSystem() {
void testLargeSystem(CustomHbondForce::NonbondedMethod method) {
int numParticles = 5000;
double boxSize = 3.0;
double cutoff = 1.0;
System system;
CustomHbondForce* custom = new CustomHbondForce("distance(d1,a1)^2");
system.setDefaultPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
CustomHbondForce* custom = new CustomHbondForce("(distance(d1,a1)-1)^2");
custom->setNonbondedMethod(method);
custom->setCutoffDistance(cutoff);
NonbondedForce* nb = new NonbondedForce(); // So that atom reordering will be done
nb->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
vector<Vec3> positions(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
......@@ -322,28 +330,34 @@ void testLargeSystem() {
custom->addDonor(i, -1, -1);
else
custom->addAcceptor(i, -1, -1);
positions[i] = Vec3(3.0*genrand_real2(sfmt), 3.0*genrand_real2(sfmt), 3.0*genrand_real2(sfmt));
positions[i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
nb->addParticle(0.0, 1.0, 0.0);
}
system.addForce(custom);
system.addForce(nb);
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
double expectedEnergy = 0;
for (int i = 0; i < numParticles; i += 2) {
Vec3 expectedForce;
for (int j = 1; j < numParticles; j += 2) {
Vec3 d = positions[i]-positions[j];
if (method == CustomHbondForce::CutoffPeriodic) {
d[0] -= round(d[0]/boxSize)*boxSize;
d[1] -= round(d[1]/boxSize)*boxSize;
d[2] -= round(d[2]/boxSize)*boxSize;
}
double r = sqrt(d.dot(d));
expectedEnergy += r*r;
if (method == CustomHbondForce::NoCutoff || r < cutoff) {
expectedEnergy += (r-1)*(r-1);
expectedForce -= 2*(r-1)*d/r;
}
}
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
for (int i = 0; i < numParticles; i += 2) {
Vec3 expectedForce;
for (int j = 1; j < numParticles; j += 2)
expectedForce += 2*(positions[j]-positions[i]);
ASSERT_EQUAL_VEC(expectedForce, state.getForces()[i], 1e-5);
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
}
void runPlatformTests();
......@@ -358,7 +372,9 @@ int main(int argc, char* argv[]) {
test2DFunction();
testIllegalVariable();
testParameters();
testLargeSystem();
testLargeSystem(CustomHbondForce::NoCutoff);
testLargeSystem(CustomHbondForce::CutoffNonPeriodic);
testLargeSystem(CustomHbondForce::CutoffPeriodic);
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment