Commit 283e97b2 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1600 from peastman/customdt

Fixed error when CustomIntegrator modifies step size
parents 690f5c23 1de311c9
...@@ -1432,7 +1432,7 @@ private: ...@@ -1432,7 +1432,7 @@ private:
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid); void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context); Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes); void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes);
void recordGlobalValue(double value, GlobalTarget target); void recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator);
void recordChangedParameters(ContextImpl& context); void recordChangedParameters(ContextImpl& context);
bool evaluateCondition(int step); bool evaluateCondition(int step);
CudaContext& cu; CudaContext& cu;
......
...@@ -7309,7 +7309,7 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, ...@@ -7309,7 +7309,7 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
} }
localValuesAreCurrent = false; localValuesAreCurrent = false;
double stepSize = integrator.getStepSize(); double stepSize = integrator.getStepSize();
recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex)); recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex), integrator);
for (int i = 0; i < (int) parameterNames.size(); i++) { for (int i = 0; i < (int) parameterNames.size(); i++) {
double value = context.getParameter(parameterNames[i]); double value = context.getParameter(parameterNames[i]);
if (value != globalValuesDouble[parameterVariableIndex[i]]) { if (value != globalValuesDouble[parameterVariableIndex[i]]) {
...@@ -7442,7 +7442,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7442,7 +7442,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber()); expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber()); expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
expressionSet.setVariable(stepEnergyVariableIndex[step], energy); expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step]); recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step], integrator);
} }
else if (stepType[step] == CustomIntegrator::ComputeSum) { else if (stepType[step] == CustomIntegrator::ComputeSum) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]); int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]);
...@@ -7458,12 +7458,12 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7458,12 +7458,12 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) { if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) {
double value; double value;
summedValue->download(&value); summedValue->download(&value);
recordGlobalValue(value, stepTarget[step]); recordGlobalValue(value, stepTarget[step], integrator);
} }
else { else {
float value; float value;
summedValue->download(&value); summedValue->download(&value);
recordGlobalValue(value, stepTarget[step]); recordGlobalValue(value, stepTarget[step], integrator);
} }
} }
else if (stepType[step] == CustomIntegrator::UpdateContextState) { else if (stepType[step] == CustomIntegrator::UpdateContextState) {
...@@ -7567,7 +7567,7 @@ double CudaIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& context, ...@@ -7567,7 +7567,7 @@ double CudaIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& context,
} }
} }
void CudaIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target) { void CudaIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator) {
switch (target.type) { switch (target.type) {
case DT: case DT:
if (value != globalValuesDouble[dtVariableIndex]) if (value != globalValuesDouble[dtVariableIndex])
...@@ -7575,6 +7575,7 @@ void CudaIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget ...@@ -7575,6 +7575,7 @@ void CudaIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget
expressionSet.setVariable(dtVariableIndex, value); expressionSet.setVariable(dtVariableIndex, value);
globalValuesDouble[dtVariableIndex] = value; globalValuesDouble[dtVariableIndex] = value;
cu.getIntegrationUtilities().setNextStepSize(value); cu.getIntegrationUtilities().setNextStepSize(value);
integrator.setStepSize(value);
break; break;
case VARIABLE: case VARIABLE:
case PARAMETER: case PARAMETER:
......
...@@ -1419,7 +1419,7 @@ private: ...@@ -1419,7 +1419,7 @@ private:
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid); void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context); Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes); void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes);
void recordGlobalValue(double value, GlobalTarget target); void recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator);
void recordChangedParameters(ContextImpl& context); void recordChangedParameters(ContextImpl& context);
bool evaluateCondition(int step); bool evaluateCondition(int step);
OpenCLContext& cl; OpenCLContext& cl;
......
...@@ -7598,7 +7598,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7598,7 +7598,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
} }
localValuesAreCurrent = false; localValuesAreCurrent = false;
double stepSize = integrator.getStepSize(); double stepSize = integrator.getStepSize();
recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex)); recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex), integrator);
for (int i = 0; i < (int) parameterNames.size(); i++) { for (int i = 0; i < (int) parameterNames.size(); i++) {
double value = context.getParameter(parameterNames[i]); double value = context.getParameter(parameterNames[i]);
if (value != globalValuesDouble[parameterVariableIndex[i]]) { if (value != globalValuesDouble[parameterVariableIndex[i]]) {
...@@ -7729,7 +7729,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -7729,7 +7729,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber()); expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber()); expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
expressionSet.setVariable(stepEnergyVariableIndex[step], energy); expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step]); recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step], integrator);
} }
else if (stepType[step] == CustomIntegrator::ComputeSum) { else if (stepType[step] == CustomIntegrator::ComputeSum) {
kernels[step][0].setArg<cl_uint>(9, integration.prepareRandomNumbers(requiredGaussian[step])); kernels[step][0].setArg<cl_uint>(9, integration.prepareRandomNumbers(requiredGaussian[step]));
...@@ -7747,12 +7747,12 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -7747,12 +7747,12 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) { if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
double value; double value;
summedValue->download(&value); summedValue->download(&value);
recordGlobalValue(value, stepTarget[step]); recordGlobalValue(value, stepTarget[step], integrator);
} }
else { else {
float value; float value;
summedValue->download(&value); summedValue->download(&value);
recordGlobalValue(value, stepTarget[step]); recordGlobalValue(value, stepTarget[step], integrator);
} }
} }
else if (stepType[step] == CustomIntegrator::UpdateContextState) { else if (stepType[step] == CustomIntegrator::UpdateContextState) {
...@@ -7856,7 +7856,7 @@ double OpenCLIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& contex ...@@ -7856,7 +7856,7 @@ double OpenCLIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& contex
} }
} }
void OpenCLIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target) { void OpenCLIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator) {
switch (target.type) { switch (target.type) {
case DT: case DT:
if (value != globalValuesDouble[dtVariableIndex]) if (value != globalValuesDouble[dtVariableIndex])
...@@ -7864,6 +7864,7 @@ void OpenCLIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarg ...@@ -7864,6 +7864,7 @@ void OpenCLIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarg
expressionSet.setVariable(dtVariableIndex, value); expressionSet.setVariable(dtVariableIndex, value);
globalValuesDouble[dtVariableIndex] = value; globalValuesDouble[dtVariableIndex] = value;
cl.getIntegrationUtilities().setNextStepSize(value); cl.getIntegrationUtilities().setNextStepSize(value);
integrator.setStepSize(value);
break; break;
case VARIABLE: case VARIABLE:
case PARAMETER: case PARAMETER:
......
...@@ -858,6 +858,43 @@ void testEnergyParameterDerivatives() { ...@@ -858,6 +858,43 @@ void testEnergyParameterDerivatives() {
ASSERT_EQUAL_TOL(dEdtheta0, values[2][1], 1e-5); ASSERT_EQUAL_TOL(dEdtheta0, values[2][1], 1e-5);
} }
/**
* Test an integrator that modifies the step size.
*/
void testChangeDT() {
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.5);
integrator.addGlobalVariable("dt_global", 0.0);
integrator.addPerDofVariable("dt_dof", 0.0);
integrator.addComputeGlobal("dt", "dt+1");
integrator.addComputePerDof("dt_dof", "dt");
integrator.addComputeGlobal("dt_global", "dt");
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(0, 0, 0);
context.setPositions(positions);
for (int i = 0; i < 5; i++) {
integrator.step(1);
double dt = 1.5+i;
ASSERT_EQUAL_TOL(dt, integrator.getStepSize(), 1e-5);
ASSERT_EQUAL_TOL(dt, integrator.getGlobalVariable(0), 1e-5);
vector<Vec3> values;
integrator.getPerDofVariable(0, values);
ASSERT_EQUAL_VEC(Vec3(dt, dt, dt), values[0], 1e-5);
}
integrator.setStepSize(1.0);
for (int i = 0; i < 5; i++) {
integrator.step(1);
double dt = 2.0+i;
ASSERT_EQUAL_TOL(dt, integrator.getStepSize(), 1e-5);
ASSERT_EQUAL_TOL(dt, integrator.getGlobalVariable(0), 1e-5);
vector<Vec3> values;
integrator.getPerDofVariable(0, values);
ASSERT_EQUAL_VEC(Vec3(dt, dt, dt), values[0], 1e-5);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -879,6 +916,7 @@ int main(int argc, char* argv[]) { ...@@ -879,6 +916,7 @@ int main(int argc, char* argv[]) {
testWhileBlock(); testWhileBlock();
testChangingGlobal(); testChangingGlobal();
testEnergyParameterDerivatives(); testEnergyParameterDerivatives();
testChangeDT();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment