Commit 2b9ef981 authored by peastman's avatar peastman
Browse files

CustomIntegrator variables are saved to checkpoints

parent fde2ca6d
......@@ -668,6 +668,16 @@ protected:
* Get whether computeKineticEnergy() expects forces to have been computed.
*/
bool kineticEnergyRequiresForce() const;
/**
* This is called while writing checkpoints. It gives the integrator a chance to write
* its own data.
*/
void createCheckpoint(std::ostream& stream) const;
/**
* This is called while loading a checkpoint. The integrator should read in whatever
* data it wrote in createCheckpoint() and update its internal state accordingly.
*/
void loadCheckpoint(std::istream& stream);
private:
class ComputationInfo;
class FunctionInfo;
......
......@@ -34,6 +34,7 @@
#include "State.h"
#include "Vec3.h"
#include <iosfwd>
#include <map>
#include <vector>
#include "internal/windowsExport.h"
......@@ -151,6 +152,18 @@ protected:
virtual double getVelocityTimeOffset() const {
return 0.0;
}
/**
* This is called while writing checkpoints. It gives the integrator a chance to write
* its own data. The default implementation does nothing.
*/
virtual void createCheckpoint(std::ostream& stream) const {
}
/**
* This is called while loading a checkpoint. The integrator should read in whatever
* data it wrote in createCheckpoint() and update its internal state accordingly.
*/
virtual void loadCheckpoint(std::istream& stream) {
}
private:
double stepSize, constraintTol;
};
......
......@@ -454,6 +454,7 @@ void ContextImpl::createCheckpoint(ostream& stream) {
stream.write((char*) &param.second, sizeof(double));
}
updateStateDataKernel.getAs<UpdateStateDataKernel>().createCheckpoint(*this, stream);
integrator.createCheckpoint(stream);
stream.flush();
}
......@@ -480,6 +481,7 @@ void ContextImpl::loadCheckpoint(istream& stream) {
parameters[name] = value;
}
updateStateDataKernel.getAs<UpdateStateDataKernel>().loadCheckpoint(*this, stream);
integrator.loadCheckpoint(stream);
hasSetPositions = true;
integrator.stateChanged(State::Positions);
integrator.stateChanged(State::Velocities);
......
......@@ -114,6 +114,31 @@ bool CustomIntegrator::kineticEnergyRequiresForce() const {
return keNeedsForce;
}
void CustomIntegrator::createCheckpoint(std::ostream& stream) const {
for (int i = 0; i < getNumGlobalVariables(); i++) {
double value = getGlobalVariable(i);
stream.write((char*) &value, sizeof(double));
}
vector<Vec3> values;
for (int i = 0; i < getNumPerDofVariables(); i++) {
getPerDofVariable(i, values);
stream.write((char*) values.data(), sizeof(Vec3)*values.size());
}
}
void CustomIntegrator::loadCheckpoint(std::istream& stream) {
double value;
for (int i = 0; i < getNumGlobalVariables(); i++) {
stream.read((char*) &value, sizeof(double));
setGlobalVariable(i, value);
}
vector<Vec3> values(context->getSystem().getNumParticles());
for (int i = 0; i < getNumPerDofVariables(); i++) {
stream.read((char*) values.data(), sizeof(Vec3)*values.size());
setPerDofVariable(i, values);
}
}
void CustomIntegrator::step(int steps) {
if (context == NULL)
throw OpenMMException("This Integrator is not bound to a context!");
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2019 Stanford University and the Authors. *
* Portions copyright (c) 2008-2020 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -48,6 +48,7 @@
#include <algorithm>
#include <cmath>
#include <iostream>
#include <sstream>
#include <vector>
using namespace OpenMM;
......@@ -1156,6 +1157,31 @@ void testInitialTemperature() {
ASSERT_USUALLY_EQUAL_TOL(targetTemperature, temperature, 0.01);
}
void testCheckpoint() {
// Test that integrator variables get loaded correctly from checkpoints.
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("a", 1.0);
integrator.addPerDofVariable("b", 2.0);
Context context(system, integrator, platform);
vector<Vec3> positions(1, Vec3());
context.setPositions(positions);
integrator.setGlobalVariable(0, 5.0);
vector<Vec3> b1(1, Vec3(1, 2, 3));
integrator.setPerDofVariable(0, b1);
stringstream checkpoint;
context.createCheckpoint(checkpoint);
integrator.setGlobalVariable(0, 10.0);
vector<Vec3> b2(1, Vec3(4, 5, 6));
integrator.setPerDofVariable(0, b2);
context.loadCheckpoint(checkpoint);
ASSERT_EQUAL(5.0, integrator.getGlobalVariable(0));
vector<Vec3> b3;
integrator.getPerDofVariable(0, b3);
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -1184,6 +1210,7 @@ int main(int argc, char* argv[]) {
testVectorFunctions();
testRecordEnergy();
testInitialTemperature();
testCheckpoint();
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment