Commit 694c3930 authored by peastman's avatar peastman
Browse files

States can save integrator parameters

parent 03861778
...@@ -1136,6 +1136,22 @@ public: ...@@ -1136,6 +1136,22 @@ public:
* Load the chain states from a checkpoint. * Load the chain states from a checkpoint.
*/ */
virtual void loadCheckpoint(ContextImpl& context, std::istream& stream) = 0; virtual void loadCheckpoint(ContextImpl& context, std::istream& stream) = 0;
/**
* Get the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
virtual void getChainStates(ContextImpl& context, std::vector<std::vector<double> >& positions, std::vector<std::vector<double> >& velocities) const = 0;
/**
* Set the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
virtual void setChainStates(ContextImpl& context, const std::vector<std::vector<double> >& positions, const std::vector<std::vector<double> >& velocities) = 0;
}; };
/** /**
......
...@@ -204,6 +204,19 @@ protected: ...@@ -204,6 +204,19 @@ protected:
* data it wrote in createCheckpoint() and update its internal state accordingly. * data it wrote in createCheckpoint() and update its internal state accordingly.
*/ */
void loadCheckpoint(std::istream& stream); void loadCheckpoint(std::istream& stream);
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
void serializeParameters(SerializationNode& node) const;
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
void deserializeParameters(const SerializationNode& node);
private: private:
int currentIntegrator; int currentIntegrator;
std::vector<Integrator*> integrators; std::vector<Integrator*> integrators;
......
...@@ -678,6 +678,19 @@ protected: ...@@ -678,6 +678,19 @@ protected:
* data it wrote in createCheckpoint() and update its internal state accordingly. * data it wrote in createCheckpoint() and update its internal state accordingly.
*/ */
void loadCheckpoint(std::istream& stream); void loadCheckpoint(std::istream& stream);
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
void serializeParameters(SerializationNode& node) const;
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
void deserializeParameters(const SerializationNode& node);
private: private:
class ComputationInfo; class ComputationInfo;
class FunctionInfo; class FunctionInfo;
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "State.h" #include "State.h"
#include "Vec3.h" #include "Vec3.h"
#include "openmm/serialization/SerializationNode.h"
#include <iosfwd> #include <iosfwd>
#include <map> #include <map>
#include <vector> #include <vector>
...@@ -176,6 +177,21 @@ protected: ...@@ -176,6 +177,21 @@ protected:
*/ */
virtual void loadCheckpoint(std::istream& stream) { virtual void loadCheckpoint(std::istream& stream) {
} }
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
virtual void serializeParameters(SerializationNode& node) const {
}
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
virtual void deserializeParameters(const SerializationNode& node) {
}
private: private:
double stepSize, constraintTol; double stepSize, constraintTol;
int forceGroups; int forceGroups;
......
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2019 Stanford University and the Authors. * * Portions copyright (c) 2019-2020 Stanford University and the Authors. *
* Authors: Andreas Krämer and Andrew C. Simmonett * * Authors: Andreas Krämer and Andrew C. Simmonett *
* Contributors: * * Contributors: Peter Eastman *
* * * *
* Permission is hereby granted, free of charge, to any person obtaining a * * Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), * * copy of this software and associated documentation files (the "Software"), *
...@@ -270,6 +270,19 @@ protected: ...@@ -270,6 +270,19 @@ protected:
* data it wrote in createCheckpoint() and update its internal state accordingly. * data it wrote in createCheckpoint() and update its internal state accordingly.
*/ */
void loadCheckpoint(std::istream& stream); void loadCheckpoint(std::istream& stream);
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
void serializeParameters(SerializationNode& node) const;
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
void deserializeParameters(const SerializationNode& node);
std::vector<NoseHooverChain> noseHooverChains; std::vector<NoseHooverChain> noseHooverChains;
std::vector<int> allAtoms; std::vector<int> allAtoms;
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2020 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "Vec3.h" #include "Vec3.h"
#include "openmm/serialization/SerializationNode.h"
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -58,7 +59,7 @@ public: ...@@ -58,7 +59,7 @@ public:
* This is an enumeration of the types of data which may be stored in a State. When you create * This is an enumeration of the types of data which may be stored in a State. When you create
* a State, use these values to specify which data types it should contain. * a State, use these values to specify which data types it should contain.
*/ */
enum DataType {Positions=1, Velocities=2, Forces=4, Energy=8, Parameters=16, ParameterDerivatives=32}; enum DataType {Positions=1, Velocities=2, Forces=4, Energy=8, Parameters=16, ParameterDerivatives=32, IntegratorParameters=64};
/** /**
* Construct an empty State containing no data. This exists so State objects can be used in STL containers. * Construct an empty State containing no data. This exists so State objects can be used in STL containers.
*/ */
...@@ -124,6 +125,8 @@ public: ...@@ -124,6 +125,8 @@ public:
*/ */
int getDataTypes() const; int getDataTypes() const;
private: private:
friend class Context;
friend class StateProxy;
State(double time); State(double time);
void setPositions(const std::vector<Vec3>& pos); void setPositions(const std::vector<Vec3>& pos);
void setVelocities(const std::vector<Vec3>& vel); void setVelocities(const std::vector<Vec3>& vel);
...@@ -132,6 +135,8 @@ private: ...@@ -132,6 +135,8 @@ private:
void setEnergyParameterDerivatives(const std::map<std::string, double>& derivs); void setEnergyParameterDerivatives(const std::map<std::string, double>& derivs);
void setEnergy(double ke, double pe); void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c); void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
SerializationNode& updateIntegratorParameters();
const SerializationNode& getIntegratorParameters() const;
int types; int types;
double time, ke, pe; double time, ke, pe;
std::vector<Vec3> positions; std::vector<Vec3> positions;
...@@ -139,6 +144,7 @@ private: ...@@ -139,6 +144,7 @@ private:
std::vector<Vec3> forces; std::vector<Vec3> forces;
Vec3 periodicBoxVectors[3]; Vec3 periodicBoxVectors[3];
std::map<std::string, double> parameters, energyParameterDerivatives; std::map<std::string, double> parameters, energyParameterDerivatives;
SerializationNode integratorParameters;
}; };
/** /**
...@@ -157,6 +163,7 @@ public: ...@@ -157,6 +163,7 @@ public:
void setEnergyParameterDerivatives(const std::map<std::string, double>& params); void setEnergyParameterDerivatives(const std::map<std::string, double>& params);
void setEnergy(double ke, double pe); void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c); void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
SerializationNode& updateIntegratorParameters();
private: private:
State state; State state;
}; };
......
...@@ -140,3 +140,22 @@ void CompoundIntegrator::loadCheckpoint(std::istream& stream) { ...@@ -140,3 +140,22 @@ void CompoundIntegrator::loadCheckpoint(std::istream& stream) {
for (int i = 0; i < integrators.size(); i++) for (int i = 0; i < integrators.size(); i++)
integrators[i]->loadCheckpoint(stream); integrators[i]->loadCheckpoint(stream);
} }
void CompoundIntegrator::serializeParameters(SerializationNode& node) const {
node.setIntProperty("version", 1);
node.setIntProperty("currentIntegrator", currentIntegrator);
for (int i = 0; i < getNumIntegrators(); i++) {
SerializationNode& child = node.createChildNode("IntegratorParameters");
integrators[i]->serializeParameters(child);
}
}
void CompoundIntegrator::deserializeParameters(const SerializationNode& node) {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
if (node.getChildren().size() != getNumIntegrators())
throw OpenMMException("State has wrong number of integrators for CompoundIntegrator");
setCurrentIntegrator(node.getIntProperty("currentIntegrator"));
for (int i = 0; i < node.getChildren().size(); i++)
integrators[i]->deserializeParameters(node.getChildren()[i]);
}
...@@ -147,6 +147,9 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const { ...@@ -147,6 +147,9 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
impl->getVelocities(velocities); impl->getVelocities(velocities);
builder.setVelocities(velocities); builder.setVelocities(velocities);
} }
if (types&State::IntegratorParameters) {
getIntegrator().serializeParameters(builder.updateIntegratorParameters());
}
return builder.getState(); return builder.getState();
} }
...@@ -162,6 +165,8 @@ void Context::setState(const State& state) { ...@@ -162,6 +165,8 @@ void Context::setState(const State& state) {
if ((state.getDataTypes()&State::Parameters) != 0) if ((state.getDataTypes()&State::Parameters) != 0)
for (auto& param : state.getParameters()) for (auto& param : state.getParameters())
setParameter(param.first, param.second); setParameter(param.first, param.second);
if ((state.getDataTypes()&State::IntegratorParameters) != 0)
getIntegrator().deserializeParameters(state.getIntegratorParameters());
} }
void Context::setTime(double time) { void Context::setTime(double time) {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2019 Stanford University and the Authors. * * Portions copyright (c) 2011-2020 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -348,3 +348,33 @@ void CustomIntegrator::setKineticEnergyExpression(const string& expression) { ...@@ -348,3 +348,33 @@ void CustomIntegrator::setKineticEnergyExpression(const string& expression) {
Lepton::CompiledExpression expr = Lepton::Parser::parse(kineticEnergy).createCompiledExpression(); Lepton::CompiledExpression expr = Lepton::Parser::parse(kineticEnergy).createCompiledExpression();
keNeedsForce = (expr.getVariables().find("f") != expr.getVariables().end()); keNeedsForce = (expr.getVariables().find("f") != expr.getVariables().end());
} }
void CustomIntegrator::serializeParameters(SerializationNode& node) const {
node.setIntProperty("version", 1);
SerializationNode& globalVariablesNode = node.createChildNode("GlobalVariables");
for (int i = 0; i < getNumGlobalVariables(); i++)
globalVariablesNode.setDoubleProperty(getGlobalVariableName(i), getGlobalVariable(i));
SerializationNode& perDofVariablesNode = node.createChildNode("PerDofVariables");
for (int i = 0; i < getNumPerDofVariables(); i++) {
SerializationNode& perDofValuesNode = perDofVariablesNode.createChildNode(getPerDofVariableName(i));
vector<Vec3> perDofValues;
getPerDofVariable(i, perDofValues);
for (int j = 0; j < perDofValues.size(); j++)
perDofValuesNode.createChildNode("Value").setDoubleProperty("x",perDofValues[j][0]).setDoubleProperty("y",perDofValues[j][1]).setDoubleProperty("z",perDofValues[j][2]);
}
}
void CustomIntegrator::deserializeParameters(const SerializationNode& node) {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& globalVariablesNode = node.getChildNode("GlobalVariables");
for (auto& prop : globalVariablesNode.getProperties())
setGlobalVariableByName(prop.first, globalVariablesNode.getDoubleProperty(prop.first));
const SerializationNode& perDofVariablesNode = node.getChildNode("PerDofVariables");
for (auto& var : perDofVariablesNode.getChildren()) {
vector<Vec3> perDofValues;
for (auto& child : var.getChildren())
perDofValues.push_back(Vec3(child.getDoubleProperty("x"), child.getDoubleProperty("y"), child.getDoubleProperty("z")));
setPerDofVariableByName(var.getName(), perDofValues);
}
}
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
* * * *
* Portions copyright (c) 2019-2020 Stanford University and the Authors. * * Portions copyright (c) 2019-2020 Stanford University and the Authors. *
* Authors: Andreas Krämer and Andrew C. Simmonett * * Authors: Andreas Krämer and Andrew C. Simmonett *
* Contributors: * * Contributors: Peter Eastman *
* * * *
* Permission is hereby granted, free of charge, to any person obtaining a * * Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), * * copy of this software and associated documentation files (the "Software"), *
...@@ -348,4 +348,30 @@ void NoseHooverIntegrator::createCheckpoint(std::ostream& stream) const { ...@@ -348,4 +348,30 @@ void NoseHooverIntegrator::createCheckpoint(std::ostream& stream) const {
void NoseHooverIntegrator::loadCheckpoint(std::istream& stream) { void NoseHooverIntegrator::loadCheckpoint(std::istream& stream) {
kernel.getAs<IntegrateNoseHooverStepKernel>().loadCheckpoint(*context, stream); kernel.getAs<IntegrateNoseHooverStepKernel>().loadCheckpoint(*context, stream);
} }
\ No newline at end of file
void NoseHooverIntegrator::serializeParameters(SerializationNode& node) const {
node.setIntProperty("version", 1);
vector<vector<double> > positions, velocities;
kernel.getAs<IntegrateNoseHooverStepKernel>().getChainStates(*context, positions, velocities);
for (int i = 0; i < positions.size(); i++) {
SerializationNode& chain = node.createChildNode("Chain");
for (int j = 0; j < positions[i].size(); j++)
chain.createChildNode("Bead").setDoubleProperty("position", positions[i][j]).setDoubleProperty("velocity", velocities[i][j]);
}
}
void NoseHooverIntegrator::deserializeParameters(const SerializationNode& node) {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
int numChains = node.getChildren().size();
vector<vector<double> > positions(numChains), velocities(numChains);
for (int i = 0; i < numChains; i++) {
auto& chain = node.getChildren()[i];
for (auto& bead : chain.getChildren()) {
positions[i].push_back(bead.getDoubleProperty("position"));
velocities[i].push_back(bead.getDoubleProperty("velocity"));
}
}
kernel.getAs<IntegrateNoseHooverStepKernel>().setChainStates(*context, positions, velocities);
}
...@@ -81,6 +81,16 @@ const map<string, double>& State::getEnergyParameterDerivatives() const { ...@@ -81,6 +81,16 @@ const map<string, double>& State::getEnergyParameterDerivatives() const {
throw OpenMMException("Invoked getEnergyParameterDerivatives() on a State which does not contain parameter derivatives."); throw OpenMMException("Invoked getEnergyParameterDerivatives() on a State which does not contain parameter derivatives.");
return energyParameterDerivatives; return energyParameterDerivatives;
} }
const SerializationNode& State::getIntegratorParameters() const {
if ((types&IntegratorParameters) == 0)
throw OpenMMException("Invoked getPIntegratorarameters() on a State which does not contain integrator parameters.");
return integratorParameters;
}
SerializationNode& State::updateIntegratorParameters() {
types |= IntegratorParameters;
integratorParameters.setName("IntegratorParameters");
return integratorParameters;
}
int State::getDataTypes() const { int State::getDataTypes() const {
return types; return types;
} }
...@@ -159,3 +169,7 @@ void State::StateBuilder::setEnergy(double ke, double pe) { ...@@ -159,3 +169,7 @@ void State::StateBuilder::setEnergy(double ke, double pe) {
void State::StateBuilder::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) { void State::StateBuilder::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
state.setPeriodicBoxVectors(a, b, c); state.setPeriodicBoxVectors(a, b, c);
} }
SerializationNode& State::StateBuilder::updateIntegratorParameters() {
return state.updateIntegratorParameters();
}
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2015 Stanford University and the Authors. * * Portions copyright (c) 2010-2020 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -89,6 +89,9 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const { ...@@ -89,6 +89,9 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]); forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]);
} }
} }
if ((s.getDataTypes()&State::IntegratorParameters) != 0) {
node.getChildren().push_back(s.getIntegratorParameters());
}
} }
void* StateProxy::deserialize(const SerializationNode& node) const { void* StateProxy::deserialize(const SerializationNode& node) const {
...@@ -138,6 +141,9 @@ void* StateProxy::deserialize(const SerializationNode& node) const { ...@@ -138,6 +141,9 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
builder.setForces(outForces); builder.setForces(outForces);
arraySizes.push_back(outForces.size()); arraySizes.push_back(outForces.size());
} }
else if (child.getName() == "IntegratorParameters") {
builder.updateIntegratorParameters() = child;
}
} }
for (int i = 1; i < arraySizes.size(); i++) { for (int i = 1; i < arraySizes.size(); i++) {
if (arraySizes[i] != arraySizes[i-1]) { if (arraySizes[i] != arraySizes[i-1]) {
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "openmm/NonbondedForce.h" #include "openmm/NonbondedForce.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/LangevinIntegrator.h" #include "openmm/LangevinIntegrator.h"
#include "openmm/AndersenThermostat.h" #include "openmm/AndersenThermostat.h"
#include "openmm/MonteCarloBarostat.h" #include "openmm/MonteCarloBarostat.h"
...@@ -48,9 +49,9 @@ using namespace std; ...@@ -48,9 +49,9 @@ using namespace std;
void testSerialization() { void testSerialization() {
// Create a System. // Create a System.
const int numParticles=50; const int numParticles=50;
System system; System system;
system.setDefaultPeriodicBoxVectors(Vec3(6.2, 0, 0), Vec3(0, 6.2, 0), Vec3(0, 0, 6.2 )); system.setDefaultPeriodicBoxVectors(Vec3(6.2, 0, 0), Vec3(0, 6.2, 0), Vec3(0, 0, 6.2 ));
NonbondedForce* nonbonded = new NonbondedForce(); NonbondedForce* nonbonded = new NonbondedForce();
nonbonded->setNonbondedMethod(NonbondedForce::Ewald); nonbonded->setNonbondedMethod(NonbondedForce::Ewald);
nonbonded->setCutoffDistance(0.8); nonbonded->setCutoffDistance(0.8);
...@@ -64,74 +65,74 @@ void testSerialization() { ...@@ -64,74 +65,74 @@ void testSerialization() {
for (int i = 0; i < numParticles/2; i++) for (int i = 0; i < numParticles/2; i++)
nonbonded->addParticle(-1.0, 1.0,0.0); nonbonded->addParticle(-1.0, 1.0,0.0);
system.addForce(nonbonded); system.addForce(nonbonded);
system.addForce(new AndersenThermostat(393.3, 19.3)); system.addForce(new AndersenThermostat(393.3, 19.3));
system.addForce(new MonteCarloBarostat(25, 393.3, 25)); system.addForce(new MonteCarloBarostat(25, 393.3, 25));
LangevinIntegrator intg(300,79,0.002); LangevinIntegrator intg(300,79,0.002);
Context context(system, intg); Context context(system, intg);
// Set positions, velocities, forces // Set positions, velocities, forces
vector<Vec3> positions; vector<Vec3> positions;
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
positions.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2)); positions.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2));
} }
vector<Vec3> velocities; vector<Vec3> velocities;
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
velocities.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2)); velocities.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2));
} }
context.setPositions(positions); context.setPositions(positions);
context.setVelocities(velocities); context.setVelocities(velocities);
// Serialize and then deserialize it. // Serialize and then deserialize it.
State s1 = context.getState(State::Positions | State::Velocities | State::Forces | State::Energy | State::Parameters); State s1 = context.getState(State::Positions | State::Velocities | State::Forces | State::Energy | State::Parameters);
stringstream buffer; stringstream buffer;
XmlSerializer::serialize<State>(&s1, "State", buffer); XmlSerializer::serialize<State>(&s1, "State", buffer);
State* copy = XmlSerializer::deserialize<State>(buffer); State* copy = XmlSerializer::deserialize<State>(buffer);
State& s2 = *copy; State& s2 = *copy;
// Compare the two states to see if they are identical. // Compare the two states to see if they are identical.
vector<Vec3> pos1 = s1.getPositions(); vector<Vec3> pos1 = s1.getPositions();
vector<Vec3> pos2 = s2.getPositions(); vector<Vec3> pos2 = s2.getPositions();
ASSERT_EQUAL(pos1.size(), pos2.size()); ASSERT_EQUAL(pos1.size(), pos2.size());
ASSERT_EQUAL(pos1.size(), positions.size()); ASSERT_EQUAL(pos1.size(), positions.size());
for (int i = 0; i < (int) pos1.size(); i++) { for (int i = 0; i < (int) pos1.size(); i++) {
ASSERT_EQUAL_VEC(pos1[i],pos2[i],0); ASSERT_EQUAL_VEC(pos1[i],pos2[i],0);
} }
vector<Vec3> vel1 = s1.getVelocities(); vector<Vec3> vel1 = s1.getVelocities();
vector<Vec3> vel2 = s2.getVelocities(); vector<Vec3> vel2 = s2.getVelocities();
ASSERT_EQUAL(vel1.size(), vel2.size()); ASSERT_EQUAL(vel1.size(), vel2.size());
for (int i = 0; i < (int) pos1.size(); i++) { for (int i = 0; i < (int) pos1.size(); i++) {
ASSERT_EQUAL_VEC(vel1[i],vel2[i],0); ASSERT_EQUAL_VEC(vel1[i],vel2[i],0);
} }
vector<Vec3> forces1 = s1.getForces(); vector<Vec3> forces1 = s1.getForces();
vector<Vec3> forces2 = s2.getForces(); vector<Vec3> forces2 = s2.getForces();
ASSERT_EQUAL(forces1.size(), forces2.size()); ASSERT_EQUAL(forces1.size(), forces2.size());
for (int i = 0; i < (int) pos1.size(); i++) { for (int i = 0; i < (int) pos1.size(); i++) {
ASSERT_EQUAL_VEC(forces1[i],forces2[i],0); ASSERT_EQUAL_VEC(forces1[i],forces2[i],0);
} }
Vec3 a1,a2,a3,b1,b2,b3; Vec3 a1,a2,a3,b1,b2,b3;
s1.getPeriodicBoxVectors(a1,a2,a3); s1.getPeriodicBoxVectors(a1,a2,a3);
s2.getPeriodicBoxVectors(b1,b2,b3); s2.getPeriodicBoxVectors(b1,b2,b3);
ASSERT_EQUAL_VEC(a1,b1,0); ASSERT_EQUAL_VEC(a1,b1,0);
ASSERT_EQUAL_VEC(a2,b2,0); ASSERT_EQUAL_VEC(a2,b2,0);
ASSERT_EQUAL_VEC(a3,b3,0); ASSERT_EQUAL_VEC(a3,b3,0);
ASSERT_EQUAL(s1.getPotentialEnergy(), s2.getPotentialEnergy()); ASSERT_EQUAL(s1.getPotentialEnergy(), s2.getPotentialEnergy());
ASSERT_EQUAL(s1.getKineticEnergy(), s2.getKineticEnergy()); ASSERT_EQUAL(s1.getKineticEnergy(), s2.getKineticEnergy());
ASSERT_EQUAL(s1.getTime(), s2.getTime()); ASSERT_EQUAL(s1.getTime(), s2.getTime());
map<string, double> p1 = s1.getParameters(); map<string, double> p1 = s1.getParameters();
map<string, double> p2 = s2.getParameters(); map<string, double> p2 = s2.getParameters();
ASSERT_EQUAL(p1.size(), p2.size()); ASSERT_EQUAL(p1.size(), p2.size());
map<string, double>::const_iterator it1=p1.begin(); map<string, double>::const_iterator it1=p1.begin();
map<string, double>::const_iterator it2=p2.begin(); map<string, double>::const_iterator it2=p2.begin();
//maps are ordered, so iterators should be in the same order. //maps are ordered, so iterators should be in the same order.
for (it1 = p1.begin(); it1 != p1.end(); ++it1, ++it2) { for (it1 = p1.begin(); it1 != p1.end(); ++it1, ++it2) {
assert((it1->first).compare(it2->first) == 0); assert((it1->first).compare(it2->first) == 0);
ASSERT_EQUAL(it1->second, it2->second); ASSERT_EQUAL(it1->second, it2->second);
} }
delete copy; delete copy;
// Now create a series of States that include only one type of information. Verify // Now create a series of States that include only one type of information. Verify
...@@ -183,13 +184,46 @@ void testSerialization() { ...@@ -183,13 +184,46 @@ void testSerialization() {
} }
} }
void testIntegratorParameters() {
// Create a Context involving integrator parameters.
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("a", 1.0);
integrator.addPerDofVariable("b", 2.0);
Context context(system, integrator);
integrator.setGlobalVariable(0, 3.0);
integrator.setPerDofVariable(0, {Vec3(1.0, 2.0, 3.0)});
// Create a State, then serialize and deserialize it.
State s1 = context.getState(State::IntegratorParameters);
stringstream buffer;
XmlSerializer::serialize<State>(&s1, "State", buffer);
State* copy = XmlSerializer::deserialize<State>(buffer);
State& s2 = *copy;
// Set the State on a new Context and make sure all the integrator parameters
// survived the serialization and deserialization.
CustomIntegrator* integrator2 = XmlSerializer::clone(integrator);
Context context2(system, *integrator2);
context2.setState(s2);
ASSERT_EQUAL(3.0, integrator2->getGlobalVariable(0));
vector<Vec3> values;
integrator2->getPerDofVariable(0, values);
ASSERT_EQUAL_VEC(Vec3(1.0, 2.0, 3.0), values[0], 1e-6);
}
int main() { int main() {
try { try {
testSerialization(); testSerialization();
testIntegratorParameters();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
return 1; return 1;
} }
cout << "Done" << endl; cout << "Done" << endl;
return 0; return 0;
......
...@@ -239,6 +239,35 @@ void testCheckpoint() { ...@@ -239,6 +239,35 @@ void testCheckpoint() {
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6); ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
} }
void testSaveParameters() {
// Test that integrator variables get loaded correctly from States.
System system;
system.addParticle(1.0);
CustomIntegrator* custom = new CustomIntegrator(0.001);
custom->addGlobalVariable("a", 1.0);
custom->addPerDofVariable("b", 2.0);
CompoundIntegrator integrator;
integrator.addIntegrator(custom);
integrator.addIntegrator(new VerletIntegrator(0.005));
Context context(system, integrator, platform);
vector<Vec3> positions(1, Vec3());
context.setPositions(positions);
custom->setGlobalVariable(0, 5.0);
vector<Vec3> b1(1, Vec3(1, 2, 3));
custom->setPerDofVariable(0, b1);
State savedState = context.getState(State::IntegratorParameters);
custom->setGlobalVariable(0, 10.0);
vector<Vec3> b2(1, Vec3(4, 5, 6));
custom->setPerDofVariable(0, b2);
integrator.setCurrentIntegrator(1);
context.setState(savedState);
ASSERT_EQUAL(0, integrator.getCurrentIntegrator());
ASSERT_EQUAL(5.0, custom->getGlobalVariable(0));
vector<Vec3> b3;
custom->getPerDofVariable(0, b3);
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -248,6 +277,7 @@ int main(int argc, char* argv[]) { ...@@ -248,6 +277,7 @@ int main(int argc, char* argv[]) {
testChangingParameters(); testChangingParameters();
testDifferentStepSizes(); testDifferentStepSizes();
testCheckpoint(); testCheckpoint();
testSaveParameters();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -1187,6 +1187,30 @@ void testCheckpoint() { ...@@ -1187,6 +1187,30 @@ void testCheckpoint() {
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6); ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
} }
void testSaveParameters() {
// Test that integrator variables get loaded correctly from States.
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("a", 1.0);
integrator.addPerDofVariable("b", 2.0);
Context context(system, integrator, platform);
vector<Vec3> positions(1, Vec3());
context.setPositions(positions);
integrator.setGlobalVariable(0, 5.0);
vector<Vec3> b1(1, Vec3(1, 2, 3));
integrator.setPerDofVariable(0, b1);
State savedState = context.getState(State::IntegratorParameters);
integrator.setGlobalVariable(0, 10.0);
vector<Vec3> b2(1, Vec3(4, 5, 6));
integrator.setPerDofVariable(0, b2);
context.setState(savedState);
ASSERT_EQUAL(5.0, integrator.getGlobalVariable(0));
vector<Vec3> b3;
integrator.getPerDofVariable(0, b3);
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -1216,6 +1240,7 @@ int main(int argc, char* argv[]) { ...@@ -1216,6 +1240,7 @@ int main(int argc, char* argv[]) {
testRecordEnergy(); testRecordEnergy();
testInitialTemperature(); testInitialTemperature();
testCheckpoint(); testCheckpoint();
testSaveParameters();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -499,6 +499,57 @@ void testCheckpoints() { ...@@ -499,6 +499,57 @@ void testCheckpoints() {
ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6); ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6);
} }
void testSaveParameters() {
// Create a system with Drude-like particles to be thermostated as a pair, as well as another
// particle to be thermostated independently, to test all integrator features.
double timeStep = 0.001;
NoseHooverIntegrator integrator(timeStep), newIntegrator(timeStep);
System system;
double mass = 1;
system.addParticle(8*mass);
system.addParticle(mass);
system.addParticle(5*mass);
HarmonicBondForce* force = new HarmonicBondForce();
force->addBond(0, 1, 0.1, 50.0);
force->addBond(0, 2, 0.1, 50.0);
system.addForce(force);
double kineticEnergy = 1e6;
double temperature=300, collisionFrequency=1, chainLength=3, numMTS=3, numYS=3;
chainLength = 10;
integrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
newIntegrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
Context context(system, integrator, platform);
Context newContext(system, newIntegrator, platform);
std::vector<Vec3> positions(3);
std::vector<Vec3> velocities(3);
positions[1] = {0.1, 0.0, 0.0};
velocities[1] = {0.1,0.2,-0.2};
positions[2] = {-0.1, 0.001, 0.001};
velocities[2] = {-0.1,0.2,-0.2};
context.setPositions(positions);
context.setVelocities(velocities);
// Run a short simulation and save a state..
integrator.step(500);
State savedState = context.getState(State::Positions | State::Velocities | State::IntegratorParameters);
// Now continue the simulation
integrator.step(5);
// And try the same, starting from the state
newContext.setState(savedState);
newIntegrator.step(5);
State state1 = context.getState(State::Positions | State::Velocities);
State state2 = newContext.getState(State::Positions | State::Velocities);
ASSERT_EQUAL_VEC(state1.getPositions()[0], state2.getPositions()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getPositions()[1], state2.getPositions()[1], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[0], state2.getVelocities()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6);
}
void testAPIChangeNumParticles() { void testAPIChangeNumParticles() {
bool constrain = true; bool constrain = true;
int numMolecules = 20; int numMolecules = 20;
...@@ -553,6 +604,7 @@ int main(int argc, char* argv[]) { ...@@ -553,6 +604,7 @@ int main(int argc, char* argv[]) {
constrain = false; testDimerBox(constrain); constrain = false; testDimerBox(constrain);
constrain = true; testDimerBox(constrain); constrain = true; testDimerBox(constrain);
testCheckpoints(); testCheckpoints();
testSaveParameters();
testForceGroups(); testForceGroups();
runPlatformTests(); runPlatformTests();
} }
......
...@@ -312,7 +312,7 @@ class Simulation(object): ...@@ -312,7 +312,7 @@ class Simulation(object):
a File-like object to write the state to, or alternatively a a File-like object to write the state to, or alternatively a
filename filename
""" """
state = self.context.getState(getPositions=True, getVelocities=True, getParameters=True) state = self.context.getState(getPositions=True, getVelocities=True, getParameters=True, getIntegratorParameters=True)
xml = mm.XmlSerializer.serialize(state) xml = mm.XmlSerializer.serialize(state)
if isinstance(file, str): if isinstance(file, str):
with open(file, 'w') as f: with open(file, 'w') as f:
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
def getState(self, getPositions=False, getVelocities=False, def getState(self, getPositions=False, getVelocities=False,
getForces=False, getEnergy=False, getParameters=False, getForces=False, getEnergy=False, getParameters=False,
getParameterDerivatives=False, enforcePeriodicBox=False, groups=-1): getParameterDerivatives=False, getIntegratorParameters=False,
enforcePeriodicBox=False, groups=-1):
"""Get a State object recording the current state information stored in this context. """Get a State object recording the current state information stored in this context.
Parameters Parameters
...@@ -28,6 +29,8 @@ ...@@ -28,6 +29,8 @@
whether to store context parameters in the State whether to store context parameters in the State
getParameterDerivatives : bool=False getParameterDerivatives : bool=False
whether to store parameter derivatives in the State whether to store parameter derivatives in the State
getIntegratorParameters : bool=False
whether to store integrator parameters in the State
enforcePeriodicBox : bool=False enforcePeriodicBox : bool=False
if false, the position of each particle will be whatever position if false, the position of each particle will be whatever position
is stored in the Context, regardless of periodic boundary conditions. is stored in the Context, regardless of periodic boundary conditions.
...@@ -64,6 +67,8 @@ ...@@ -64,6 +67,8 @@
types += State.Parameters types += State.Parameters
if getParameterDerivatives: if getParameterDerivatives:
types += State.ParameterDerivatives types += State.ParameterDerivatives
if getIntegratorParameters:
types += State.IntegratorParameters
state = _openmm.Context_getState(self, types, enforcePeriodicBox, groups_mask) state = _openmm.Context_getState(self, types, enforcePeriodicBox, groups_mask)
return state return state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment