Commit c114f2fe authored by peastman's avatar peastman
Browse files

Bug fix: merging two per-DOF steps could prevent global values from being propagated to the GPU

parent 9ddded35
......@@ -5899,11 +5899,17 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
// Identify steps that can be merged into a single kernel.
for (int step = 1; step < numSteps; step++) {
if (needsForces[step] || needsEnergy[step])
if (invalidatesForces[step] || ((needsForces[step] || needsEnergy[step]) && forceGroupFlags[step] != forceGroupFlags[step-1]))
continue;
if (stepType[step-1] == CustomIntegrator::ComputePerDof && stepType[step] == CustomIntegrator::ComputePerDof)
merged[step] = true;
}
for (int step = numSteps-1; step > 0; step--)
if (merged[step]) {
needsForces[step-1] = (needsForces[step] || needsForces[step-1]);
needsEnergy[step-1] = (needsEnergy[step] || needsEnergy[step-1]);
needsGlobals[step-1] = (needsGlobals[step] || needsGlobals[step-1]);
}
// Loop over all steps and create the kernels for them.
......
......@@ -816,6 +816,32 @@ void testWhileBlock() {
ASSERT_EQUAL_TOL(6.0, integrator.getGlobalVariable(1), 1e-6);
}
/**
* Test modifying a global variable, then using it in a per-DOF computation.
*/
void testChangingGlobal() {
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.1);
integrator.addGlobalVariable("g", 0);
integrator.addPerDofVariable("a", 0);
integrator.addPerDofVariable("b", 0);
integrator.addComputeGlobal("g", "g+1");
integrator.addComputePerDof("a", "0.5");
integrator.addComputePerDof("b", "a+g");
Context context(system, integrator, platform);
// See if everything is being calculated correctly..
for (int i = 0; i < 10; i++) {
integrator.step(1);
ASSERT_EQUAL_TOL(i+1, integrator.getGlobalVariable(0), 1e-5);
vector<Vec3> values;
integrator.getPerDofVariable(1, values);
ASSERT_EQUAL_VEC(Vec3(i+1.5, i+1.5, i+1.5), values[0], 1e-5);
}
}
int main(int argc, char* argv[]) {
try {
if (argc > 1)
......@@ -835,6 +861,7 @@ int main(int argc, char* argv[]) {
testMergedRandoms();
testIfBlock();
testWhileBlock();
testChangingGlobal();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
......@@ -6163,11 +6163,17 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
// Identify steps that can be merged into a single kernel.
for (int step = 1; step < numSteps; step++) {
if (needsForces[step] || needsEnergy[step])
if (invalidatesForces[step] || ((needsForces[step] || needsEnergy[step]) && forceGroupFlags[step] != forceGroupFlags[step-1]))
continue;
if (stepType[step-1] == CustomIntegrator::ComputePerDof && stepType[step] == CustomIntegrator::ComputePerDof)
merged[step] = true;
}
for (int step = numSteps-1; step > 0; step--)
if (merged[step]) {
needsForces[step-1] = (needsForces[step] || needsForces[step-1]);
needsEnergy[step-1] = (needsEnergy[step] || needsEnergy[step-1]);
needsGlobals[step-1] = (needsGlobals[step] || needsGlobals[step-1]);
}
// Loop over all steps and create the kernels for them.
......
......@@ -816,6 +816,32 @@ void testWhileBlock() {
ASSERT_EQUAL_TOL(6.0, integrator.getGlobalVariable(1), 1e-6);
}
/**
* Test modifying a global variable, then using it in a per-DOF computation.
*/
void testChangingGlobal() {
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.1);
integrator.addGlobalVariable("g", 0);
integrator.addPerDofVariable("a", 0);
integrator.addPerDofVariable("b", 0);
integrator.addComputeGlobal("g", "g+1");
integrator.addComputePerDof("a", "0.5");
integrator.addComputePerDof("b", "a+g");
Context context(system, integrator, platform);
// See if everything is being calculated correctly..
for (int i = 0; i < 10; i++) {
integrator.step(1);
ASSERT_EQUAL_TOL(i+1, integrator.getGlobalVariable(0), 1e-5);
vector<Vec3> values;
integrator.getPerDofVariable(1, values);
ASSERT_EQUAL_VEC(Vec3(i+1.5, i+1.5, i+1.5), values[0], 1e-5);
}
}
int main(int argc, char* argv[]) {
try {
if (argc > 1)
......@@ -835,6 +861,7 @@ int main(int argc, char* argv[]) {
testMergedRandoms();
testIfBlock();
testWhileBlock();
testChangingGlobal();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment