Commit 506908ce authored by peastman's avatar peastman
Browse files

CUDA CustomIntegrator supports if and while blocks

parent 0360c770
......@@ -1288,6 +1288,7 @@ private:
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
void recordGlobalValue(double value, GlobalTarget target);
void recordChangedParameters(ContextImpl& context);
bool evaluateCondition(int step);
CudaContext& cu;
double prevStepSize, energy;
float energyFloat;
......
......@@ -6124,9 +6124,10 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
int maxUniformRandoms = uniformRandoms->getSize();
void* randomArgs[] = {&maxUniformRandoms, &uniformRandoms->getDevicePointer(), &randomSeed->getDevicePointer()};
CUdeviceptr posCorrection = (cu.getUseMixedPrecision() ? cu.getPosqCorrection().getDevicePointer() : 0);
for (int i = 0; i < numSteps; i++) {
for (int step = 0; step < numSteps; ) {
int nextStep = step+1;
int lastForceGroups = context.getLastForceGroups();
if ((needsForces[i] || needsEnergy[i]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[i])) {
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again.
......@@ -6140,21 +6141,21 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
// Recompute forces and/or energy. Figure out what is actually needed
// between now and the next time they get invalidated again.
bool computeForce = (needsForces[i] || computeBothForceAndEnergy[i]);
bool computeEnergy = (needsEnergy[i] || computeBothForceAndEnergy[i]);
if (!computeEnergy && validSavedForces.find(forceGroupFlags[i]) != validSavedForces.end()) {
bool computeForce = (needsForces[step] || computeBothForceAndEnergy[step]);
bool computeEnergy = (needsEnergy[step] || computeBothForceAndEnergy[step]);
if (!computeEnergy && validSavedForces.find(forceGroupFlags[step]) != validSavedForces.end()) {
// We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[i]]->copyTo(cu.getForce());
savedForces[forceGroupFlags[step]]->copyTo(cu.getForce());
}
else {
recordChangedParameters(context);
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[i]);
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]);
energyFloat = (float) energy;
}
forcesAreValid = true;
}
if (needsGlobals[i] && !deviceGlobalsAreCurrent) {
if (needsGlobals[step] && !deviceGlobalsAreCurrent) {
// Upload the global values to the device.
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision())
......@@ -6165,59 +6166,72 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
globalValues->upload(globalValuesFloat);
}
}
if (stepType[i] == CustomIntegrator::ComputePerDof && !merged[i]) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[i]);
kernelArgs[i][0][1] = &posCorrection;
kernelArgs[i][0][8] = &integration.getRandom().getDevicePointer();
kernelArgs[i][0][9] = &randomIndex;
kernelArgs[i][0][10] = &uniformRandoms->getDevicePointer();
if (requiredUniform[i] > 0)
if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]);
kernelArgs[step][0][1] = &posCorrection;
kernelArgs[step][0][8] = &integration.getRandom().getDevicePointer();
kernelArgs[step][0][9] = &randomIndex;
kernelArgs[step][0][10] = &uniformRandoms->getDevicePointer();
if (requiredUniform[step] > 0)
cu.executeKernel(randomKernel, &randomArgs[0], numAtoms);
cu.executeKernel(kernels[i][0], &kernelArgs[i][0][0], numAtoms);
cu.executeKernel(kernels[step][0], &kernelArgs[step][0][0], numAtoms);
}
else if (stepType[i] == CustomIntegrator::ComputeGlobal) {
else if (stepType[step] == CustomIntegrator::ComputeGlobal) {
expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
expressionSet.setVariable(stepEnergyVariableIndex[i], energy);
recordGlobalValue(globalExpressions[i][0].evaluate(), stepTarget[i]);
}
else if (stepType[i] == CustomIntegrator::ComputeSum) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[i]);
kernelArgs[i][0][1] = &posCorrection;
kernelArgs[i][0][8] = &integration.getRandom().getDevicePointer();
kernelArgs[i][0][9] = &randomIndex;
kernelArgs[i][0][10] = &uniformRandoms->getDevicePointer();
if (requiredUniform[i] > 0)
expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step]);
}
else if (stepType[step] == CustomIntegrator::ComputeSum) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]);
kernelArgs[step][0][1] = &posCorrection;
kernelArgs[step][0][8] = &integration.getRandom().getDevicePointer();
kernelArgs[step][0][9] = &randomIndex;
kernelArgs[step][0][10] = &uniformRandoms->getDevicePointer();
if (requiredUniform[step] > 0)
cu.executeKernel(randomKernel, &randomArgs[0], numAtoms);
cu.clearBuffer(*sumBuffer);
cu.executeKernel(kernels[i][0], &kernelArgs[i][0][0], numAtoms);
cu.executeKernel(kernels[i][1], &kernelArgs[i][1][0], CudaContext::ThreadBlockSize, CudaContext::ThreadBlockSize);
cu.executeKernel(kernels[step][0], &kernelArgs[step][0][0], numAtoms);
cu.executeKernel(kernels[step][1], &kernelArgs[step][1][0], CudaContext::ThreadBlockSize, CudaContext::ThreadBlockSize);
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) {
double value;
summedValue->download(&value);
globalValuesDouble[stepTarget[i].variableIndex] = value;
globalValuesDouble[stepTarget[step].variableIndex] = value;
}
else {
float value;
summedValue->download(&value);
globalValuesDouble[stepTarget[i].variableIndex] = value;
globalValuesDouble[stepTarget[step].variableIndex] = value;
}
}
else if (stepType[i] == CustomIntegrator::UpdateContextState) {
else if (stepType[step] == CustomIntegrator::UpdateContextState) {
recordChangedParameters(context);
context.updateContextState();
}
else if (stepType[i] == CustomIntegrator::ConstrainPositions) {
else if (stepType[step] == CustomIntegrator::ConstrainPositions) {
cu.getIntegrationUtilities().applyConstraints(integrator.getConstraintTolerance());
kernelArgs[i][0][1] = &posCorrection;
cu.executeKernel(kernels[i][0], &kernelArgs[i][0][0], numAtoms);
kernelArgs[step][0][1] = &posCorrection;
cu.executeKernel(kernels[step][0], &kernelArgs[step][0][0], numAtoms);
cu.getIntegrationUtilities().computeVirtualSites();
}
else if (stepType[i] == CustomIntegrator::ConstrainVelocities) {
else if (stepType[step] == CustomIntegrator::ConstrainVelocities) {
cu.getIntegrationUtilities().applyVelocityConstraints(integrator.getConstraintTolerance());
}
if (invalidatesForces[i])
else if (stepType[step] == CustomIntegrator::BeginIfBlock) {
if (!evaluateCondition(step))
nextStep = blockEnd[step]+1;
}
else if (stepType[step] == CustomIntegrator::BeginWhileBlock) {
if (!evaluateCondition(step))
nextStep = blockEnd[step]+1;
}
else if (stepType[step] == CustomIntegrator::EndBlock) {
if (blockEnd[step] != -1)
nextStep = blockEnd[step]; // Return to the start of a while block.
}
if (invalidatesForces[step])
forcesAreValid = false;
step = nextStep;
}
recordChangedParameters(context);
......@@ -6232,6 +6246,29 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
}
}
bool CudaIntegrateCustomStepKernel::evaluateCondition(int step) {
expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
double lhs = globalExpressions[step][0].evaluate();
double rhs = globalExpressions[step][1].evaluate();
switch (comparisons[step]) {
case CustomIntegratorUtilities::EQUAL:
return (lhs == rhs);
case CustomIntegratorUtilities::LESS_THAN:
return (lhs < rhs);
case CustomIntegratorUtilities::GREATER_THAN:
return (lhs > rhs);
case CustomIntegratorUtilities::NOT_EQUAL:
return (lhs != rhs);
case CustomIntegratorUtilities::LESS_THAN_OR_EQUAL:
return (lhs <= rhs);
case CustomIntegratorUtilities::GREATER_THAN_OR_EQUAL:
return (lhs >= rhs);
}
throw OpenMMException("Invalid comparison operator");
}
double CudaIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
prepareForComputation(context, integrator, forcesAreValid);
if (keNeedsForce && !forcesAreValid) {
......@@ -6320,8 +6357,10 @@ void CudaIntegrateCustomStepKernel::setGlobalVariables(ContextImpl& context, con
initialGlobalVariables = values;
return;
}
for (int i = 0; i < numGlobalVariables; i++)
for (int i = 0; i < numGlobalVariables; i++) {
globalValuesDouble[globalVariableIndex[i]] = values[i];
expressionSet.setVariable(globalVariableIndex[i], values[i]);
}
deviceGlobalsAreCurrent = false;
}
......
......@@ -756,6 +756,66 @@ void testMergedRandoms() {
}
}
void testIfBlock() {
System system;
system.addParticle(2.0);
system.addParticle(2.0);
const double dt = 0.01;
CustomIntegrator integrator(dt);
integrator.addGlobalVariable("a", 0);
integrator.addGlobalVariable("b", 0);
integrator.addComputeGlobal("b", "1");
integrator.beginIfBlock("a < 3.5");
integrator.addComputeGlobal("b", "a+1");
integrator.endBlock();
Context context(system, integrator, platform);
// Set "a" to 1.7 and verify that "b" gets set to a+1.
integrator.setGlobalVariable(0, 1.7);
integrator.step(1);
ASSERT_EQUAL_TOL(2.7, integrator.getGlobalVariable(1), 1e-6);
// Now set it to a value that should cause the block to be skipped.
integrator.setGlobalVariable(0, 4.1);
integrator.step(1);
ASSERT_EQUAL_TOL(1.0, integrator.getGlobalVariable(1), 1e-6);
}
void testWhileBlock() {
System system;
system.addParticle(2.0);
system.addParticle(2.0);
const double dt = 0.01;
CustomIntegrator integrator(dt);
integrator.addGlobalVariable("a", 0);
integrator.addGlobalVariable("b", 0);
integrator.addComputeGlobal("b", "1");
integrator.beginWhileBlock("b <= a");
integrator.addComputeGlobal("b", "b+1");
integrator.endBlock();
Context context(system, integrator, platform);
// Try a case where the loop should be skipped.
integrator.setGlobalVariable(0, -3.3);
integrator.step(1);
ASSERT_EQUAL_TOL(1.0, integrator.getGlobalVariable(1), 1e-6);
// In this case it should be executed exactly once.
integrator.setGlobalVariable(0, 1.2);
integrator.step(1);
ASSERT_EQUAL_TOL(2.0, integrator.getGlobalVariable(1), 1e-6);
// In this case, it should be executed several times.
integrator.setGlobalVariable(0, 5.3);
integrator.step(1);
ASSERT_EQUAL_TOL(6.0, integrator.getGlobalVariable(1), 1e-6);
}
int main(int argc, char* argv[]) {
try {
if (argc > 1)
......@@ -773,6 +833,8 @@ int main(int argc, char* argv[]) {
testForceGroups();
testRespa();
testMergedRandoms();
testIfBlock();
testWhileBlock();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment