Unverified Commit 62394b00 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2651 from peastman/checkpoint

CustomIntegrator variables are saved to checkpoints
parents fde2ca6d 2b9ef981
...@@ -668,6 +668,16 @@ protected: ...@@ -668,6 +668,16 @@ protected:
* Get whether computeKineticEnergy() expects forces to have been computed. * Get whether computeKineticEnergy() expects forces to have been computed.
*/ */
bool kineticEnergyRequiresForce() const; bool kineticEnergyRequiresForce() const;
/**
* This is called while writing checkpoints. It gives the integrator a chance to write
* its own data.
*/
void createCheckpoint(std::ostream& stream) const;
/**
* This is called while loading a checkpoint. The integrator should read in whatever
* data it wrote in createCheckpoint() and update its internal state accordingly.
*/
void loadCheckpoint(std::istream& stream);
private: private:
class ComputationInfo; class ComputationInfo;
class FunctionInfo; class FunctionInfo;
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "State.h" #include "State.h"
#include "Vec3.h" #include "Vec3.h"
#include <iosfwd>
#include <map> #include <map>
#include <vector> #include <vector>
#include "internal/windowsExport.h" #include "internal/windowsExport.h"
...@@ -151,6 +152,18 @@ protected: ...@@ -151,6 +152,18 @@ protected:
virtual double getVelocityTimeOffset() const { virtual double getVelocityTimeOffset() const {
return 0.0; return 0.0;
} }
/**
* This is called while writing checkpoints. It gives the integrator a chance to write
* its own data. The default implementation does nothing.
*/
virtual void createCheckpoint(std::ostream& stream) const {
}
/**
* This is called while loading a checkpoint. The integrator should read in whatever
* data it wrote in createCheckpoint() and update its internal state accordingly.
*/
virtual void loadCheckpoint(std::istream& stream) {
}
private: private:
double stepSize, constraintTol; double stepSize, constraintTol;
}; };
......
...@@ -454,6 +454,7 @@ void ContextImpl::createCheckpoint(ostream& stream) { ...@@ -454,6 +454,7 @@ void ContextImpl::createCheckpoint(ostream& stream) {
stream.write((char*) &param.second, sizeof(double)); stream.write((char*) &param.second, sizeof(double));
} }
updateStateDataKernel.getAs<UpdateStateDataKernel>().createCheckpoint(*this, stream); updateStateDataKernel.getAs<UpdateStateDataKernel>().createCheckpoint(*this, stream);
integrator.createCheckpoint(stream);
stream.flush(); stream.flush();
} }
...@@ -480,6 +481,7 @@ void ContextImpl::loadCheckpoint(istream& stream) { ...@@ -480,6 +481,7 @@ void ContextImpl::loadCheckpoint(istream& stream) {
parameters[name] = value; parameters[name] = value;
} }
updateStateDataKernel.getAs<UpdateStateDataKernel>().loadCheckpoint(*this, stream); updateStateDataKernel.getAs<UpdateStateDataKernel>().loadCheckpoint(*this, stream);
integrator.loadCheckpoint(stream);
hasSetPositions = true; hasSetPositions = true;
integrator.stateChanged(State::Positions); integrator.stateChanged(State::Positions);
integrator.stateChanged(State::Velocities); integrator.stateChanged(State::Velocities);
......
...@@ -114,6 +114,31 @@ bool CustomIntegrator::kineticEnergyRequiresForce() const { ...@@ -114,6 +114,31 @@ bool CustomIntegrator::kineticEnergyRequiresForce() const {
return keNeedsForce; return keNeedsForce;
} }
void CustomIntegrator::createCheckpoint(std::ostream& stream) const {
for (int i = 0; i < getNumGlobalVariables(); i++) {
double value = getGlobalVariable(i);
stream.write((char*) &value, sizeof(double));
}
vector<Vec3> values;
for (int i = 0; i < getNumPerDofVariables(); i++) {
getPerDofVariable(i, values);
stream.write((char*) values.data(), sizeof(Vec3)*values.size());
}
}
void CustomIntegrator::loadCheckpoint(std::istream& stream) {
double value;
for (int i = 0; i < getNumGlobalVariables(); i++) {
stream.read((char*) &value, sizeof(double));
setGlobalVariable(i, value);
}
vector<Vec3> values(context->getSystem().getNumParticles());
for (int i = 0; i < getNumPerDofVariables(); i++) {
stream.read((char*) values.data(), sizeof(Vec3)*values.size());
setPerDofVariable(i, values);
}
}
void CustomIntegrator::step(int steps) { void CustomIntegrator::step(int steps) {
if (context == NULL) if (context == NULL)
throw OpenMMException("This Integrator is not bound to a context!"); throw OpenMMException("This Integrator is not bound to a context!");
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2019 Stanford University and the Authors. * * Portions copyright (c) 2008-2020 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include <sstream>
#include <vector> #include <vector>
using namespace OpenMM; using namespace OpenMM;
...@@ -1156,6 +1157,31 @@ void testInitialTemperature() { ...@@ -1156,6 +1157,31 @@ void testInitialTemperature() {
ASSERT_USUALLY_EQUAL_TOL(targetTemperature, temperature, 0.01); ASSERT_USUALLY_EQUAL_TOL(targetTemperature, temperature, 0.01);
} }
void testCheckpoint() {
// Test that integrator variables get loaded correctly from checkpoints.
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("a", 1.0);
integrator.addPerDofVariable("b", 2.0);
Context context(system, integrator, platform);
vector<Vec3> positions(1, Vec3());
context.setPositions(positions);
integrator.setGlobalVariable(0, 5.0);
vector<Vec3> b1(1, Vec3(1, 2, 3));
integrator.setPerDofVariable(0, b1);
stringstream checkpoint;
context.createCheckpoint(checkpoint);
integrator.setGlobalVariable(0, 10.0);
vector<Vec3> b2(1, Vec3(4, 5, 6));
integrator.setPerDofVariable(0, b2);
context.loadCheckpoint(checkpoint);
ASSERT_EQUAL(5.0, integrator.getGlobalVariable(0));
vector<Vec3> b3;
integrator.getPerDofVariable(0, b3);
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -1184,6 +1210,7 @@ int main(int argc, char* argv[]) { ...@@ -1184,6 +1210,7 @@ int main(int argc, char* argv[]) {
testVectorFunctions(); testVectorFunctions();
testRecordEnergy(); testRecordEnergy();
testInitialTemperature(); testInitialTemperature();
testCheckpoint();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment