Commit 694c3930 authored by peastman's avatar peastman
Browse files

States can save integrator parameters

parent 03861778
......@@ -1136,6 +1136,22 @@ public:
* Load the chain states from a checkpoint.
*/
virtual void loadCheckpoint(ContextImpl& context, std::istream& stream) = 0;
/**
* Get the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
virtual void getChainStates(ContextImpl& context, std::vector<std::vector<double> >& positions, std::vector<std::vector<double> >& velocities) const = 0;
/**
* Set the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
virtual void setChainStates(ContextImpl& context, const std::vector<std::vector<double> >& positions, const std::vector<std::vector<double> >& velocities) = 0;
};
/**
......
......@@ -204,6 +204,19 @@ protected:
* data it wrote in createCheckpoint() and update its internal state accordingly.
*/
void loadCheckpoint(std::istream& stream);
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
void serializeParameters(SerializationNode& node) const;
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
void deserializeParameters(const SerializationNode& node);
private:
int currentIntegrator;
std::vector<Integrator*> integrators;
......
......@@ -678,6 +678,19 @@ protected:
* data it wrote in createCheckpoint() and update its internal state accordingly.
*/
void loadCheckpoint(std::istream& stream);
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
void serializeParameters(SerializationNode& node) const;
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
void deserializeParameters(const SerializationNode& node);
private:
class ComputationInfo;
class FunctionInfo;
......
......@@ -34,6 +34,7 @@
#include "State.h"
#include "Vec3.h"
#include "openmm/serialization/SerializationNode.h"
#include <iosfwd>
#include <map>
#include <vector>
......@@ -176,6 +177,21 @@ protected:
*/
virtual void loadCheckpoint(std::istream& stream) {
}
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
virtual void serializeParameters(SerializationNode& node) const {
}
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
virtual void deserializeParameters(const SerializationNode& node) {
}
private:
double stepSize, constraintTol;
int forceGroups;
......
......@@ -9,9 +9,9 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2019 Stanford University and the Authors. *
* Portions copyright (c) 2019-2020 Stanford University and the Authors. *
* Authors: Andreas Krämer and Andrew C. Simmonett *
* Contributors: *
* Contributors: Peter Eastman *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
......@@ -270,6 +270,19 @@ protected:
* data it wrote in createCheckpoint() and update its internal state accordingly.
*/
void loadCheckpoint(std::istream& stream);
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
void serializeParameters(SerializationNode& node) const;
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
void deserializeParameters(const SerializationNode& node);
std::vector<NoseHooverChain> noseHooverChains;
std::vector<int> allAtoms;
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2020 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -33,6 +33,7 @@
* -------------------------------------------------------------------------- */
#include "Vec3.h"
#include "openmm/serialization/SerializationNode.h"
#include <map>
#include <string>
#include <vector>
......@@ -58,7 +59,7 @@ public:
* This is an enumeration of the types of data which may be stored in a State. When you create
* a State, use these values to specify which data types it should contain.
*/
enum DataType {Positions=1, Velocities=2, Forces=4, Energy=8, Parameters=16, ParameterDerivatives=32};
enum DataType {Positions=1, Velocities=2, Forces=4, Energy=8, Parameters=16, ParameterDerivatives=32, IntegratorParameters=64};
/**
* Construct an empty State containing no data. This exists so State objects can be used in STL containers.
*/
......@@ -124,6 +125,8 @@ public:
*/
int getDataTypes() const;
private:
friend class Context;
friend class StateProxy;
State(double time);
void setPositions(const std::vector<Vec3>& pos);
void setVelocities(const std::vector<Vec3>& vel);
......@@ -132,6 +135,8 @@ private:
void setEnergyParameterDerivatives(const std::map<std::string, double>& derivs);
void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
SerializationNode& updateIntegratorParameters();
const SerializationNode& getIntegratorParameters() const;
int types;
double time, ke, pe;
std::vector<Vec3> positions;
......@@ -139,6 +144,7 @@ private:
std::vector<Vec3> forces;
Vec3 periodicBoxVectors[3];
std::map<std::string, double> parameters, energyParameterDerivatives;
SerializationNode integratorParameters;
};
/**
......@@ -157,6 +163,7 @@ public:
void setEnergyParameterDerivatives(const std::map<std::string, double>& params);
void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
SerializationNode& updateIntegratorParameters();
private:
State state;
};
......
......@@ -140,3 +140,22 @@ void CompoundIntegrator::loadCheckpoint(std::istream& stream) {
for (int i = 0; i < integrators.size(); i++)
integrators[i]->loadCheckpoint(stream);
}
void CompoundIntegrator::serializeParameters(SerializationNode& node) const {
node.setIntProperty("version", 1);
node.setIntProperty("currentIntegrator", currentIntegrator);
for (int i = 0; i < getNumIntegrators(); i++) {
SerializationNode& child = node.createChildNode("IntegratorParameters");
integrators[i]->serializeParameters(child);
}
}
void CompoundIntegrator::deserializeParameters(const SerializationNode& node) {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
if (node.getChildren().size() != getNumIntegrators())
throw OpenMMException("State has wrong number of integrators for CompoundIntegrator");
setCurrentIntegrator(node.getIntProperty("currentIntegrator"));
for (int i = 0; i < node.getChildren().size(); i++)
integrators[i]->deserializeParameters(node.getChildren()[i]);
}
......@@ -147,6 +147,9 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
impl->getVelocities(velocities);
builder.setVelocities(velocities);
}
if (types&State::IntegratorParameters) {
getIntegrator().serializeParameters(builder.updateIntegratorParameters());
}
return builder.getState();
}
......@@ -162,6 +165,8 @@ void Context::setState(const State& state) {
if ((state.getDataTypes()&State::Parameters) != 0)
for (auto& param : state.getParameters())
setParameter(param.first, param.second);
if ((state.getDataTypes()&State::IntegratorParameters) != 0)
getIntegrator().deserializeParameters(state.getIntegratorParameters());
}
void Context::setTime(double time) {
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2011-2019 Stanford University and the Authors. *
* Portions copyright (c) 2011-2020 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -348,3 +348,33 @@ void CustomIntegrator::setKineticEnergyExpression(const string& expression) {
Lepton::CompiledExpression expr = Lepton::Parser::parse(kineticEnergy).createCompiledExpression();
keNeedsForce = (expr.getVariables().find("f") != expr.getVariables().end());
}
void CustomIntegrator::serializeParameters(SerializationNode& node) const {
node.setIntProperty("version", 1);
SerializationNode& globalVariablesNode = node.createChildNode("GlobalVariables");
for (int i = 0; i < getNumGlobalVariables(); i++)
globalVariablesNode.setDoubleProperty(getGlobalVariableName(i), getGlobalVariable(i));
SerializationNode& perDofVariablesNode = node.createChildNode("PerDofVariables");
for (int i = 0; i < getNumPerDofVariables(); i++) {
SerializationNode& perDofValuesNode = perDofVariablesNode.createChildNode(getPerDofVariableName(i));
vector<Vec3> perDofValues;
getPerDofVariable(i, perDofValues);
for (int j = 0; j < perDofValues.size(); j++)
perDofValuesNode.createChildNode("Value").setDoubleProperty("x",perDofValues[j][0]).setDoubleProperty("y",perDofValues[j][1]).setDoubleProperty("z",perDofValues[j][2]);
}
}
void CustomIntegrator::deserializeParameters(const SerializationNode& node) {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& globalVariablesNode = node.getChildNode("GlobalVariables");
for (auto& prop : globalVariablesNode.getProperties())
setGlobalVariableByName(prop.first, globalVariablesNode.getDoubleProperty(prop.first));
const SerializationNode& perDofVariablesNode = node.getChildNode("PerDofVariables");
for (auto& var : perDofVariablesNode.getChildren()) {
vector<Vec3> perDofValues;
for (auto& child : var.getChildren())
perDofValues.push_back(Vec3(child.getDoubleProperty("x"), child.getDoubleProperty("y"), child.getDoubleProperty("z")));
setPerDofVariableByName(var.getName(), perDofValues);
}
}
......@@ -8,7 +8,7 @@
* *
* Portions copyright (c) 2019-2020 Stanford University and the Authors. *
* Authors: Andreas Krämer and Andrew C. Simmonett *
* Contributors: *
* Contributors: Peter Eastman *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
......@@ -348,4 +348,30 @@ void NoseHooverIntegrator::createCheckpoint(std::ostream& stream) const {
void NoseHooverIntegrator::loadCheckpoint(std::istream& stream) {
kernel.getAs<IntegrateNoseHooverStepKernel>().loadCheckpoint(*context, stream);
}
\ No newline at end of file
}
void NoseHooverIntegrator::serializeParameters(SerializationNode& node) const {
node.setIntProperty("version", 1);
vector<vector<double> > positions, velocities;
kernel.getAs<IntegrateNoseHooverStepKernel>().getChainStates(*context, positions, velocities);
for (int i = 0; i < positions.size(); i++) {
SerializationNode& chain = node.createChildNode("Chain");
for (int j = 0; j < positions[i].size(); j++)
chain.createChildNode("Bead").setDoubleProperty("position", positions[i][j]).setDoubleProperty("velocity", velocities[i][j]);
}
}
void NoseHooverIntegrator::deserializeParameters(const SerializationNode& node) {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
int numChains = node.getChildren().size();
vector<vector<double> > positions(numChains), velocities(numChains);
for (int i = 0; i < numChains; i++) {
auto& chain = node.getChildren()[i];
for (auto& bead : chain.getChildren()) {
positions[i].push_back(bead.getDoubleProperty("position"));
velocities[i].push_back(bead.getDoubleProperty("velocity"));
}
}
kernel.getAs<IntegrateNoseHooverStepKernel>().setChainStates(*context, positions, velocities);
}
......@@ -81,6 +81,16 @@ const map<string, double>& State::getEnergyParameterDerivatives() const {
throw OpenMMException("Invoked getEnergyParameterDerivatives() on a State which does not contain parameter derivatives.");
return energyParameterDerivatives;
}
const SerializationNode& State::getIntegratorParameters() const {
if ((types&IntegratorParameters) == 0)
throw OpenMMException("Invoked getPIntegratorarameters() on a State which does not contain integrator parameters.");
return integratorParameters;
}
SerializationNode& State::updateIntegratorParameters() {
types |= IntegratorParameters;
integratorParameters.setName("IntegratorParameters");
return integratorParameters;
}
int State::getDataTypes() const {
return types;
}
......@@ -159,3 +169,7 @@ void State::StateBuilder::setEnergy(double ke, double pe) {
void State::StateBuilder::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
state.setPeriodicBoxVectors(a, b, c);
}
SerializationNode& State::StateBuilder::updateIntegratorParameters() {
return state.updateIntegratorParameters();
}
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2010-2015 Stanford University and the Authors. *
* Portions copyright (c) 2010-2020 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -89,6 +89,9 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]);
}
}
if ((s.getDataTypes()&State::IntegratorParameters) != 0) {
node.getChildren().push_back(s.getIntegratorParameters());
}
}
void* StateProxy::deserialize(const SerializationNode& node) const {
......@@ -138,6 +141,9 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
builder.setForces(outForces);
arraySizes.push_back(outForces.size());
}
else if (child.getName() == "IntegratorParameters") {
builder.updateIntegratorParameters() = child;
}
}
for (int i = 1; i < arraySizes.size(); i++) {
if (arraySizes[i] != arraySizes[i-1]) {
......
......@@ -34,6 +34,7 @@
#include "openmm/NonbondedForce.h"
#include "openmm/System.h"
#include "openmm/Context.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/LangevinIntegrator.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/MonteCarloBarostat.h"
......@@ -48,9 +49,9 @@ using namespace std;
void testSerialization() {
// Create a System.
const int numParticles=50;
const int numParticles=50;
System system;
system.setDefaultPeriodicBoxVectors(Vec3(6.2, 0, 0), Vec3(0, 6.2, 0), Vec3(0, 0, 6.2 ));
system.setDefaultPeriodicBoxVectors(Vec3(6.2, 0, 0), Vec3(0, 6.2, 0), Vec3(0, 0, 6.2 ));
NonbondedForce* nonbonded = new NonbondedForce();
nonbonded->setNonbondedMethod(NonbondedForce::Ewald);
nonbonded->setCutoffDistance(0.8);
......@@ -64,74 +65,74 @@ void testSerialization() {
for (int i = 0; i < numParticles/2; i++)
nonbonded->addParticle(-1.0, 1.0,0.0);
system.addForce(nonbonded);
system.addForce(new AndersenThermostat(393.3, 19.3));
system.addForce(new MonteCarloBarostat(25, 393.3, 25));
LangevinIntegrator intg(300,79,0.002);
Context context(system, intg);
// Set positions, velocities, forces
vector<Vec3> positions;
for (int i = 0; i < numParticles; i++) {
positions.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2));
}
vector<Vec3> velocities;
for (int i = 0; i < numParticles; i++) {
velocities.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2));
}
context.setPositions(positions);
context.setVelocities(velocities);
// Serialize and then deserialize it.
State s1 = context.getState(State::Positions | State::Velocities | State::Forces | State::Energy | State::Parameters);
stringstream buffer;
system.addForce(new AndersenThermostat(393.3, 19.3));
system.addForce(new MonteCarloBarostat(25, 393.3, 25));
LangevinIntegrator intg(300,79,0.002);
Context context(system, intg);
// Set positions, velocities, forces
vector<Vec3> positions;
for (int i = 0; i < numParticles; i++) {
positions.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2));
}
vector<Vec3> velocities;
for (int i = 0; i < numParticles; i++) {
velocities.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2));
}
context.setPositions(positions);
context.setVelocities(velocities);
// Serialize and then deserialize it.
State s1 = context.getState(State::Positions | State::Velocities | State::Forces | State::Energy | State::Parameters);
stringstream buffer;
XmlSerializer::serialize<State>(&s1, "State", buffer);
State* copy = XmlSerializer::deserialize<State>(buffer);
State& s2 = *copy;
State& s2 = *copy;
// Compare the two states to see if they are identical.
vector<Vec3> pos1 = s1.getPositions();
vector<Vec3> pos2 = s2.getPositions();
ASSERT_EQUAL(pos1.size(), pos2.size());
ASSERT_EQUAL(pos1.size(), positions.size());
for (int i = 0; i < (int) pos1.size(); i++) {
ASSERT_EQUAL_VEC(pos1[i],pos2[i],0);
}
vector<Vec3> vel1 = s1.getVelocities();
vector<Vec3> vel2 = s2.getVelocities();
ASSERT_EQUAL(vel1.size(), vel2.size());
for (int i = 0; i < (int) pos1.size(); i++) {
ASSERT_EQUAL_VEC(vel1[i],vel2[i],0);
}
vector<Vec3> forces1 = s1.getForces();
vector<Vec3> forces2 = s2.getForces();
ASSERT_EQUAL(forces1.size(), forces2.size());
for (int i = 0; i < (int) pos1.size(); i++) {
ASSERT_EQUAL_VEC(forces1[i],forces2[i],0);
}
Vec3 a1,a2,a3,b1,b2,b3;
s1.getPeriodicBoxVectors(a1,a2,a3);
s2.getPeriodicBoxVectors(b1,b2,b3);
ASSERT_EQUAL_VEC(a1,b1,0);
ASSERT_EQUAL_VEC(a2,b2,0);
ASSERT_EQUAL_VEC(a3,b3,0);
ASSERT_EQUAL(s1.getPotentialEnergy(), s2.getPotentialEnergy());
ASSERT_EQUAL(s1.getKineticEnergy(), s2.getKineticEnergy());
ASSERT_EQUAL(s1.getTime(), s2.getTime());
map<string, double> p1 = s1.getParameters();
map<string, double> p2 = s2.getParameters();
ASSERT_EQUAL(p1.size(), p2.size());
map<string, double>::const_iterator it1=p1.begin();
map<string, double>::const_iterator it2=p2.begin();
//maps are ordered, so iterators should be in the same order.
for (it1 = p1.begin(); it1 != p1.end(); ++it1, ++it2) {
assert((it1->first).compare(it2->first) == 0);
ASSERT_EQUAL(it1->second, it2->second);
}
vector<Vec3> pos1 = s1.getPositions();
vector<Vec3> pos2 = s2.getPositions();
ASSERT_EQUAL(pos1.size(), pos2.size());
ASSERT_EQUAL(pos1.size(), positions.size());
for (int i = 0; i < (int) pos1.size(); i++) {
ASSERT_EQUAL_VEC(pos1[i],pos2[i],0);
}
vector<Vec3> vel1 = s1.getVelocities();
vector<Vec3> vel2 = s2.getVelocities();
ASSERT_EQUAL(vel1.size(), vel2.size());
for (int i = 0; i < (int) pos1.size(); i++) {
ASSERT_EQUAL_VEC(vel1[i],vel2[i],0);
}
vector<Vec3> forces1 = s1.getForces();
vector<Vec3> forces2 = s2.getForces();
ASSERT_EQUAL(forces1.size(), forces2.size());
for (int i = 0; i < (int) pos1.size(); i++) {
ASSERT_EQUAL_VEC(forces1[i],forces2[i],0);
}
Vec3 a1,a2,a3,b1,b2,b3;
s1.getPeriodicBoxVectors(a1,a2,a3);
s2.getPeriodicBoxVectors(b1,b2,b3);
ASSERT_EQUAL_VEC(a1,b1,0);
ASSERT_EQUAL_VEC(a2,b2,0);
ASSERT_EQUAL_VEC(a3,b3,0);
ASSERT_EQUAL(s1.getPotentialEnergy(), s2.getPotentialEnergy());
ASSERT_EQUAL(s1.getKineticEnergy(), s2.getKineticEnergy());
ASSERT_EQUAL(s1.getTime(), s2.getTime());
map<string, double> p1 = s1.getParameters();
map<string, double> p2 = s2.getParameters();
ASSERT_EQUAL(p1.size(), p2.size());
map<string, double>::const_iterator it1=p1.begin();
map<string, double>::const_iterator it2=p2.begin();
//maps are ordered, so iterators should be in the same order.
for (it1 = p1.begin(); it1 != p1.end(); ++it1, ++it2) {
assert((it1->first).compare(it2->first) == 0);
ASSERT_EQUAL(it1->second, it2->second);
}
delete copy;
// Now create a series of States that include only one type of information. Verify
......@@ -183,13 +184,46 @@ void testSerialization() {
}
}
void testIntegratorParameters() {
// Create a Context involving integrator parameters.
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("a", 1.0);
integrator.addPerDofVariable("b", 2.0);
Context context(system, integrator);
integrator.setGlobalVariable(0, 3.0);
integrator.setPerDofVariable(0, {Vec3(1.0, 2.0, 3.0)});
// Create a State, then serialize and deserialize it.
State s1 = context.getState(State::IntegratorParameters);
stringstream buffer;
XmlSerializer::serialize<State>(&s1, "State", buffer);
State* copy = XmlSerializer::deserialize<State>(buffer);
State& s2 = *copy;
// Set the State on a new Context and make sure all the integrator parameters
// survived the serialization and deserialization.
CustomIntegrator* integrator2 = XmlSerializer::clone(integrator);
Context context2(system, *integrator2);
context2.setState(s2);
ASSERT_EQUAL(3.0, integrator2->getGlobalVariable(0));
vector<Vec3> values;
integrator2->getPerDofVariable(0, values);
ASSERT_EQUAL_VEC(Vec3(1.0, 2.0, 3.0), values[0], 1e-6);
}
int main() {
try {
testSerialization();
testIntegratorParameters();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
return 1;
}
cout << "Done" << endl;
return 0;
......
......@@ -239,6 +239,35 @@ void testCheckpoint() {
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void testSaveParameters() {
// Test that integrator variables get loaded correctly from States.
System system;
system.addParticle(1.0);
CustomIntegrator* custom = new CustomIntegrator(0.001);
custom->addGlobalVariable("a", 1.0);
custom->addPerDofVariable("b", 2.0);
CompoundIntegrator integrator;
integrator.addIntegrator(custom);
integrator.addIntegrator(new VerletIntegrator(0.005));
Context context(system, integrator, platform);
vector<Vec3> positions(1, Vec3());
context.setPositions(positions);
custom->setGlobalVariable(0, 5.0);
vector<Vec3> b1(1, Vec3(1, 2, 3));
custom->setPerDofVariable(0, b1);
State savedState = context.getState(State::IntegratorParameters);
custom->setGlobalVariable(0, 10.0);
vector<Vec3> b2(1, Vec3(4, 5, 6));
custom->setPerDofVariable(0, b2);
integrator.setCurrentIntegrator(1);
context.setState(savedState);
ASSERT_EQUAL(0, integrator.getCurrentIntegrator());
ASSERT_EQUAL(5.0, custom->getGlobalVariable(0));
vector<Vec3> b3;
custom->getPerDofVariable(0, b3);
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -248,6 +277,7 @@ int main(int argc, char* argv[]) {
testChangingParameters();
testDifferentStepSizes();
testCheckpoint();
testSaveParameters();
runPlatformTests();
}
catch(const exception& e) {
......
......@@ -1187,6 +1187,30 @@ void testCheckpoint() {
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void testSaveParameters() {
// Test that integrator variables get loaded correctly from States.
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("a", 1.0);
integrator.addPerDofVariable("b", 2.0);
Context context(system, integrator, platform);
vector<Vec3> positions(1, Vec3());
context.setPositions(positions);
integrator.setGlobalVariable(0, 5.0);
vector<Vec3> b1(1, Vec3(1, 2, 3));
integrator.setPerDofVariable(0, b1);
State savedState = context.getState(State::IntegratorParameters);
integrator.setGlobalVariable(0, 10.0);
vector<Vec3> b2(1, Vec3(4, 5, 6));
integrator.setPerDofVariable(0, b2);
context.setState(savedState);
ASSERT_EQUAL(5.0, integrator.getGlobalVariable(0));
vector<Vec3> b3;
integrator.getPerDofVariable(0, b3);
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -1216,6 +1240,7 @@ int main(int argc, char* argv[]) {
testRecordEnergy();
testInitialTemperature();
testCheckpoint();
testSaveParameters();
runPlatformTests();
}
catch(const exception& e) {
......
......@@ -499,6 +499,57 @@ void testCheckpoints() {
ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6);
}
void testSaveParameters() {
// Create a system with Drude-like particles to be thermostated as a pair, as well as another
// particle to be thermostated independently, to test all integrator features.
double timeStep = 0.001;
NoseHooverIntegrator integrator(timeStep), newIntegrator(timeStep);
System system;
double mass = 1;
system.addParticle(8*mass);
system.addParticle(mass);
system.addParticle(5*mass);
HarmonicBondForce* force = new HarmonicBondForce();
force->addBond(0, 1, 0.1, 50.0);
force->addBond(0, 2, 0.1, 50.0);
system.addForce(force);
double kineticEnergy = 1e6;
double temperature=300, collisionFrequency=1, chainLength=3, numMTS=3, numYS=3;
chainLength = 10;
integrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
newIntegrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
Context context(system, integrator, platform);
Context newContext(system, newIntegrator, platform);
std::vector<Vec3> positions(3);
std::vector<Vec3> velocities(3);
positions[1] = {0.1, 0.0, 0.0};
velocities[1] = {0.1,0.2,-0.2};
positions[2] = {-0.1, 0.001, 0.001};
velocities[2] = {-0.1,0.2,-0.2};
context.setPositions(positions);
context.setVelocities(velocities);
// Run a short simulation and save a state..
integrator.step(500);
State savedState = context.getState(State::Positions | State::Velocities | State::IntegratorParameters);
// Now continue the simulation
integrator.step(5);
// And try the same, starting from the state
newContext.setState(savedState);
newIntegrator.step(5);
State state1 = context.getState(State::Positions | State::Velocities);
State state2 = newContext.getState(State::Positions | State::Velocities);
ASSERT_EQUAL_VEC(state1.getPositions()[0], state2.getPositions()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getPositions()[1], state2.getPositions()[1], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[0], state2.getVelocities()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6);
}
void testAPIChangeNumParticles() {
bool constrain = true;
int numMolecules = 20;
......@@ -553,6 +604,7 @@ int main(int argc, char* argv[]) {
constrain = false; testDimerBox(constrain);
constrain = true; testDimerBox(constrain);
testCheckpoints();
testSaveParameters();
testForceGroups();
runPlatformTests();
}
......
......@@ -312,7 +312,7 @@ class Simulation(object):
a File-like object to write the state to, or alternatively a
filename
"""
state = self.context.getState(getPositions=True, getVelocities=True, getParameters=True)
state = self.context.getState(getPositions=True, getVelocities=True, getParameters=True, getIntegratorParameters=True)
xml = mm.XmlSerializer.serialize(state)
if isinstance(file, str):
with open(file, 'w') as f:
......
......@@ -11,7 +11,8 @@
def getState(self, getPositions=False, getVelocities=False,
getForces=False, getEnergy=False, getParameters=False,
getParameterDerivatives=False, enforcePeriodicBox=False, groups=-1):
getParameterDerivatives=False, getIntegratorParameters=False,
enforcePeriodicBox=False, groups=-1):
"""Get a State object recording the current state information stored in this context.
Parameters
......@@ -28,6 +29,8 @@
whether to store context parameters in the State
getParameterDerivatives : bool=False
whether to store parameter derivatives in the State
getIntegratorParameters : bool=False
whether to store integrator parameters in the State
enforcePeriodicBox : bool=False
if false, the position of each particle will be whatever position
is stored in the Context, regardless of periodic boundary conditions.
......@@ -64,6 +67,8 @@
types += State.Parameters
if getParameterDerivatives:
types += State.ParameterDerivatives
if getIntegratorParameters:
types += State.IntegratorParameters
state = _openmm.Context_getState(self, types, enforcePeriodicBox, groups_mask)
return state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment