Unverified Commit b64a82a8 authored by Evan Pretti's avatar Evan Pretti Committed by GitHub
Browse files

Fix issue with constant potential and large neighbor lists (#5109)

* Ensure that neighbor list is valid before solving for charges

* Add test with neighbor list that needs to be resized

* Try another approach to skip interactions for neighbor list generation only

* Increase CG error tolerance for test
parent 2b20ffba
...@@ -290,6 +290,7 @@ private: ...@@ -290,6 +290,7 @@ private:
void initPmeExecute(); void initPmeExecute();
void pmeExecute(bool includeEnergy, bool includeForces, bool includeChargeDerivatives, bool init = true); void pmeExecute(bool includeEnergy, bool includeForces, bool includeChargeDerivatives, bool init = true);
void setKernelInputs(bool includeEnergy, bool includeForces); void setKernelInputs(bool includeEnergy, bool includeForces);
void ensureValidNeighborList();
class SortTrait : public ComputeSortImpl::SortTrait { class SortTrait : public ComputeSortImpl::SortTrait {
int getDataSize() const {return 8;} int getDataSize() const {return 8;}
int getKeySize() const {return 4;} int getKeySize() const {return 4;}
...@@ -302,6 +303,7 @@ private: ...@@ -302,6 +303,7 @@ private:
}; };
class ForceInfo; class ForceInfo;
class ReorderListener; class ReorderListener;
class InvalidatePostComputation;
ComputeContext& cc; ComputeContext& cc;
ForceInfo* info; ForceInfo* info;
CommonConstantPotentialSolver* solver; CommonConstantPotentialSolver* solver;
......
...@@ -129,6 +129,21 @@ private: ...@@ -129,6 +129,21 @@ private:
vector<ComputeArray*> chargeArrays; vector<ComputeArray*> chargeArrays;
}; };
class CommonCalcConstantPotentialForceKernel::InvalidatePostComputation : public ComputeContext::ForcePostComputation {
public:
InvalidatePostComputation(ComputeContext& cc, CommonConstantPotentialSolver* solver) : cc(cc), solver(solver) {
}
double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
if(!cc.getForcesValid()) {
solver->discardSavedSolution();
}
return 0.0;
}
private:
ComputeContext& cc;
CommonConstantPotentialSolver* solver;
};
CommonConstantPotentialSolver::CommonConstantPotentialSolver(ComputeContext& cc, int numParticles, int numElectrodeParticles, int paddedProblemSize) : CommonConstantPotentialSolver::CommonConstantPotentialSolver(ComputeContext& cc, int numParticles, int numElectrodeParticles, int paddedProblemSize) :
numParticles(numParticles), numParticles(numParticles),
numElectrodeParticles(numElectrodeParticles), numElectrodeParticles(numElectrodeParticles),
...@@ -291,6 +306,9 @@ void CommonConstantPotentialMatrixSolver::ensureValid(CommonCalcConstantPotentia ...@@ -291,6 +306,9 @@ void CommonConstantPotentialMatrixSolver::ensureValid(CommonCalcConstantPotentia
return; return;
} }
// We must have a valid neighbor list before populating the matrix.
kernel.ensureValidNeighborList();
// Store the current box vectors and electrode positions before updating the // Store the current box vectors and electrode positions before updating the
// capacitance matrix. // capacitance matrix.
valid = true; valid = true;
...@@ -999,6 +1017,12 @@ void CommonCalcConstantPotentialForceKernel::commonInitialize(const System& syst ...@@ -999,6 +1017,12 @@ void CommonCalcConstantPotentialForceKernel::commonInitialize(const System& syst
} }
} }
cc.addReorderListener(listener); cc.addReorderListener(listener);
// Create a post computation object to avoid caching solutions if forces
// are invalid (e.g., due to the neighbor list needing to be resized).
if (hasElectrodes) {
cc.addPostComputation(new InvalidatePostComputation(cc, solver));
}
} }
void CommonCalcConstantPotentialForceKernel::copyParametersToContext(ContextImpl& context, const ConstantPotentialForce& force, int firstParticle, int lastParticle, int firstException, int lastException, int firstElectrode, int lastElectrode) { void CommonCalcConstantPotentialForceKernel::copyParametersToContext(ContextImpl& context, const ConstantPotentialForce& force, int firstParticle, int lastParticle, int firstException, int lastException, int firstElectrode, int lastElectrode) {
...@@ -1146,9 +1170,7 @@ void CommonCalcConstantPotentialForceKernel::getCharges(ContextImpl& context, ve ...@@ -1146,9 +1170,7 @@ void CommonCalcConstantPotentialForceKernel::getCharges(ContextImpl& context, ve
ContextSelector selector(cc); ContextSelector selector(cc);
ensureInitialized(context); ensureInitialized(context);
ensureValidNeighborList();
// We need to have a neighbor list to evaluate direct space derivatives.
cc.getNonbondedUtilities().prepareInteractions(1 << forceGroup);
cc.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]); cc.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
setKernelInputs(false, false); setKernelInputs(false, false);
...@@ -1686,3 +1708,23 @@ void CommonCalcConstantPotentialForceKernel::setKernelInputs(bool includeEnergy, ...@@ -1686,3 +1708,23 @@ void CommonCalcConstantPotentialForceKernel::setKernelInputs(bool includeEnergy,
posCellOffsets.upload(hostPosCellOffsets); posCellOffsets.upload(hostPosCellOffsets);
} }
} }
void CommonCalcConstantPotentialForceKernel::ensureValidNeighborList() {
// Save the forcesValid flag since we use it to monitor the neighbor list build.
bool oldForcesValid = cc.getForcesValid();
do {
// If we need to try to build the neighbor list again (i.e., it needs to be made bigger),
// getForcesValid() will return false after computeInteractions() completes.
cc.setForcesValid(true);
cc.getNonbondedUtilities().prepareInteractions(1 << forceGroup);
cc.getNonbondedUtilities().computeInteractions(1 << forceGroup, false, false);
} while(!cc.getForcesValid());
if (hasElectrodes) {
evaluateDirectDerivativesKernel->setArg(17, (unsigned int) cc.getNonbondedUtilities().getInteractingTiles().getSize());
}
// Restore the old value of the flag.
cc.setForcesValid(oldForcesValid);
}
...@@ -424,7 +424,7 @@ void CudaNonbondedUtilities::computeInteractions(int forceGroups, bool includeFo ...@@ -424,7 +424,7 @@ void CudaNonbondedUtilities::computeInteractions(int forceGroups, bool includeFo
if ((forceGroups&groupFlags) == 0) if ((forceGroups&groupFlags) == 0)
return; return;
KernelSet& kernels = groupKernels[forceGroups]; KernelSet& kernels = groupKernels[forceGroups];
if (kernels.hasForces) { if (kernels.hasForces && (includeForces || includeEnergy)) {
CUfunction& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel); CUfunction& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel);
if (kernel == NULL) if (kernel == NULL)
kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy); kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy);
......
...@@ -36,4 +36,5 @@ void platformInitialize() { ...@@ -36,4 +36,5 @@ void platformInitialize() {
void runPlatformTests(ConstantPotentialForce::ConstantPotentialMethod method, bool usePreconditioner) { void runPlatformTests(ConstantPotentialForce::ConstantPotentialMethod method, bool usePreconditioner) {
testEnergyConservation(method, usePreconditioner, 10); testEnergyConservation(method, usePreconditioner, 10);
testCompareToReferencePlatform(method, usePreconditioner); testCompareToReferencePlatform(method, usePreconditioner);
testLargeNeighborList(method, usePreconditioner);
} }
...@@ -446,7 +446,7 @@ void HipNonbondedUtilities::computeInteractions(int forceGroups, bool includeFor ...@@ -446,7 +446,7 @@ void HipNonbondedUtilities::computeInteractions(int forceGroups, bool includeFor
if ((forceGroups&groupFlags) == 0) if ((forceGroups&groupFlags) == 0)
return; return;
KernelSet& kernels = groupKernels[forceGroups]; KernelSet& kernels = groupKernels[forceGroups];
if (kernels.hasForces) { if (kernels.hasForces && (includeForces || includeEnergy)) {
hipFunction_t& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel); hipFunction_t& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel);
if (kernel == NULL) if (kernel == NULL)
kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy); kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy);
......
...@@ -36,4 +36,5 @@ void platformInitialize() { ...@@ -36,4 +36,5 @@ void platformInitialize() {
void runPlatformTests(ConstantPotentialForce::ConstantPotentialMethod method, bool usePreconditioner) { void runPlatformTests(ConstantPotentialForce::ConstantPotentialMethod method, bool usePreconditioner) {
testEnergyConservation(method, usePreconditioner, 10); testEnergyConservation(method, usePreconditioner, 10);
testCompareToReferencePlatform(method, usePreconditioner); testCompareToReferencePlatform(method, usePreconditioner);
testLargeNeighborList(method, usePreconditioner);
} }
...@@ -375,7 +375,7 @@ void OpenCLNonbondedUtilities::computeInteractions(int forceGroups, bool include ...@@ -375,7 +375,7 @@ void OpenCLNonbondedUtilities::computeInteractions(int forceGroups, bool include
if ((forceGroups&groupFlags) == 0) if ((forceGroups&groupFlags) == 0)
return; return;
KernelSet& kernels = groupKernels[forceGroups]; KernelSet& kernels = groupKernels[forceGroups];
if (kernels.hasForces) { if (kernels.hasForces && (includeForces || includeEnergy)) {
if (isAMD) if (isAMD)
context.getQueue().flush(); context.getQueue().flush();
cl::Kernel& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel); cl::Kernel& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel);
......
...@@ -36,4 +36,5 @@ void platformInitialize() { ...@@ -36,4 +36,5 @@ void platformInitialize() {
void runPlatformTests(ConstantPotentialForce::ConstantPotentialMethod method, bool usePreconditioner) { void runPlatformTests(ConstantPotentialForce::ConstantPotentialMethod method, bool usePreconditioner) {
testEnergyConservation(method, usePreconditioner, 10); testEnergyConservation(method, usePreconditioner, 10);
testCompareToReferencePlatform(method, usePreconditioner); testCompareToReferencePlatform(method, usePreconditioner);
testLargeNeighborList(method, usePreconditioner);
} }
...@@ -1242,6 +1242,61 @@ void testCompareToReferencePlatform(ConstantPotentialForce::ConstantPotentialMet ...@@ -1242,6 +1242,61 @@ void testCompareToReferencePlatform(ConstantPotentialForce::ConstantPotentialMet
compareToReferencePlatform(system, force, positions); compareToReferencePlatform(system, force, positions);
} }
void testLargeNeighborList(ConstantPotentialForce::ConstantPotentialMethod method, bool usePreconditioner) {
// Runs a test where the initial neighbor list should overflow on GPU platforms.
const int n = 9;
const double l = 3.0;
const double scale = l / n;
System system;
system.setDefaultPeriodicBoxVectors(Vec3(l, 0, 0), Vec3(0, l, 0), Vec3(0, 0, l));
ConstantPotentialForce* force = new ConstantPotentialForce();
force->setConstantPotentialMethod(method);
force->setUsePreconditioner(usePreconditioner);
force->setUseChargeConstraint(true);
force->setCutoffDistance(1.4);
force->setCGErrorTolerance(5e-4);
system.addForce(force);
vector<Vec3> positions;
set<int> electrodeParticles;
for (int ix = 0; ix < n; ix++) {
for (int iy = 0; iy < n; iy++) {
for (int iz = 0; iz < n; iz++) {
positions.push_back(scale * Vec3(ix, iy, iz));
positions.push_back(scale * Vec3(ix + 0.5, iy + 0.5, iz + 0.5));
electrodeParticles.insert(system.addParticle(0.0));
system.addParticle(0.0);
force->addParticle(0.0);
force->addParticle(1.0);
}
}
}
force->addElectrode(electrodeParticles, 0.0, 0.01, 0.0);
VerletIntegrator integrator(0.001);
Context context(system, integrator, platform);
context.setPositions(positions);
// Get charges: if the neighbor list is incomplete, they will not be uniformly equal to -1.
vector<double> charges;
force->getCharges(context, charges);
for (int i : electrodeParticles) {
ASSERT_EQUAL_TOL(-1.0, charges[i], 2e-3);
}
// Run again, this time doing an energy/force calculation before getting charges.
context.reinitialize();
context.setPositions(positions);
context.getState(State::Energy | State::Forces);
force->getCharges(context, charges);
for (int i : electrodeParticles) {
ASSERT_EQUAL_TOL(-1.0, charges[i], 2e-3);
}
}
void platformInitialize(); void platformInitialize();
void runPlatformTests(ConstantPotentialForce::ConstantPotentialMethod method, bool usePreconditioner); void runPlatformTests(ConstantPotentialForce::ConstantPotentialMethod method, bool usePreconditioner);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment