Commit c1d643e2 authored by one's avatar one
Browse files

Split LJ-PME atom-grid sorting from Coulomb PME

Avoid forcing Coulomb PME to re-sort whenever LJ-PME is enabled, and give dispersion PME its own atom-grid index and sort state so the performance impact can be measured independently.
parent b1a1c54c
...@@ -45,7 +45,7 @@ namespace OpenMM { ...@@ -45,7 +45,7 @@ namespace OpenMM {
class CommonCalcNonbondedForceKernel : public CalcNonbondedForceKernel { class CommonCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public: public:
CommonCalcNonbondedForceKernel(std::string name, const Platform& platform, ComputeContext& cc, const System& system) : CalcNonbondedForceKernel(name, platform), CommonCalcNonbondedForceKernel(std::string name, const Platform& platform, ComputeContext& cc, const System& system) : CalcNonbondedForceKernel(name, platform),
hasInitializedKernel(false), cc(cc), pmeio(NULL), stepsToSort(0) { hasInitializedKernel(false), cc(cc), pmeio(NULL), stepsToSort(0), dispersionStepsToSort(0) {
} }
~CommonCalcNonbondedForceKernel(); ~CommonCalcNonbondedForceKernel();
/** /**
...@@ -140,6 +140,7 @@ private: ...@@ -140,6 +140,7 @@ private:
ComputeArray pmeDispersionBsplineModuliY; ComputeArray pmeDispersionBsplineModuliY;
ComputeArray pmeDispersionBsplineModuliZ; ComputeArray pmeDispersionBsplineModuliZ;
ComputeArray pmeAtomGridIndex; ComputeArray pmeAtomGridIndex;
ComputeArray pmeDispersionAtomGridIndex;
ComputeArray pmeEnergyBuffer; ComputeArray pmeEnergyBuffer;
ComputeArray chargeBuffer; ComputeArray chargeBuffer;
ComputeSort sort; ComputeSort sort;
...@@ -167,6 +168,7 @@ private: ...@@ -167,6 +168,7 @@ private:
int gridSizeX, gridSizeY, gridSizeZ; int gridSizeX, gridSizeY, gridSizeZ;
int dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ; int dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ;
int stepsToSort; int stepsToSort;
int dispersionStepsToSort;
bool usePmeQueue, deviceIsCpu, useFixedPointChargeSpreading, useCpuPme; bool usePmeQueue, deviceIsCpu, useFixedPointChargeSpreading, useCpuPme;
bool hasCoulomb, hasLJ, doLJPME, usePosqCharges, recomputeParams, hasOffsets; bool hasCoulomb, hasLJ, doLJPME, usePosqCharges, recomputeParams, hasOffsets;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
......
...@@ -457,6 +457,8 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons ...@@ -457,6 +457,8 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons
pmeDispersionBsplineModuliZ.initialize(cc, dispersionGridSizeZ, elementSize, "pmeDispersionBsplineModuliZ"); pmeDispersionBsplineModuliZ.initialize(cc, dispersionGridSizeZ, elementSize, "pmeDispersionBsplineModuliZ");
} }
pmeAtomGridIndex.initialize<mm_int2>(cc, numParticles, "pmeAtomGridIndex"); pmeAtomGridIndex.initialize<mm_int2>(cc, numParticles, "pmeAtomGridIndex");
if (doLJPME)
pmeDispersionAtomGridIndex.initialize<mm_int2>(cc, numParticles, "pmeDispersionAtomGridIndex");
int energyElementSize = (cc.getUseDoublePrecision() || cc.getUseMixedPrecision() ? sizeof(double) : sizeof(float)); int energyElementSize = (cc.getUseDoublePrecision() || cc.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
pmeEnergyBuffer.initialize(cc, cc.getNumThreadBlocks()*ComputeContext::ThreadBlockSize, energyElementSize, "pmeEnergyBuffer"); pmeEnergyBuffer.initialize(cc, cc.getNumThreadBlocks()*ComputeContext::ThreadBlockSize, energyElementSize, "pmeEnergyBuffer");
cc.clearBuffer(pmeEnergyBuffer); cc.clearBuffer(pmeEnergyBuffer);
...@@ -836,7 +838,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -836,7 +838,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDispersionEvalEnergyKernel = program->createKernel("gridEvaluateEnergy"); pmeDispersionEvalEnergyKernel = program->createKernel("gridEvaluateEnergy");
pmeDispersionInterpolateForceKernel = program->createKernel("gridInterpolateForce"); pmeDispersionInterpolateForceKernel = program->createKernel("gridInterpolateForce");
pmeDispersionGridIndexKernel->addArg(cc.getPosq()); pmeDispersionGridIndexKernel->addArg(cc.getPosq());
pmeDispersionGridIndexKernel->addArg(pmeAtomGridIndex); pmeDispersionGridIndexKernel->addArg(pmeDispersionAtomGridIndex);
for (int i = 0; i < 8; i++) for (int i = 0; i < 8; i++)
pmeDispersionGridIndexKernel->addArg(); pmeDispersionGridIndexKernel->addArg();
pmeDispersionSpreadChargeKernel->addArg(cc.getPosq()); pmeDispersionSpreadChargeKernel->addArg(cc.getPosq());
...@@ -846,7 +848,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -846,7 +848,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDispersionSpreadChargeKernel->addArg(pmeGrid1); pmeDispersionSpreadChargeKernel->addArg(pmeGrid1);
for (int i = 0; i < 8; i++) for (int i = 0; i < 8; i++)
pmeDispersionSpreadChargeKernel->addArg(); pmeDispersionSpreadChargeKernel->addArg();
pmeDispersionSpreadChargeKernel->addArg(pmeAtomGridIndex); pmeDispersionSpreadChargeKernel->addArg(pmeDispersionAtomGridIndex);
pmeDispersionSpreadChargeKernel->addArg(sigmaEpsilon); pmeDispersionSpreadChargeKernel->addArg(sigmaEpsilon);
pmeDispersionConvolutionKernel->addArg(pmeGrid2); pmeDispersionConvolutionKernel->addArg(pmeGrid2);
pmeDispersionConvolutionKernel->addArg(pmeDispersionBsplineModuliX); pmeDispersionConvolutionKernel->addArg(pmeDispersionBsplineModuliX);
...@@ -869,7 +871,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -869,7 +871,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDispersionInterpolateForceKernel->addArg(pmeGrid1); pmeDispersionInterpolateForceKernel->addArg(pmeGrid1);
for (int i = 0; i < 8; i++) for (int i = 0; i < 8; i++)
pmeDispersionInterpolateForceKernel->addArg(); pmeDispersionInterpolateForceKernel->addArg();
pmeDispersionInterpolateForceKernel->addArg(pmeAtomGridIndex); pmeDispersionInterpolateForceKernel->addArg(pmeDispersionAtomGridIndex);
pmeDispersionInterpolateForceKernel->addArg(sigmaEpsilon); pmeDispersionInterpolateForceKernel->addArg(sigmaEpsilon);
if (useFixedPointChargeSpreading) { if (useFixedPointChargeSpreading) {
pmeDispersionFinishSpreadChargeKernel = program->createKernel("finishSpreadCharge"); pmeDispersionFinishSpreadChargeKernel = program->createKernel("finishSpreadCharge");
...@@ -962,7 +964,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -962,7 +964,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
// Execute the reciprocal space kernels. // Execute the reciprocal space kernels.
if (hasCoulomb) { if (hasCoulomb) {
if (stepsToSort <= 0 || doLJPME) { if (stepsToSort <= 0) {
setPeriodicBoxArgs(cc, pmeGridIndexKernel, 2); setPeriodicBoxArgs(cc, pmeGridIndexKernel, 2);
if (cc.getUseDoublePrecision()) { if (cc.getUseDoublePrecision()) {
pmeGridIndexKernel->setArg(7, recipBoxVectors[0]); pmeGridIndexKernel->setArg(7, recipBoxVectors[0]);
...@@ -1033,6 +1035,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -1033,6 +1035,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
} }
if (doLJPME && hasLJ) { if (doLJPME && hasLJ) {
if (dispersionStepsToSort <= 0) {
setPeriodicBoxArgs(cc, pmeDispersionGridIndexKernel, 2); setPeriodicBoxArgs(cc, pmeDispersionGridIndexKernel, 2);
if (cc.getUseDoublePrecision()) { if (cc.getUseDoublePrecision()) {
pmeDispersionGridIndexKernel->setArg(7, recipBoxVectors[0]); pmeDispersionGridIndexKernel->setArg(7, recipBoxVectors[0]);
...@@ -1045,7 +1048,11 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -1045,7 +1048,11 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDispersionGridIndexKernel->setArg(9, recipBoxVectorsFloat[2]); pmeDispersionGridIndexKernel->setArg(9, recipBoxVectorsFloat[2]);
} }
pmeDispersionGridIndexKernel->execute(cc.getNumAtoms()); pmeDispersionGridIndexKernel->execute(cc.getNumAtoms());
sort->sort(pmeAtomGridIndex); sort->sort(pmeDispersionAtomGridIndex);
dispersionStepsToSort = 3;
}
else
dispersionStepsToSort--;
if (useFixedPointChargeSpreading) if (useFixedPointChargeSpreading)
cc.clearBuffer(pmeGrid2); cc.clearBuffer(pmeGrid2);
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment