Unverified Commit b4c54303 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Faster implementation of CustomHbondForce (#4060)

* Faster implementation of CustomHbondForce

* Minor optimization

* Optimized writing forces, which are often zero

* Fix test failure on CPU OpenCL

* Bug fix
parent 2ae50f9d
......@@ -793,7 +793,7 @@ private:
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system;
ComputeKernel donorKernel, acceptorKernel;
ComputeKernel kernel;
};
/**
......
......@@ -3767,18 +3767,13 @@ CommonCalcCustomHbondForceKernel::~CommonCalcCustomHbondForceKernel() {
delete acceptorParams;
}
static void addDonorAndAcceptorCode(stringstream& computeDonor, stringstream& computeAcceptor, const string& value) {
computeDonor << value;
computeAcceptor << value;
}
static void applyDonorAndAcceptorForces(stringstream& applyToDonor, stringstream& applyToAcceptor, int atom, const string& value, bool trim=true) {
static void applyDonorAndAcceptorForces(stringstream& apply, int atom, const string& value, bool trim=true) {
string forceNames[] = {"f1", "f2", "f3"};
string toAdd = (trim ? "trimTo3("+value+")" : value);
if (atom < 3)
applyToAcceptor << forceNames[atom]<<" += "<<toAdd<<";\n";
apply << "localData[tbx+index]." << forceNames[atom]<<" += "<<toAdd<<";\n";
else
applyToDonor << forceNames[atom-3]<<" += "<<toAdd<<";\n";
apply << forceNames[atom-3]<<" += "<<toAdd<<";\n";
}
void CommonCalcCustomHbondForceKernel::initialize(const System& system, const CustomHbondForce& force) {
......@@ -3920,16 +3915,16 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
computedDeltas.insert("D1A1");
string atomNames[] = {"A1", "A2", "A3", "D1", "D2", "D3"};
string atomNamesLower[] = {"a1", "a2", "a3", "d1", "d2", "d3"};
stringstream computeDonor, computeAcceptor, extraArgs;
stringstream compute, extraArgs;
int index = 0;
for (auto& distance : distances) {
const vector<int>& atoms = distance.second;
string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
if (computedDeltas.count(deltaName) == 0) {
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName+" = delta("+atomNamesLower[atoms[0]]+", "+atomNamesLower[atoms[1]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
compute << "real4 delta"+deltaName+" = delta("+atomNamesLower[atoms[0]]+", "+atomNamesLower[atoms[1]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
computedDeltas.insert(deltaName);
}
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real r_"+deltaName+" = SQRT(delta"+deltaName+".w);\n");
compute << "real r_"+deltaName+" = SQRT(delta"+deltaName+".w);\n";
variables[distance.first] = "r_"+deltaName;
forceExpressions["real dEdDistance"+cc.intToString(index)+" = "] = energyExpression.differentiate(distance.first).optimize();
index++;
......@@ -3941,14 +3936,14 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
string angleName = "angle_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]];
if (computedDeltas.count(deltaName1) == 0) {
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName1+" = delta("+atomNamesLower[atoms[1]]+", "+atomNamesLower[atoms[0]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
compute << "real4 delta"+deltaName1+" = delta("+atomNamesLower[atoms[1]]+", "+atomNamesLower[atoms[0]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
computedDeltas.insert(deltaName1);
}
if (computedDeltas.count(deltaName2) == 0) {
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName2+" = delta("+atomNamesLower[atoms[1]]+", "+atomNamesLower[atoms[2]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
compute << "real4 delta"+deltaName2+" = delta("+atomNamesLower[atoms[1]]+", "+atomNamesLower[atoms[2]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
computedDeltas.insert(deltaName2);
}
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real "+angleName+" = computeAngle(delta"+deltaName1+", delta"+deltaName2+");\n");
compute << "real "+angleName+" = computeAngle(delta"+deltaName1+", delta"+deltaName2+");\n";
variables[angle.first] = angleName;
forceExpressions["real dEdAngle"+cc.intToString(index)+" = "] = energyExpression.differentiate(angle.first).optimize();
index++;
......@@ -3963,21 +3958,21 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
string dihedralName = "dihedral_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]]+atomNames[atoms[3]];
if (computedDeltas.count(deltaName1) == 0) {
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName1+" = delta("+atomNamesLower[atoms[0]]+", "+atomNamesLower[atoms[1]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
compute << "real4 delta"+deltaName1+" = delta("+atomNamesLower[atoms[0]]+", "+atomNamesLower[atoms[1]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
computedDeltas.insert(deltaName1);
}
if (computedDeltas.count(deltaName2) == 0) {
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName2+" = delta("+atomNamesLower[atoms[2]]+", "+atomNamesLower[atoms[1]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
compute << "real4 delta"+deltaName2+" = delta("+atomNamesLower[atoms[2]]+", "+atomNamesLower[atoms[1]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
computedDeltas.insert(deltaName2);
}
if (computedDeltas.count(deltaName3) == 0) {
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName3+" = delta("+atomNamesLower[atoms[2]]+", "+atomNamesLower[atoms[3]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
compute << "real4 delta"+deltaName3+" = delta("+atomNamesLower[atoms[2]]+", "+atomNamesLower[atoms[3]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
computedDeltas.insert(deltaName3);
}
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 "+crossName1+" = computeCross(delta"+deltaName1+", delta"+deltaName2+");\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 "+crossName2+" = computeCross(delta"+deltaName2+", delta"+deltaName3+");\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real "+dihedralName+" = computeAngle("+crossName1+", "+crossName2+");\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, dihedralName+" *= (delta"+deltaName1+".x*"+crossName2+".x + delta"+deltaName1+".y*"+crossName2+".y + delta"+deltaName1+".z*"+crossName2+".z < 0 ? -1 : 1);\n");
compute << "real4 "+crossName1+" = computeCross(delta"+deltaName1+", delta"+deltaName2+");\n";
compute << "real4 "+crossName2+" = computeCross(delta"+deltaName2+", delta"+deltaName3+");\n";
compute << "real "+dihedralName+" = computeAngle("+crossName1+", "+crossName2+");\n";
compute << dihedralName+" *= (delta"+deltaName1+".x*"+crossName2+".x + delta"+deltaName1+".y*"+crossName2+".y + delta"+deltaName1+".z*"+crossName2+".z < 0 ? -1 : 1);\n";
variables[dihedral.first] = dihedralName;
forceExpressions["real dEdDihedral"+cc.intToString(index)+" = "] = energyExpression.differentiate(dihedral.first).optimize();
index++;
......@@ -3990,19 +3985,18 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
for (int i = 0; i < (int) donorParams->getParameterInfos().size(); i++) {
ComputeParameterInfo& parameter = donorParams->getParameterInfos()[i];
extraArgs << ", GLOBAL const "+parameter.getType()+"* RESTRICT donor"+parameter.getName();
addDonorAndAcceptorCode(computeDonor, computeAcceptor, parameter.getType()+" donorParams"+cc.intToString(i+1)+" = donor"+parameter.getName()+"[donorIndex];\n");
compute << parameter.getType()+" donorParams"+cc.intToString(i+1)+" = donor"+parameter.getName()+"[donorIndex];\n";
}
for (int i = 0; i < (int) acceptorParams->getParameterInfos().size(); i++) {
ComputeParameterInfo& parameter = acceptorParams->getParameterInfos()[i];
extraArgs << ", GLOBAL const "+parameter.getType()+"* RESTRICT acceptor"+parameter.getName();
addDonorAndAcceptorCode(computeDonor, computeAcceptor, parameter.getType()+" acceptorParams"+cc.intToString(i+1)+" = acceptor"+parameter.getName()+"[acceptorIndex];\n");
compute << parameter.getType()+" acceptorParams"+cc.intToString(i+1)+" = acceptor"+parameter.getName()+"[acceptorIndex];\n";
}
// Now evaluate the expressions.
computeAcceptor << cc.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
forceExpressions["energy += "] = energyExpression;
computeDonor << cc.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
compute << cc.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
// Finally, apply forces to atoms.
......@@ -4011,8 +4005,8 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
const vector<int>& atoms = distance.second;
string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
string value = "(dEdDistance"+cc.intToString(index)+"/r_"+deltaName+")*delta"+deltaName;
applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[0], "-"+value);
applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[1], value);
applyDonorAndAcceptorForces(compute, atoms[0], "-"+value);
applyDonorAndAcceptorForces(compute, atoms[1], value);
index++;
}
index = 0;
......@@ -4020,16 +4014,16 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
const vector<int>& atoms = angle.second;
string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "{\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real3 crossProd = trimTo3(cross(delta"+deltaName2+", delta"+deltaName1+"));\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real lengthCross = max(SQRT(dot(crossProd,crossProd)), (real) 1e-6f);\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real3 deltaCross0 = -cross(trimTo3(delta"+deltaName1+"), crossProd)*dEdAngle"+cc.intToString(index)+"/(delta"+deltaName1+".w*lengthCross);\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real3 deltaCross2 = cross(trimTo3(delta"+deltaName2+"), crossProd)*dEdAngle"+cc.intToString(index)+"/(delta"+deltaName2+".w*lengthCross);\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real3 deltaCross1 = -(deltaCross0+deltaCross2);\n");
applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[0], "deltaCross0", false);
applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[1], "deltaCross1", false);
applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[2], "deltaCross2", false);
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "}\n");
compute << "{\n";
compute << "real3 crossProd = trimTo3(cross(delta"+deltaName2+", delta"+deltaName1+"));\n";
compute << "real lengthCross = max(SQRT(dot(crossProd,crossProd)), (real) 1e-6f);\n";
compute << "real3 deltaCross0 = -cross(trimTo3(delta"+deltaName1+"), crossProd)*dEdAngle"+cc.intToString(index)+"/(delta"+deltaName1+".w*lengthCross);\n";
compute << "real3 deltaCross2 = cross(trimTo3(delta"+deltaName2+"), crossProd)*dEdAngle"+cc.intToString(index)+"/(delta"+deltaName2+".w*lengthCross);\n";
compute << "real3 deltaCross1 = -(deltaCross0+deltaCross2);\n";
applyDonorAndAcceptorForces(compute, atoms[0], "deltaCross0", false);
applyDonorAndAcceptorForces(compute, atoms[1], "deltaCross1", false);
applyDonorAndAcceptorForces(compute, atoms[2], "deltaCross2", false);
compute << "}\n";
index++;
}
index = 0;
......@@ -4040,34 +4034,35 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "{\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real r = SQRT(delta"+deltaName2+".w);\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 ff;\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.x = (-dEdDihedral"+cc.intToString(index)+"*r)/"+crossName1+".w;\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.y = (delta"+deltaName1+".x*delta"+deltaName2+".x + delta"+deltaName1+".y*delta"+deltaName2+".y + delta"+deltaName1+".z*delta"+deltaName2+".z)/delta"+deltaName2+".w;\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.z = (delta"+deltaName3+".x*delta"+deltaName2+".x + delta"+deltaName3+".y*delta"+deltaName2+".y + delta"+deltaName3+".z*delta"+deltaName2+".z)/delta"+deltaName2+".w;\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.w = (dEdDihedral"+cc.intToString(index)+"*r)/"+crossName2+".w;\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 internalF0 = ff.x*"+crossName1+";\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 internalF3 = ff.w*"+crossName2+";\n");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 s = ff.y*internalF0 - ff.z*internalF3;\n");
applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[0], "internalF0");
applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[1], "s-internalF0");
applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[2], "-s-internalF3");
applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[3], "internalF3");
addDonorAndAcceptorCode(computeDonor, computeAcceptor, "}\n");
compute << "{\n";
compute << "real r = SQRT(delta"+deltaName2+".w);\n";
compute << "real4 ff;\n";
compute << "ff.x = (-dEdDihedral"+cc.intToString(index)+"*r)/"+crossName1+".w;\n";
compute << "ff.y = (delta"+deltaName1+".x*delta"+deltaName2+".x + delta"+deltaName1+".y*delta"+deltaName2+".y + delta"+deltaName1+".z*delta"+deltaName2+".z)/delta"+deltaName2+".w;\n";
compute << "ff.z = (delta"+deltaName3+".x*delta"+deltaName2+".x + delta"+deltaName3+".y*delta"+deltaName2+".y + delta"+deltaName3+".z*delta"+deltaName2+".z)/delta"+deltaName2+".w;\n";
compute << "ff.w = (dEdDihedral"+cc.intToString(index)+"*r)/"+crossName2+".w;\n";
compute << "real4 internalF0 = ff.x*"+crossName1+";\n";
compute << "real4 internalF3 = ff.w*"+crossName2+";\n";
compute << "real4 s = ff.y*internalF0 - ff.z*internalF3;\n";
applyDonorAndAcceptorForces(compute, atoms[0], "internalF0");
applyDonorAndAcceptorForces(compute, atoms[1], "s-internalF0");
applyDonorAndAcceptorForces(compute, atoms[2], "-s-internalF3");
applyDonorAndAcceptorForces(compute, atoms[3], "internalF3");
compute << "}\n";
index++;
}
// Generate the kernels.
map<string, string> replacements;
replacements["COMPUTE_DONOR_FORCE"] = computeDonor.str();
replacements["COMPUTE_ACCEPTOR_FORCE"] = computeAcceptor.str();
replacements["COMPUTE_FORCE"] = compute.str();
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
map<string, string> defines;
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
defines["NUM_DONORS"] = cc.intToString(numDonors);
defines["NUM_ACCEPTORS"] = cc.intToString(numAcceptors);
defines["NUM_DONOR_BLOCKS"] = cc.intToString((numDonors+31)/32);
defines["NUM_ACCEPTOR_BLOCKS"] = cc.intToString((numAcceptors+31)/32);
defines["M_PI"] = cc.doubleToString(M_PI);
defines["THREAD_BLOCK_SIZE"] = "64";
if (force.getNonbondedMethod() != CustomHbondForce::NoCutoff) {
......@@ -4079,8 +4074,7 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
if (force.getNumExclusions() > 0)
defines["USE_EXCLUSIONS"] = "1";
ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::customHbondForce, replacements), defines);
donorKernel = program->createKernel("computeDonorForces");
acceptorKernel = program->createKernel("computeAcceptorForces");
kernel = program->createKernel("computeHbondForces");
}
double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
......@@ -4100,43 +4094,27 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
}
if (!hasInitializedKernel) {
hasInitializedKernel = true;
donorKernel->addArg(cc.getLongForceBuffer());
donorKernel->addArg(cc.getEnergyBuffer());
donorKernel->addArg(cc.getPosq());
donorKernel->addArg(donorExclusions);
donorKernel->addArg(donors);
donorKernel->addArg(acceptors);
for (int i = 0; i < 5; i++)
donorKernel->addArg(); // Periodic box size arguments are set when the kernel is executed.
if (globals.isInitialized())
donorKernel->addArg(globals);
for (auto& parameter : donorParams->getParameterInfos())
donorKernel->addArg(parameter.getArray());
for (auto& parameter : acceptorParams->getParameterInfos())
donorKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctionArrays)
donorKernel->addArg(function);
acceptorKernel->addArg(cc.getLongForceBuffer());
acceptorKernel->addArg(cc.getEnergyBuffer());
acceptorKernel->addArg(cc.getPosq());
acceptorKernel->addArg(acceptorExclusions);
acceptorKernel->addArg(donors);
acceptorKernel->addArg(acceptors);
kernel->addArg(cc.getLongForceBuffer());
kernel->addArg(cc.getEnergyBuffer());
kernel->addArg(cc.getPosq());
kernel->addArg(donorExclusions);
kernel->addArg(donors);
kernel->addArg(acceptors);
for (int i = 0; i < 5; i++)
acceptorKernel->addArg(); // Periodic box size arguments are set when the kernel is executed.
kernel->addArg(); // Periodic box size arguments are set when the kernel is executed.
if (globals.isInitialized())
acceptorKernel->addArg(globals);
kernel->addArg(globals);
for (auto& parameter : donorParams->getParameterInfos())
acceptorKernel->addArg(parameter.getArray());
kernel->addArg(parameter.getArray());
for (auto& parameter : acceptorParams->getParameterInfos())
acceptorKernel->addArg(parameter.getArray());
kernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctionArrays)
acceptorKernel->addArg(function);
kernel->addArg(function);
}
setPeriodicBoxArgs(cc, donorKernel, 6);
donorKernel->execute(max(numDonors, numAcceptors), 64);
setPeriodicBoxArgs(cc, acceptorKernel, 6);
acceptorKernel->execute(max(numDonors, numAcceptors), 64);
setPeriodicBoxArgs(cc, kernel, 6);
int numDonorBlocks = (numDonors+31)/32;
int numAcceptorBlocks = (numAcceptors+31)/32;
kernel->execute(numDonorBlocks*numAcceptorBlocks*32, cc.getIsCPU() ? 32 : 64);
return 0.0;
}
......@@ -7797,4 +7775,4 @@ void CommonApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& contex
cc.setPosCellOffsets(lastPosCellOffsets);
if (savedFloatForces.isInitialized())
savedFloatForces.copyTo(cc.getFloatForceBuffer());
}
}
\ No newline at end of file
......@@ -2,7 +2,7 @@
* Compute the difference between two vectors, optionally taking periodic boundary conditions into account
* and setting the fourth component to the squared magnitude.
*/
inline DEVICE real4 delta(real4 vec1, real4 vec2, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ) {
inline DEVICE real4 delta(real3 vec1, real3 vec2, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ) {
real4 result = make_real4(vec1.x-vec2.x, vec1.y-vec2.y, vec1.z-vec2.z, 0);
#ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(result)
......@@ -41,184 +41,116 @@ inline DEVICE real4 computeCross(real4 vec1, real4 vec2) {
}
/**
* Compute forces on donors.
* Write the force on an atom to global memory.
*/
KERNEL void computeDonorForces(
inline DEVICE void applyForce(int atom, real3 f, GLOBAL mm_ulong* force) {
if (atom > -1) {
if (f.x != 0)
ATOMIC_ADD(&force[atom], (mm_ulong) realToFixedPoint(f.x));
if (f.y != 0)
ATOMIC_ADD(&force[atom+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f.y));
if (f.z != 0)
ATOMIC_ADD(&force[atom+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f.z));
MEM_FENCE;
}
}
typedef struct {
real3 pos1, pos2, pos3;
real3 f1, f2, f3;
} AcceptorData;
/**
* Compute forces on donors and acceptors.
*/
KERNEL void computeHbondForces(
GLOBAL mm_ulong* RESTRICT force,
GLOBAL mixed* RESTRICT energyBuffer, GLOBAL const real4* RESTRICT posq, GLOBAL const int4* RESTRICT exclusions,
GLOBAL const int4* RESTRICT donorAtoms, GLOBAL const int4* RESTRICT acceptorAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize,
real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ
PARAMETER_ARGUMENTS) {
LOCAL real4 posBuffer[3*THREAD_BLOCK_SIZE];
const unsigned int totalWarps = GLOBAL_SIZE/32;
const unsigned int warp = GLOBAL_ID/32;
const int indexInWarp = GLOBAL_ID%32;
const int tbx = LOCAL_ID-indexInWarp;
LOCAL AcceptorData localData[THREAD_BLOCK_SIZE];
mixed energy = 0;
real3 f1 = make_real3(0);
real3 f2 = make_real3(0);
real3 f3 = make_real3(0);
for (int donorStart = 0; donorStart < NUM_DONORS; donorStart += GLOBAL_SIZE) {
for (int tile = warp; tile < NUM_DONOR_BLOCKS*NUM_ACCEPTOR_BLOCKS; tile += totalWarps) {
int donorStart = (tile/NUM_ACCEPTOR_BLOCKS)*32;
int acceptorStart = (tile%NUM_ACCEPTOR_BLOCKS)*32;
// Load information about the donor this thread will compute forces on.
int donorIndex = donorStart+GLOBAL_ID;
real3 f1 = make_real3(0);
real3 f2 = make_real3(0);
real3 f3 = make_real3(0);
int donorIndex = donorStart+indexInWarp;
int4 atoms, exclusionIndices;
real4 d1, d2, d3;
real3 d1, d2, d3;
if (donorIndex < NUM_DONORS) {
atoms = donorAtoms[donorIndex];
d1 = (atoms.x > -1 ? posq[atoms.x] : make_real4(0));
d2 = (atoms.y > -1 ? posq[atoms.y] : make_real4(0));
d3 = (atoms.z > -1 ? posq[atoms.z] : make_real4(0));
d1 = (atoms.x > -1 ? trimTo3(posq[atoms.x]) : make_real3(0));
d2 = (atoms.y > -1 ? trimTo3(posq[atoms.y]) : make_real3(0));
d3 = (atoms.z > -1 ? trimTo3(posq[atoms.z]) : make_real3(0));
#ifdef USE_EXCLUSIONS
exclusionIndices = exclusions[donorIndex];
#endif
}
else
atoms = make_int4(-1, -1, -1, -1);
for (int acceptorStart = 0; acceptorStart < NUM_ACCEPTORS; acceptorStart += LOCAL_SIZE) {
// Load the next block of acceptors into local memory.
SYNC_THREADS;
int blockSize = min((int) LOCAL_SIZE, NUM_ACCEPTORS-acceptorStart);
if (LOCAL_ID < blockSize) {
int4 atoms2 = acceptorAtoms[acceptorStart+LOCAL_ID];
posBuffer[3*LOCAL_ID] = (atoms2.x > -1 ? posq[atoms2.x] : make_real4(0));
posBuffer[3*LOCAL_ID+1] = (atoms2.y > -1 ? posq[atoms2.y] : make_real4(0));
posBuffer[3*LOCAL_ID+2] = (atoms2.z > -1 ? posq[atoms2.z] : make_real4(0));
}
SYNC_THREADS;
if (donorIndex < NUM_DONORS) {
for (int index = 0; index < blockSize; index++) {
int acceptorIndex = acceptorStart+index;
#ifdef USE_EXCLUSIONS
if (acceptorIndex == exclusionIndices.x || acceptorIndex == exclusionIndices.y || acceptorIndex == exclusionIndices.z || acceptorIndex == exclusionIndices.w)
continue;
#endif
// Compute the interaction between a donor and an acceptor.
// Load information about the acceptors into local memory.
real4 a1 = posBuffer[3*index];
real4 a2 = posBuffer[3*index+1];
real4 a3 = posBuffer[3*index+2];
real4 deltaD1A1 = delta(d1, a1, periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);
#ifdef USE_CUTOFF
if (deltaD1A1.w < CUTOFF_SQUARED) {
#endif
COMPUTE_DONOR_FORCE
#ifdef USE_CUTOFF
}
#endif
}
}
SYNC_WARPS;
localData[LOCAL_ID].f1 = make_real3(0);
localData[LOCAL_ID].f2 = make_real3(0);
localData[LOCAL_ID].f3 = make_real3(0);
int blockSize = min(32, NUM_ACCEPTORS-acceptorStart);
int4 atoms2 = (indexInWarp < blockSize ? acceptorAtoms[acceptorStart+indexInWarp] : make_int4(-1));
if (indexInWarp < blockSize) {
localData[LOCAL_ID].pos1 = (atoms2.x > -1 ? trimTo3(posq[atoms2.x]) : make_real3(0));
localData[LOCAL_ID].pos2 = (atoms2.y > -1 ? trimTo3(posq[atoms2.y]) : make_real3(0));
localData[LOCAL_ID].pos3 = (atoms2.z > -1 ? trimTo3(posq[atoms2.z]) : make_real3(0));
}
// Write results
SYNC_WARPS;
if (donorIndex < NUM_DONORS) {
if (atoms.x > -1) {
ATOMIC_ADD(&force[atoms.x], (mm_ulong) realToFixedPoint(f1.x));
ATOMIC_ADD(&force[atoms.x+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f1.y));
ATOMIC_ADD(&force[atoms.x+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f1.z));
MEM_FENCE;
}
if (atoms.y > -1) {
ATOMIC_ADD(&force[atoms.y], (mm_ulong) realToFixedPoint(f2.x));
ATOMIC_ADD(&force[atoms.y+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f2.y));
ATOMIC_ADD(&force[atoms.y+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f2.z));
MEM_FENCE;
}
if (atoms.z > -1) {
ATOMIC_ADD(&force[atoms.z], (mm_ulong) realToFixedPoint(f3.x));
ATOMIC_ADD(&force[atoms.z+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f3.y));
ATOMIC_ADD(&force[atoms.z+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f3.z));
MEM_FENCE;
}
}
}
energyBuffer[GLOBAL_ID] += energy;
}
/**
* Compute forces on acceptors.
*/
KERNEL void computeAcceptorForces(
GLOBAL mm_ulong* RESTRICT force,
GLOBAL mixed* RESTRICT energyBuffer, GLOBAL const real4* RESTRICT posq, GLOBAL const int4* RESTRICT exclusions,
GLOBAL const int4* RESTRICT donorAtoms, GLOBAL const int4* RESTRICT acceptorAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize,
real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ
PARAMETER_ARGUMENTS) {
LOCAL real4 posBuffer[3*THREAD_BLOCK_SIZE];
real3 f1 = make_real3(0);
real3 f2 = make_real3(0);
real3 f3 = make_real3(0);
for (int acceptorStart = 0; acceptorStart < NUM_ACCEPTORS; acceptorStart += GLOBAL_SIZE) {
// Load information about the acceptor this thread will compute forces on.
int acceptorIndex = acceptorStart+GLOBAL_ID;
int4 atoms, exclusionIndices;
real4 a1, a2, a3;
if (acceptorIndex < NUM_ACCEPTORS) {
atoms = acceptorAtoms[acceptorIndex];
a1 = (atoms.x > -1 ? posq[atoms.x] : make_real4(0));
a2 = (atoms.y > -1 ? posq[atoms.y] : make_real4(0));
a3 = (atoms.z > -1 ? posq[atoms.z] : make_real4(0));
int index = indexInWarp;
for (int j = 0; j < 32; j++) {
int acceptorIndex = acceptorStart+index;
#ifdef USE_EXCLUSIONS
exclusionIndices = exclusions[acceptorIndex];
#endif
}
else
atoms = make_int4(-1, -1, -1, -1);
for (int donorStart = 0; donorStart < NUM_DONORS; donorStart += LOCAL_SIZE) {
// Load the next block of donors into local memory.
SYNC_THREADS;
int blockSize = min((int) LOCAL_SIZE, NUM_DONORS-donorStart);
if (LOCAL_ID < blockSize) {
int4 atoms2 = donorAtoms[donorStart+LOCAL_ID];
posBuffer[3*LOCAL_ID] = (atoms2.x > -1 ? posq[atoms2.x] : make_real4(0));
posBuffer[3*LOCAL_ID+1] = (atoms2.y > -1 ? posq[atoms2.y] : make_real4(0));
posBuffer[3*LOCAL_ID+2] = (atoms2.z > -1 ? posq[atoms2.z] : make_real4(0));
}
SYNC_THREADS;
if (acceptorIndex < NUM_ACCEPTORS) {
for (int index = 0; index < blockSize; index++) {
int donorIndex = donorStart+index;
#ifdef USE_EXCLUSIONS
if (donorIndex == exclusionIndices.x || donorIndex == exclusionIndices.y || donorIndex == exclusionIndices.z || donorIndex == exclusionIndices.w)
continue;
if (acceptorIndex < NUM_ACCEPTORS && acceptorIndex != exclusionIndices.x && acceptorIndex != exclusionIndices.y && acceptorIndex != exclusionIndices.z && acceptorIndex != exclusionIndices.w) {
#else
if (acceptorIndex < NUM_ACCEPTORS) {
#endif
// Compute the interaction between a donor and an acceptor.
real4 d1 = posBuffer[3*index];
real4 d2 = posBuffer[3*index+1];
real4 d3 = posBuffer[3*index+2];
real3 a1 = localData[tbx+index].pos1;
real3 a2 = localData[tbx+index].pos2;
real3 a3 = localData[tbx+index].pos3;
real4 deltaD1A1 = delta(d1, a1, periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);
#ifdef USE_CUTOFF
if (deltaD1A1.w < CUTOFF_SQUARED) {
#endif
COMPUTE_ACCEPTOR_FORCE
COMPUTE_FORCE
#ifdef USE_CUTOFF
}
#endif
}
index = (index+1)%32;
}
}
// Write results
if (acceptorIndex < NUM_ACCEPTORS) {
if (atoms.x > -1) {
ATOMIC_ADD(&force[atoms.x], (mm_ulong) realToFixedPoint(f1.x));
ATOMIC_ADD(&force[atoms.x+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f1.y));
ATOMIC_ADD(&force[atoms.x+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f1.z));
MEM_FENCE;
}
if (atoms.y > -1) {
ATOMIC_ADD(&force[atoms.y], (mm_ulong) realToFixedPoint(f2.x));
ATOMIC_ADD(&force[atoms.y+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f2.y));
ATOMIC_ADD(&force[atoms.y+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f2.z));
MEM_FENCE;
}
if (atoms.z > -1) {
ATOMIC_ADD(&force[atoms.z], (mm_ulong) realToFixedPoint(f3.x));
ATOMIC_ADD(&force[atoms.z+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f3.y));
ATOMIC_ADD(&force[atoms.z+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(f3.z));
MEM_FENCE;
}
if (donorIndex < NUM_DONORS) {
applyForce(atoms.x, f1, force);
applyForce(atoms.y, f2, force);
applyForce(atoms.z, f3, force);
}
SYNC_WARPS;
applyForce(atoms2.x, localData[LOCAL_ID].f1, force);
applyForce(atoms2.y, localData[LOCAL_ID].f2, force);
applyForce(atoms2.z, localData[LOCAL_ID].f3, force);
}
energyBuffer[GLOBAL_ID] += energy;
}
......@@ -309,6 +309,43 @@ void testParameters() {
ASSERT_EQUAL_TOL(2*(2*1.8+2.1)+2*(2*1.5+2.1), state.getPotentialEnergy(), TOL);
}
void testLargeSystem() {
int numParticles = 5000;
System system;
CustomHbondForce* custom = new CustomHbondForce("distance(d1,a1)^2");
vector<Vec3> positions(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
if (i%2 == 0)
custom->addDonor(i, -1, -1);
else
custom->addAcceptor(i, -1, -1);
positions[i] = Vec3(3.0*genrand_real2(sfmt), 3.0*genrand_real2(sfmt), 3.0*genrand_real2(sfmt));
}
system.addForce(custom);
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
double expectedEnergy = 0;
for (int i = 0; i < numParticles; i += 2) {
for (int j = 1; j < numParticles; j += 2) {
Vec3 d = positions[i]-positions[j];
double r = sqrt(d.dot(d));
expectedEnergy += r*r;
}
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
for (int i = 0; i < numParticles; i += 2) {
Vec3 expectedForce;
for (int j = 1; j < numParticles; j += 2)
expectedForce += 2*(positions[j]-positions[i]);
ASSERT_EQUAL_VEC(expectedForce, state.getForces()[i], 1e-5);
}
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -321,6 +358,7 @@ int main(int argc, char* argv[]) {
test2DFunction();
testIllegalVariable();
testParameters();
testLargeSystem();
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment