Commit 8eba00a7 authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimizations to CustomIntegrator to prevent unnecessary extra force evaluations

parent 0665e20b
...@@ -4492,6 +4492,8 @@ CudaIntegrateCustomStepKernel::~CudaIntegrateCustomStepKernel() { ...@@ -4492,6 +4492,8 @@ CudaIntegrateCustomStepKernel::~CudaIntegrateCustomStepKernel() {
delete randomSeed; delete randomSeed;
if (perDofValues != NULL) if (perDofValues != NULL)
delete perDofValues; delete perDofValues;
for (map<int, CudaArray*>::iterator iter = savedForces.begin(); iter != savedForces.end(); ++iter)
delete iter->second;
} }
void CudaIntegrateCustomStepKernel::initialize(const System& system, const CustomIntegrator& integrator) { void CudaIntegrateCustomStepKernel::initialize(const System& system, const CustomIntegrator& integrator) {
...@@ -4700,6 +4702,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -4700,6 +4702,8 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
invalidatesForces[step] = (stepType[step] == CustomIntegrator::ConstrainPositions || affectsForce.find(variable[step]) != affectsForce.end()); invalidatesForces[step] = (stepType[step] == CustomIntegrator::ConstrainPositions || affectsForce.find(variable[step]) != affectsForce.end());
if (forceGroup[step] == -2 && step > 0) if (forceGroup[step] == -2 && step > 0)
forceGroup[step] = forceGroup[step-1]; forceGroup[step] = forceGroup[step-1];
if (forceGroup[step] != -2 && savedForces.find(forceGroup[step]) == savedForces.end())
savedForces[forceGroup[step]] = new CudaArray(cu, cu.getForce().getSize(), cu.getForce().getElementSize(), "savedForces");
} }
// Determine how each step will represent the position (as just a value, or a value plus a delta). // Determine how each step will represent the position (as just a value, or a value plus a delta).
...@@ -4984,7 +4988,18 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -4984,7 +4988,18 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
void* randomArgs[] = {&uniformRandoms->getDevicePointer(), &randomSeed->getDevicePointer()}; void* randomArgs[] = {&uniformRandoms->getDevicePointer(), &randomSeed->getDevicePointer()};
CUdeviceptr posCorrection = (cu.getUseMixedPrecision() ? cu.getPosqCorrection().getDevicePointer() : 0); CUdeviceptr posCorrection = (cu.getUseMixedPrecision() ? cu.getPosqCorrection().getDevicePointer() : 0);
for (int i = 0; i < numSteps; i++) { for (int i = 0; i < numSteps; i++) {
if ((needsForces[i] || needsEnergy[i]) && (!forcesAreValid || context.getLastForceGroups() != forceGroup[i])) { int lastForceGroups = context.getLastForceGroups();
if ((needsForces[i] || needsEnergy[i]) && (!forcesAreValid || lastForceGroups != forceGroup[i])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again.
cu.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups);
}
else
validSavedForces.clear();
// Recompute forces and/or energy. Figure out what is actually needed // Recompute forces and/or energy. Figure out what is actually needed
// between now and the next time they get invalidated again. // between now and the next time they get invalidated again.
...@@ -5001,12 +5016,19 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -5001,12 +5016,19 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
if (j == i-1) if (j == i-1)
break; break;
} }
if (!computeEnergy && validSavedForces.find(forceGroup[i]) != validSavedForces.end()) {
// We can just restore the forces we saved earlier.
savedForces[forceGroup[i]]->copyTo(cu.getForce());
}
else {
recordChangedParameters(context); recordChangedParameters(context);
context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroup[i]); context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroup[i]);
if (computeEnergy) { if (computeEnergy) {
void* args[] = {&cu.getEnergyBuffer().getDevicePointer(), &potentialEnergy->getDevicePointer()}; void* args[] = {&cu.getEnergyBuffer().getDevicePointer(), &potentialEnergy->getDevicePointer()};
cu.executeKernel(sumPotentialEnergyKernel, &args[0], CudaContext::ThreadBlockSize, CudaContext::ThreadBlockSize); cu.executeKernel(sumPotentialEnergyKernel, &args[0], CudaContext::ThreadBlockSize, CudaContext::ThreadBlockSize);
} }
}
forcesAreValid = true; forcesAreValid = true;
} }
if (stepType[i] == CustomIntegrator::ComputePerDof && !merged[i]) { if (stepType[i] == CustomIntegrator::ComputePerDof && !merged[i]) {
......
...@@ -1185,6 +1185,8 @@ private: ...@@ -1185,6 +1185,8 @@ private:
CudaArray* kineticEnergy; CudaArray* kineticEnergy;
CudaArray* uniformRandoms; CudaArray* uniformRandoms;
CudaArray* randomSeed; CudaArray* randomSeed;
std::map<int, CudaArray*> savedForces;
std::set<int> validSavedForces;
CudaParameterSet* perDofValues; CudaParameterSet* perDofValues;
mutable std::vector<std::vector<float> > localPerDofValuesFloat; mutable std::vector<std::vector<float> > localPerDofValuesFloat;
mutable std::vector<std::vector<double> > localPerDofValuesDouble; mutable std::vector<std::vector<double> > localPerDofValuesDouble;
......
...@@ -4725,6 +4725,8 @@ OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() { ...@@ -4725,6 +4725,8 @@ OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() {
delete randomSeed; delete randomSeed;
if (perDofValues != NULL) if (perDofValues != NULL)
delete perDofValues; delete perDofValues;
for (map<int, OpenCLArray*>::iterator iter = savedForces.begin(); iter != savedForces.end(); ++iter)
delete iter->second;
} }
void OpenCLIntegrateCustomStepKernel::initialize(const System& system, const CustomIntegrator& integrator) { void OpenCLIntegrateCustomStepKernel::initialize(const System& system, const CustomIntegrator& integrator) {
...@@ -4930,6 +4932,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -4930,6 +4932,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
invalidatesForces[step] = (stepType[step] == CustomIntegrator::ConstrainPositions || affectsForce.find(variable[step]) != affectsForce.end()); invalidatesForces[step] = (stepType[step] == CustomIntegrator::ConstrainPositions || affectsForce.find(variable[step]) != affectsForce.end());
if (forceGroup[step] == -2 && step > 0) if (forceGroup[step] == -2 && step > 0)
forceGroup[step] = forceGroup[step-1]; forceGroup[step] = forceGroup[step-1];
if (forceGroup[step] != -2 && savedForces.find(forceGroup[step]) == savedForces.end())
savedForces[forceGroup[step]] = new OpenCLArray(cl, cl.getForce().getSize(), cl.getForce().getElementSize(), "savedForces");
} }
// Determine how each step will represent the position (as just a value, or a value plus a delta). // Determine how each step will represent the position (as just a value, or a value plus a delta).
...@@ -5218,7 +5222,18 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -5218,7 +5222,18 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
// Loop over computation steps in the integrator and execute them. // Loop over computation steps in the integrator and execute them.
for (int i = 0; i < numSteps; i++) { for (int i = 0; i < numSteps; i++) {
if ((needsForces[i] || needsEnergy[i]) && (!forcesAreValid || context.getLastForceGroups() != forceGroup[i])) { int lastForceGroups = context.getLastForceGroups();
if ((needsForces[i] || needsEnergy[i]) && (!forcesAreValid || lastForceGroups != forceGroup[i])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again.
cl.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups);
}
else
validSavedForces.clear();
// Recompute forces and/or energy. Figure out what is actually needed // Recompute forces and/or energy. Figure out what is actually needed
// between now and the next time they get invalidated again. // between now and the next time they get invalidated again.
...@@ -5235,12 +5250,19 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -5235,12 +5250,19 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
if (j == i-1) if (j == i-1)
break; break;
} }
if (!computeEnergy && validSavedForces.find(forceGroup[i]) != validSavedForces.end()) {
// We can just restore the forces we saved earlier.
savedForces[forceGroup[i]]->copyTo(cl.getForce());
}
else {
recordChangedParameters(context); recordChangedParameters(context);
context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroup[i]); context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroup[i]);
if (computeEnergy) if (computeEnergy)
cl.executeKernel(sumPotentialEnergyKernel, OpenCLContext::ThreadBlockSize, OpenCLContext::ThreadBlockSize); cl.executeKernel(sumPotentialEnergyKernel, OpenCLContext::ThreadBlockSize, OpenCLContext::ThreadBlockSize);
forcesAreValid = true; forcesAreValid = true;
} }
}
if (stepType[i] == CustomIntegrator::ComputePerDof && !merged[i]) { if (stepType[i] == CustomIntegrator::ComputePerDof && !merged[i]) {
kernels[i][0].setArg<cl_uint>(10, integration.prepareRandomNumbers(requiredGaussian[i])); kernels[i][0].setArg<cl_uint>(10, integration.prepareRandomNumbers(requiredGaussian[i]));
if (requiredUniform[i] > 0) if (requiredUniform[i] > 0)
......
...@@ -1199,6 +1199,8 @@ private: ...@@ -1199,6 +1199,8 @@ private:
OpenCLArray* kineticEnergy; OpenCLArray* kineticEnergy;
OpenCLArray* uniformRandoms; OpenCLArray* uniformRandoms;
OpenCLArray* randomSeed; OpenCLArray* randomSeed;
std::map<int, OpenCLArray*> savedForces;
std::set<int> validSavedForces;
OpenCLParameterSet* perDofValues; OpenCLParameterSet* perDofValues;
mutable std::vector<std::vector<cl_float> > localPerDofValuesFloat; mutable std::vector<std::vector<cl_float> > localPerDofValuesFloat;
mutable std::vector<std::vector<cl_double> > localPerDofValuesDouble; mutable std::vector<std::vector<cl_double> > localPerDofValuesDouble;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment