Commit 506908ce authored by peastman's avatar peastman
Browse files

CUDA CustomIntegrator supports if and while blocks

parent 0360c770
...@@ -1288,6 +1288,7 @@ private: ...@@ -1288,6 +1288,7 @@ private:
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid); void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
void recordGlobalValue(double value, GlobalTarget target); void recordGlobalValue(double value, GlobalTarget target);
void recordChangedParameters(ContextImpl& context); void recordChangedParameters(ContextImpl& context);
bool evaluateCondition(int step);
CudaContext& cu; CudaContext& cu;
double prevStepSize, energy; double prevStepSize, energy;
float energyFloat; float energyFloat;
......
...@@ -6124,9 +6124,10 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -6124,9 +6124,10 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
int maxUniformRandoms = uniformRandoms->getSize(); int maxUniformRandoms = uniformRandoms->getSize();
void* randomArgs[] = {&maxUniformRandoms, &uniformRandoms->getDevicePointer(), &randomSeed->getDevicePointer()}; void* randomArgs[] = {&maxUniformRandoms, &uniformRandoms->getDevicePointer(), &randomSeed->getDevicePointer()};
CUdeviceptr posCorrection = (cu.getUseMixedPrecision() ? cu.getPosqCorrection().getDevicePointer() : 0); CUdeviceptr posCorrection = (cu.getUseMixedPrecision() ? cu.getPosqCorrection().getDevicePointer() : 0);
for (int i = 0; i < numSteps; i++) { for (int step = 0; step < numSteps; ) {
int nextStep = step+1;
int lastForceGroups = context.getLastForceGroups(); int lastForceGroups = context.getLastForceGroups();
if ((needsForces[i] || needsEnergy[i]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[i])) { if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) { if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old // The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again. // forces in case we need them again.
...@@ -6140,21 +6141,21 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -6140,21 +6141,21 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
// Recompute forces and/or energy. Figure out what is actually needed // Recompute forces and/or energy. Figure out what is actually needed
// between now and the next time they get invalidated again. // between now and the next time they get invalidated again.
bool computeForce = (needsForces[i] || computeBothForceAndEnergy[i]); bool computeForce = (needsForces[step] || computeBothForceAndEnergy[step]);
bool computeEnergy = (needsEnergy[i] || computeBothForceAndEnergy[i]); bool computeEnergy = (needsEnergy[step] || computeBothForceAndEnergy[step]);
if (!computeEnergy && validSavedForces.find(forceGroupFlags[i]) != validSavedForces.end()) { if (!computeEnergy && validSavedForces.find(forceGroupFlags[step]) != validSavedForces.end()) {
// We can just restore the forces we saved earlier. // We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[i]]->copyTo(cu.getForce()); savedForces[forceGroupFlags[step]]->copyTo(cu.getForce());
} }
else { else {
recordChangedParameters(context); recordChangedParameters(context);
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[i]); energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]);
energyFloat = (float) energy; energyFloat = (float) energy;
} }
forcesAreValid = true; forcesAreValid = true;
} }
if (needsGlobals[i] && !deviceGlobalsAreCurrent) { if (needsGlobals[step] && !deviceGlobalsAreCurrent) {
// Upload the global values to the device. // Upload the global values to the device.
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision())
...@@ -6165,59 +6166,72 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -6165,59 +6166,72 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
globalValues->upload(globalValuesFloat); globalValues->upload(globalValuesFloat);
} }
} }
if (stepType[i] == CustomIntegrator::ComputePerDof && !merged[i]) { if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[i]); int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]);
kernelArgs[i][0][1] = &posCorrection; kernelArgs[step][0][1] = &posCorrection;
kernelArgs[i][0][8] = &integration.getRandom().getDevicePointer(); kernelArgs[step][0][8] = &integration.getRandom().getDevicePointer();
kernelArgs[i][0][9] = &randomIndex; kernelArgs[step][0][9] = &randomIndex;
kernelArgs[i][0][10] = &uniformRandoms->getDevicePointer(); kernelArgs[step][0][10] = &uniformRandoms->getDevicePointer();
if (requiredUniform[i] > 0) if (requiredUniform[step] > 0)
cu.executeKernel(randomKernel, &randomArgs[0], numAtoms); cu.executeKernel(randomKernel, &randomArgs[0], numAtoms);
cu.executeKernel(kernels[i][0], &kernelArgs[i][0][0], numAtoms); cu.executeKernel(kernels[step][0], &kernelArgs[step][0][0], numAtoms);
} }
else if (stepType[i] == CustomIntegrator::ComputeGlobal) { else if (stepType[step] == CustomIntegrator::ComputeGlobal) {
expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber()); expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber()); expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
expressionSet.setVariable(stepEnergyVariableIndex[i], energy); expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
recordGlobalValue(globalExpressions[i][0].evaluate(), stepTarget[i]); recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step]);
} }
else if (stepType[i] == CustomIntegrator::ComputeSum) { else if (stepType[step] == CustomIntegrator::ComputeSum) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[i]); int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]);
kernelArgs[i][0][1] = &posCorrection; kernelArgs[step][0][1] = &posCorrection;
kernelArgs[i][0][8] = &integration.getRandom().getDevicePointer(); kernelArgs[step][0][8] = &integration.getRandom().getDevicePointer();
kernelArgs[i][0][9] = &randomIndex; kernelArgs[step][0][9] = &randomIndex;
kernelArgs[i][0][10] = &uniformRandoms->getDevicePointer(); kernelArgs[step][0][10] = &uniformRandoms->getDevicePointer();
if (requiredUniform[i] > 0) if (requiredUniform[step] > 0)
cu.executeKernel(randomKernel, &randomArgs[0], numAtoms); cu.executeKernel(randomKernel, &randomArgs[0], numAtoms);
cu.clearBuffer(*sumBuffer); cu.clearBuffer(*sumBuffer);
cu.executeKernel(kernels[i][0], &kernelArgs[i][0][0], numAtoms); cu.executeKernel(kernels[step][0], &kernelArgs[step][0][0], numAtoms);
cu.executeKernel(kernels[i][1], &kernelArgs[i][1][0], CudaContext::ThreadBlockSize, CudaContext::ThreadBlockSize); cu.executeKernel(kernels[step][1], &kernelArgs[step][1][0], CudaContext::ThreadBlockSize, CudaContext::ThreadBlockSize);
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) { if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) {
double value; double value;
summedValue->download(&value); summedValue->download(&value);
globalValuesDouble[stepTarget[i].variableIndex] = value; globalValuesDouble[stepTarget[step].variableIndex] = value;
} }
else { else {
float value; float value;
summedValue->download(&value); summedValue->download(&value);
globalValuesDouble[stepTarget[i].variableIndex] = value; globalValuesDouble[stepTarget[step].variableIndex] = value;
} }
} }
else if (stepType[i] == CustomIntegrator::UpdateContextState) { else if (stepType[step] == CustomIntegrator::UpdateContextState) {
recordChangedParameters(context); recordChangedParameters(context);
context.updateContextState(); context.updateContextState();
} }
else if (stepType[i] == CustomIntegrator::ConstrainPositions) { else if (stepType[step] == CustomIntegrator::ConstrainPositions) {
cu.getIntegrationUtilities().applyConstraints(integrator.getConstraintTolerance()); cu.getIntegrationUtilities().applyConstraints(integrator.getConstraintTolerance());
kernelArgs[i][0][1] = &posCorrection; kernelArgs[step][0][1] = &posCorrection;
cu.executeKernel(kernels[i][0], &kernelArgs[i][0][0], numAtoms); cu.executeKernel(kernels[step][0], &kernelArgs[step][0][0], numAtoms);
cu.getIntegrationUtilities().computeVirtualSites(); cu.getIntegrationUtilities().computeVirtualSites();
} }
else if (stepType[i] == CustomIntegrator::ConstrainVelocities) { else if (stepType[step] == CustomIntegrator::ConstrainVelocities) {
cu.getIntegrationUtilities().applyVelocityConstraints(integrator.getConstraintTolerance()); cu.getIntegrationUtilities().applyVelocityConstraints(integrator.getConstraintTolerance());
} }
if (invalidatesForces[i]) else if (stepType[step] == CustomIntegrator::BeginIfBlock) {
if (!evaluateCondition(step))
nextStep = blockEnd[step]+1;
}
else if (stepType[step] == CustomIntegrator::BeginWhileBlock) {
if (!evaluateCondition(step))
nextStep = blockEnd[step]+1;
}
else if (stepType[step] == CustomIntegrator::EndBlock) {
if (blockEnd[step] != -1)
nextStep = blockEnd[step]; // Return to the start of a while block.
}
if (invalidatesForces[step])
forcesAreValid = false; forcesAreValid = false;
step = nextStep;
} }
recordChangedParameters(context); recordChangedParameters(context);
...@@ -6232,6 +6246,29 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -6232,6 +6246,29 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
} }
} }
bool CudaIntegrateCustomStepKernel::evaluateCondition(int step) {
expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
double lhs = globalExpressions[step][0].evaluate();
double rhs = globalExpressions[step][1].evaluate();
switch (comparisons[step]) {
case CustomIntegratorUtilities::EQUAL:
return (lhs == rhs);
case CustomIntegratorUtilities::LESS_THAN:
return (lhs < rhs);
case CustomIntegratorUtilities::GREATER_THAN:
return (lhs > rhs);
case CustomIntegratorUtilities::NOT_EQUAL:
return (lhs != rhs);
case CustomIntegratorUtilities::LESS_THAN_OR_EQUAL:
return (lhs <= rhs);
case CustomIntegratorUtilities::GREATER_THAN_OR_EQUAL:
return (lhs >= rhs);
}
throw OpenMMException("Invalid comparison operator");
}
double CudaIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { double CudaIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
prepareForComputation(context, integrator, forcesAreValid); prepareForComputation(context, integrator, forcesAreValid);
if (keNeedsForce && !forcesAreValid) { if (keNeedsForce && !forcesAreValid) {
...@@ -6320,8 +6357,10 @@ void CudaIntegrateCustomStepKernel::setGlobalVariables(ContextImpl& context, con ...@@ -6320,8 +6357,10 @@ void CudaIntegrateCustomStepKernel::setGlobalVariables(ContextImpl& context, con
initialGlobalVariables = values; initialGlobalVariables = values;
return; return;
} }
for (int i = 0; i < numGlobalVariables; i++) for (int i = 0; i < numGlobalVariables; i++) {
globalValuesDouble[globalVariableIndex[i]] = values[i]; globalValuesDouble[globalVariableIndex[i]] = values[i];
expressionSet.setVariable(globalVariableIndex[i], values[i]);
}
deviceGlobalsAreCurrent = false; deviceGlobalsAreCurrent = false;
} }
......
...@@ -756,6 +756,66 @@ void testMergedRandoms() { ...@@ -756,6 +756,66 @@ void testMergedRandoms() {
} }
} }
void testIfBlock() {
System system;
system.addParticle(2.0);
system.addParticle(2.0);
const double dt = 0.01;
CustomIntegrator integrator(dt);
integrator.addGlobalVariable("a", 0);
integrator.addGlobalVariable("b", 0);
integrator.addComputeGlobal("b", "1");
integrator.beginIfBlock("a < 3.5");
integrator.addComputeGlobal("b", "a+1");
integrator.endBlock();
Context context(system, integrator, platform);
// Set "a" to 1.7 and verify that "b" gets set to a+1.
integrator.setGlobalVariable(0, 1.7);
integrator.step(1);
ASSERT_EQUAL_TOL(2.7, integrator.getGlobalVariable(1), 1e-6);
// Now set it to a value that should cause the block to be skipped.
integrator.setGlobalVariable(0, 4.1);
integrator.step(1);
ASSERT_EQUAL_TOL(1.0, integrator.getGlobalVariable(1), 1e-6);
}
void testWhileBlock() {
System system;
system.addParticle(2.0);
system.addParticle(2.0);
const double dt = 0.01;
CustomIntegrator integrator(dt);
integrator.addGlobalVariable("a", 0);
integrator.addGlobalVariable("b", 0);
integrator.addComputeGlobal("b", "1");
integrator.beginWhileBlock("b <= a");
integrator.addComputeGlobal("b", "b+1");
integrator.endBlock();
Context context(system, integrator, platform);
// Try a case where the loop should be skipped.
integrator.setGlobalVariable(0, -3.3);
integrator.step(1);
ASSERT_EQUAL_TOL(1.0, integrator.getGlobalVariable(1), 1e-6);
// In this case it should be executed exactly once.
integrator.setGlobalVariable(0, 1.2);
integrator.step(1);
ASSERT_EQUAL_TOL(2.0, integrator.getGlobalVariable(1), 1e-6);
// In this case, it should be executed several times.
integrator.setGlobalVariable(0, 5.3);
integrator.step(1);
ASSERT_EQUAL_TOL(6.0, integrator.getGlobalVariable(1), 1e-6);
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -773,6 +833,8 @@ int main(int argc, char* argv[]) { ...@@ -773,6 +833,8 @@ int main(int argc, char* argv[]) {
testForceGroups(); testForceGroups();
testRespa(); testRespa();
testMergedRandoms(); testMergedRandoms();
testIfBlock();
testWhileBlock();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment