Commit 283e97b2 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1600 from peastman/customdt

Fixed error when CustomIntegrator modifies step size
parents 690f5c23 1de311c9
......@@ -1432,7 +1432,7 @@ private:
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes);
void recordGlobalValue(double value, GlobalTarget target);
void recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator);
void recordChangedParameters(ContextImpl& context);
bool evaluateCondition(int step);
CudaContext& cu;
......
......@@ -7309,7 +7309,7 @@ void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context,
}
localValuesAreCurrent = false;
double stepSize = integrator.getStepSize();
recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex));
recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex), integrator);
for (int i = 0; i < (int) parameterNames.size(); i++) {
double value = context.getParameter(parameterNames[i]);
if (value != globalValuesDouble[parameterVariableIndex[i]]) {
......@@ -7442,7 +7442,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step]);
recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step], integrator);
}
else if (stepType[step] == CustomIntegrator::ComputeSum) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]);
......@@ -7458,12 +7458,12 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) {
double value;
summedValue->download(&value);
recordGlobalValue(value, stepTarget[step]);
recordGlobalValue(value, stepTarget[step], integrator);
}
else {
float value;
summedValue->download(&value);
recordGlobalValue(value, stepTarget[step]);
recordGlobalValue(value, stepTarget[step], integrator);
}
}
else if (stepType[step] == CustomIntegrator::UpdateContextState) {
......@@ -7567,7 +7567,7 @@ double CudaIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& context,
}
}
void CudaIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target) {
void CudaIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator) {
switch (target.type) {
case DT:
if (value != globalValuesDouble[dtVariableIndex])
......@@ -7575,6 +7575,7 @@ void CudaIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget
expressionSet.setVariable(dtVariableIndex, value);
globalValuesDouble[dtVariableIndex] = value;
cu.getIntegrationUtilities().setNextStepSize(value);
integrator.setStepSize(value);
break;
case VARIABLE:
case PARAMETER:
......
......@@ -1419,7 +1419,7 @@ private:
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes);
void recordGlobalValue(double value, GlobalTarget target);
void recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator);
void recordChangedParameters(ContextImpl& context);
bool evaluateCondition(int step);
OpenCLContext& cl;
......
......@@ -7598,7 +7598,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
}
localValuesAreCurrent = false;
double stepSize = integrator.getStepSize();
recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex));
recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex), integrator);
for (int i = 0; i < (int) parameterNames.size(); i++) {
double value = context.getParameter(parameterNames[i]);
if (value != globalValuesDouble[parameterVariableIndex[i]]) {
......@@ -7729,7 +7729,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step]);
recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step], integrator);
}
else if (stepType[step] == CustomIntegrator::ComputeSum) {
kernels[step][0].setArg<cl_uint>(9, integration.prepareRandomNumbers(requiredGaussian[step]));
......@@ -7747,12 +7747,12 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
double value;
summedValue->download(&value);
recordGlobalValue(value, stepTarget[step]);
recordGlobalValue(value, stepTarget[step], integrator);
}
else {
float value;
summedValue->download(&value);
recordGlobalValue(value, stepTarget[step]);
recordGlobalValue(value, stepTarget[step], integrator);
}
}
else if (stepType[step] == CustomIntegrator::UpdateContextState) {
......@@ -7856,7 +7856,7 @@ double OpenCLIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& contex
}
}
void OpenCLIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target) {
void OpenCLIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator) {
switch (target.type) {
case DT:
if (value != globalValuesDouble[dtVariableIndex])
......@@ -7864,6 +7864,7 @@ void OpenCLIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarg
expressionSet.setVariable(dtVariableIndex, value);
globalValuesDouble[dtVariableIndex] = value;
cl.getIntegrationUtilities().setNextStepSize(value);
integrator.setStepSize(value);
break;
case VARIABLE:
case PARAMETER:
......
......@@ -858,6 +858,43 @@ void testEnergyParameterDerivatives() {
ASSERT_EQUAL_TOL(dEdtheta0, values[2][1], 1e-5);
}
/**
* Test an integrator that modifies the step size.
*/
void testChangeDT() {
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.5);
integrator.addGlobalVariable("dt_global", 0.0);
integrator.addPerDofVariable("dt_dof", 0.0);
integrator.addComputeGlobal("dt", "dt+1");
integrator.addComputePerDof("dt_dof", "dt");
integrator.addComputeGlobal("dt_global", "dt");
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(0, 0, 0);
context.setPositions(positions);
for (int i = 0; i < 5; i++) {
integrator.step(1);
double dt = 1.5+i;
ASSERT_EQUAL_TOL(dt, integrator.getStepSize(), 1e-5);
ASSERT_EQUAL_TOL(dt, integrator.getGlobalVariable(0), 1e-5);
vector<Vec3> values;
integrator.getPerDofVariable(0, values);
ASSERT_EQUAL_VEC(Vec3(dt, dt, dt), values[0], 1e-5);
}
integrator.setStepSize(1.0);
for (int i = 0; i < 5; i++) {
integrator.step(1);
double dt = 2.0+i;
ASSERT_EQUAL_TOL(dt, integrator.getStepSize(), 1e-5);
ASSERT_EQUAL_TOL(dt, integrator.getGlobalVariable(0), 1e-5);
vector<Vec3> values;
integrator.getPerDofVariable(0, values);
ASSERT_EQUAL_VEC(Vec3(dt, dt, dt), values[0], 1e-5);
}
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -879,6 +916,7 @@ int main(int argc, char* argv[]) {
testWhileBlock();
testChangingGlobal();
testEnergyParameterDerivatives();
testChangeDT();
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment