Unverified Commit 14f8b061 authored by Evan Pretti's avatar Evan Pretti Committed by GitHub
Browse files

Make sure contexts are deselected before evaluation (#5279)

* Context deselection before energy evaluation

* Check that the correct context is popped by popAsCurrent()
parent 0aee8050
...@@ -49,6 +49,25 @@ private: ...@@ -49,6 +49,25 @@ private:
ComputeContext& context; ComputeContext& context;
}; };
/**
* This class deselects a ComputeContext by calling popAsCurrent() on the
* context when it is created and pushAsCurrent() when it goes out of scope.
* This can be useful to temporarily undo the effect of a ContextSelector and
* must only be used when the context is already selected.
*/
class OPENMM_EXPORT_COMMON ContextDeselector {
public:
ContextDeselector(ComputeContext& context) : context(context) {
context.popAsCurrent();
}
~ContextDeselector() {
context.pushAsCurrent();
}
private:
ComputeContext& context;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_CONTEXTSELECTOR_H_*/ #endif /*OPENMM_CONTEXTSELECTOR_H_*/
...@@ -683,7 +683,10 @@ void CommonIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -683,7 +683,10 @@ void CommonIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
} }
else { else {
recordChangedParameters(context); recordChangedParameters(context);
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroups); {
ContextDeselector deselector(cc);
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroups);
}
savedEnergy[forceGroups] = energy; savedEnergy[forceGroups] = energy;
if (needsEnergyParamDerivs) { if (needsEnergyParamDerivs) {
context.getEnergyParameterDerivatives(energyParamDerivs); context.getEnergyParameterDerivatives(energyParamDerivs);
......
...@@ -96,17 +96,17 @@ void CommonIntegrateNoseHooverStepKernel::initialize(const System& system, const ...@@ -96,17 +96,17 @@ void CommonIntegrateNoseHooverStepKernel::initialize(const System& system, const
} }
void CommonIntegrateNoseHooverStepKernel::execute(ContextImpl& context, const NoseHooverIntegrator& integrator) { void CommonIntegrateNoseHooverStepKernel::execute(ContextImpl& context, const NoseHooverIntegrator& integrator) {
// If the atom reordering has occured, the forces from the previous step are permuted and thus invalid.
// They need to be either sorted or recomputed; here we choose the latter.
if (cc.getAtomsWereReordered())
context.calcForcesAndEnergy(true, false, integrator.getIntegrationForceGroups());
ContextSelector selector(cc); ContextSelector selector(cc);
IntegrationUtilities& integration = cc.getIntegrationUtilities(); IntegrationUtilities& integration = cc.getIntegrationUtilities();
int paddedNumAtoms = cc.getPaddedNumAtoms(); int paddedNumAtoms = cc.getPaddedNumAtoms();
double dt = integrator.getStepSize(); double dt = integrator.getStepSize();
cc.getIntegrationUtilities().setNextStepSize(dt); cc.getIntegrationUtilities().setNextStepSize(dt);
// If the atom reordering has occured, the forces from the previous step are permuted and thus invalid.
// They need to be either sorted or recomputed; here we choose the latter.
if (cc.getAtomsWereReordered())
context.calcForcesAndEnergy(true, false, integrator.getIntegrationForceGroups());
const auto& atomList = integrator.getAllThermostatedIndividualParticles(); const auto& atomList = integrator.getAllThermostatedIndividualParticles();
const auto& pairList = integrator.getAllThermostatedPairs(); const auto& pairList = integrator.getAllThermostatedPairs();
int numAtoms = atomList.size(); int numAtoms = atomList.size();
......
...@@ -4709,7 +4709,7 @@ void CommonCalcCustomCPPForceKernel::initialize(const ContextImpl& context, Cust ...@@ -4709,7 +4709,7 @@ void CommonCalcCustomCPPForceKernel::initialize(const ContextImpl& context, Cust
forceGroupFlag = (1<<force.getOwner().getForceGroup()); forceGroupFlag = (1<<force.getOwner().getForceGroup());
useWorkerThread = (cc.getNumContexts() == 1); useWorkerThread = (cc.getNumContexts() == 1);
for (const ForceImpl* impl : context.getForceImpls()) for (const ForceImpl* impl : context.getForceImpls())
if (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL) if (impl != &force && (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL))
useWorkerThread = false; useWorkerThread = false;
if (useWorkerThread) { if (useWorkerThread) {
cc.addPreComputation(new StartCalculationPreComputation(*this)); cc.addPreComputation(new StartCalculationPreComputation(*this));
...@@ -4871,7 +4871,7 @@ void CommonCalcPythonForceKernel::initialize(const ContextImpl& context, const P ...@@ -4871,7 +4871,7 @@ void CommonCalcPythonForceKernel::initialize(const ContextImpl& context, const P
forceGroupFlag = (1<<force.getForceGroup()); forceGroupFlag = (1<<force.getForceGroup());
useWorkerThread = (cc.getNumContexts() == 1); useWorkerThread = (cc.getNumContexts() == 1);
for (const ForceImpl* impl : context.getForceImpls()) for (const ForceImpl* impl : context.getForceImpls())
if (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL) if (&impl->getOwner() != &force && (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL))
useWorkerThread = false; useWorkerThread = false;
if (useWorkerThread) { if (useWorkerThread) {
cc.addPreComputation(new StartCalculationPreComputation(*this)); cc.addPreComputation(new StartCalculationPreComputation(*this));
......
...@@ -547,7 +547,11 @@ void CommonMinimizeKernel::evaluateGpu(ContextImpl& context) { ...@@ -547,7 +547,11 @@ void CommonMinimizeKernel::evaluateGpu(ContextImpl& context) {
// Evaluate the forces and energy for the desired interactions as well as // Evaluate the forces and energy for the desired interactions as well as
// harmonic restraints to emulate the constraints. // harmonic restraints to emulate the constraints.
energy = context.calcForcesAndEnergy(true, true, forceGroups); {
ContextDeselector deselector(cc);
energy = context.calcForcesAndEnergy(true, true, forceGroups);
}
if (numConstraints) { if (numConstraints) {
if (mixedIsDouble) { if (mixedIsDouble) {
getConstraintEnergyForcesKernel->setArg(8, kRestraint); getConstraintEnergyForcesKernel->setArg(8, kRestraint);
...@@ -592,7 +596,11 @@ double CommonMinimizeKernel::evaluateCpu(ContextImpl& context) { ...@@ -592,7 +596,11 @@ double CommonMinimizeKernel::evaluateCpu(ContextImpl& context) {
cpuContext->setState(context.getOwner().getState(State::Parameters)); cpuContext->setState(context.getOwner().getState(State::Parameters));
cpuContext->setPositions(hostPositions); cpuContext->setPositions(hostPositions);
cpuContext->computeVirtualSites(); cpuContext->computeVirtualSites();
State state = cpuContext->getState(State::Energy | State::Forces, false, forceGroups); State state;
{
ContextDeselector deselector(cc);
state = cpuContext->getState(State::Energy | State::Forces, false, forceGroups);
}
double hostEnergy = state.getPotentialEnergy(); double hostEnergy = state.getPotentialEnergy();
const vector<Vec3>& hostForces = state.getForces(); const vector<Vec3>& hostForces = state.getForces();
...@@ -676,7 +684,10 @@ bool CommonMinimizeKernel::report(ContextImpl& context, int iteration) { ...@@ -676,7 +684,10 @@ bool CommonMinimizeKernel::report(ContextImpl& context, int iteration) {
args["system energy"] = energy - restraintEnergy; args["system energy"] = energy - restraintEnergy;
args["restraint strength"] = kRestraint; args["restraint strength"] = kRestraint;
args["max constraint error"] = maxError; args["max constraint error"] = maxError;
return reporter->report(iteration - 1, hostX, hostGrad, args); {
ContextDeselector deselector(cc);
return reporter->report(iteration - 1, hostX, hostGrad, args);
}
} }
void CommonMinimizeKernel::downloadReturnFlagStart() { void CommonMinimizeKernel::downloadReturnFlagStart() {
......
...@@ -461,8 +461,11 @@ void CudaContext::pushAsCurrent() { ...@@ -461,8 +461,11 @@ void CudaContext::pushAsCurrent() {
void CudaContext::popAsCurrent() { void CudaContext::popAsCurrent() {
CUcontext popped; CUcontext popped;
if (contextIsValid) if (contextIsValid) {
cuCtxPopCurrent(&popped); cuCtxPopCurrent(&popped);
if (popped != context)
throw OpenMMException("Called popAsCurrent() on a context that is not current");
}
} }
CUmodule CudaContext::createModule(const string source, const char* optimizationFlags) { CUmodule CudaContext::createModule(const string source, const char* optimizationFlags) {
...@@ -886,4 +889,4 @@ void CudaContext::ensureCudaInitialized() { ...@@ -886,4 +889,4 @@ void CudaContext::ensureCudaInitialized() {
CHECK_RESULT2(cuInit(0), "Error initializing CUDA"); CHECK_RESULT2(cuInit(0), "Error initializing CUDA");
hasInitializedCuda = true; hasInitializedCuda = true;
} }
} }
\ No newline at end of file
...@@ -1399,8 +1399,10 @@ void CommonCalcAmoebaMultipoleForceKernel::ensureMultipolesValid(ContextImpl& co ...@@ -1399,8 +1399,10 @@ void CommonCalcAmoebaMultipoleForceKernel::ensureMultipolesValid(ContextImpl& co
} }
} }
} }
if (!multipolesAreValid) if (!multipolesAreValid) {
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(false, false, context.getIntegrator().getIntegrationForceGroups()); context.calcForcesAndEnergy(false, false, context.getIntegrator().getIntegrationForceGroups());
}
} }
void CommonCalcAmoebaMultipoleForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) { void CommonCalcAmoebaMultipoleForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
...@@ -3487,8 +3489,10 @@ void CommonCalcHippoNonbondedForceKernel::ensureMultipolesValid(ContextImpl& con ...@@ -3487,8 +3489,10 @@ void CommonCalcHippoNonbondedForceKernel::ensureMultipolesValid(ContextImpl& con
} }
} }
} }
if (!multipolesAreValid) if (!multipolesAreValid) {
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(false, false, context.getIntegrator().getIntegrationForceGroups()); context.calcForcesAndEnergy(false, false, context.getIntegrator().getIntegrationForceGroups());
}
} }
void CommonCalcHippoNonbondedForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) { void CommonCalcHippoNonbondedForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
......
...@@ -502,7 +502,10 @@ void CommonIntegrateDrudeSCFStepKernel::minimize(ContextImpl& context, double to ...@@ -502,7 +502,10 @@ void CommonIntegrateDrudeSCFStepKernel::minimize(ContextImpl& context, double to
int numDrude = drudeParams.getSize(); int numDrude = drudeParams.getSize();
int paddedNumAtoms = cc.getPaddedNumAtoms(); int paddedNumAtoms = cc.getPaddedNumAtoms();
for (int iteration = 0; iteration < 50; iteration++) { for (int iteration = 0; iteration < 50; iteration++) {
context.calcForcesAndEnergy(true, false, context.getIntegrator().getIntegrationForceGroups()); {
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(true, false, context.getIntegrator().getIntegrationForceGroups());
}
minimizeKernel->execute(drudeParams.getSize()); minimizeKernel->execute(drudeParams.getSize());
cc.getLongForceBuffer().download(forces); cc.getLongForceBuffer().download(forces);
double totalForce = 0; double totalForce = 0;
......
...@@ -297,7 +297,10 @@ void CommonIntegrateRPMDStepKernel::computeForces(ContextImpl& context) { ...@@ -297,7 +297,10 @@ void CommonIntegrateRPMDStepKernel::computeForces(ContextImpl& context) {
context.getPeriodicBoxVectors(finalBox[0], finalBox[1], finalBox[2]); context.getPeriodicBoxVectors(finalBox[0], finalBox[1], finalBox[2]);
if (initialBox[0] != finalBox[0] || initialBox[1] != finalBox[1] || initialBox[2] != finalBox[2]) if (initialBox[0] != finalBox[0] || initialBox[1] != finalBox[1] || initialBox[2] != finalBox[2])
throw OpenMMException("Standard barostats cannot be used with RPMDIntegrator. Use RPMDMonteCarloBarostat instead."); throw OpenMMException("Standard barostats cannot be used with RPMDIntegrator. Use RPMDMonteCarloBarostat instead.");
context.calcForcesAndEnergy(true, false, groupsNotContracted); {
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(true, false, groupsNotContracted);
}
copyFromContextKernel->setArg(7, i); copyFromContextKernel->setArg(7, i);
copyFromContextKernel->execute(cc.getNumAtoms()); copyFromContextKernel->execute(cc.getNumAtoms());
} }
...@@ -322,7 +325,10 @@ void CommonIntegrateRPMDStepKernel::computeForces(ContextImpl& context) { ...@@ -322,7 +325,10 @@ void CommonIntegrateRPMDStepKernel::computeForces(ContextImpl& context) {
copyToContextKernel->setArg(5, i); copyToContextKernel->setArg(5, i);
copyToContextKernel->execute(cc.getNumAtoms()); copyToContextKernel->execute(cc.getNumAtoms());
context.computeVirtualSites(); context.computeVirtualSites();
context.calcForcesAndEnergy(true, false, groupFlags); {
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(true, false, groupFlags);
}
copyFromContextKernel->setArg(7, i); copyFromContextKernel->setArg(7, i);
copyFromContextKernel->execute(cc.getNumAtoms()); copyFromContextKernel->execute(cc.getNumAtoms());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment