Unverified Commit cc3c4b54 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2790 from peastman/integratorparams

States can save integrator parameters
parents 03861778 16b6e4aa
...@@ -1136,6 +1136,22 @@ public: ...@@ -1136,6 +1136,22 @@ public:
* Load the chain states from a checkpoint. * Load the chain states from a checkpoint.
*/ */
virtual void loadCheckpoint(ContextImpl& context, std::istream& stream) = 0; virtual void loadCheckpoint(ContextImpl& context, std::istream& stream) = 0;
/**
* Get the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
virtual void getChainStates(ContextImpl& context, std::vector<std::vector<double> >& positions, std::vector<std::vector<double> >& velocities) const = 0;
/**
* Set the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
virtual void setChainStates(ContextImpl& context, const std::vector<std::vector<double> >& positions, const std::vector<std::vector<double> >& velocities) = 0;
}; };
/** /**
......
...@@ -204,6 +204,19 @@ protected: ...@@ -204,6 +204,19 @@ protected:
* data it wrote in createCheckpoint() and update its internal state accordingly. * data it wrote in createCheckpoint() and update its internal state accordingly.
*/ */
void loadCheckpoint(std::istream& stream); void loadCheckpoint(std::istream& stream);
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
void serializeParameters(SerializationNode& node) const;
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
void deserializeParameters(const SerializationNode& node);
private: private:
int currentIntegrator; int currentIntegrator;
std::vector<Integrator*> integrators; std::vector<Integrator*> integrators;
......
...@@ -678,6 +678,19 @@ protected: ...@@ -678,6 +678,19 @@ protected:
* data it wrote in createCheckpoint() and update its internal state accordingly. * data it wrote in createCheckpoint() and update its internal state accordingly.
*/ */
void loadCheckpoint(std::istream& stream); void loadCheckpoint(std::istream& stream);
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
void serializeParameters(SerializationNode& node) const;
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
void deserializeParameters(const SerializationNode& node);
private: private:
class ComputationInfo; class ComputationInfo;
class FunctionInfo; class FunctionInfo;
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "State.h" #include "State.h"
#include "Vec3.h" #include "Vec3.h"
#include "openmm/serialization/SerializationNode.h"
#include <iosfwd> #include <iosfwd>
#include <map> #include <map>
#include <vector> #include <vector>
...@@ -176,6 +177,21 @@ protected: ...@@ -176,6 +177,21 @@ protected:
*/ */
virtual void loadCheckpoint(std::istream& stream) { virtual void loadCheckpoint(std::istream& stream) {
} }
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
virtual void serializeParameters(SerializationNode& node) const {
}
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
virtual void deserializeParameters(const SerializationNode& node) {
}
private: private:
double stepSize, constraintTol; double stepSize, constraintTol;
int forceGroups; int forceGroups;
......
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2019 Stanford University and the Authors. * * Portions copyright (c) 2019-2020 Stanford University and the Authors. *
* Authors: Andreas Krämer and Andrew C. Simmonett * * Authors: Andreas Krämer and Andrew C. Simmonett *
* Contributors: * * Contributors: Peter Eastman *
* * * *
* Permission is hereby granted, free of charge, to any person obtaining a * * Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), * * copy of this software and associated documentation files (the "Software"), *
...@@ -270,6 +270,19 @@ protected: ...@@ -270,6 +270,19 @@ protected:
* data it wrote in createCheckpoint() and update its internal state accordingly. * data it wrote in createCheckpoint() and update its internal state accordingly.
*/ */
void loadCheckpoint(std::istream& stream); void loadCheckpoint(std::istream& stream);
/**
* This is called while creating a State. The Integrator should store the values
* of all time-varying parameters into the SerializationNode so they can be saved
* as part of the state.
*/
void serializeParameters(SerializationNode& node) const;
/**
* This is called when loading a previously saved State. The Integrator should
* load the values of all time-varying parameters from the SerializationNode. If
* the node contains parameters that are not defined for this Integrator, it should
* throw an exception.
*/
void deserializeParameters(const SerializationNode& node);
std::vector<NoseHooverChain> noseHooverChains; std::vector<NoseHooverChain> noseHooverChains;
std::vector<int> allAtoms; std::vector<int> allAtoms;
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2020 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "Vec3.h" #include "Vec3.h"
#include "openmm/serialization/SerializationNode.h"
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -58,7 +59,7 @@ public: ...@@ -58,7 +59,7 @@ public:
* This is an enumeration of the types of data which may be stored in a State. When you create * This is an enumeration of the types of data which may be stored in a State. When you create
* a State, use these values to specify which data types it should contain. * a State, use these values to specify which data types it should contain.
*/ */
enum DataType {Positions=1, Velocities=2, Forces=4, Energy=8, Parameters=16, ParameterDerivatives=32}; enum DataType {Positions=1, Velocities=2, Forces=4, Energy=8, Parameters=16, ParameterDerivatives=32, IntegratorParameters=64};
/** /**
* Construct an empty State containing no data. This exists so State objects can be used in STL containers. * Construct an empty State containing no data. This exists so State objects can be used in STL containers.
*/ */
...@@ -124,6 +125,8 @@ public: ...@@ -124,6 +125,8 @@ public:
*/ */
int getDataTypes() const; int getDataTypes() const;
private: private:
friend class Context;
friend class StateProxy;
State(double time); State(double time);
void setPositions(const std::vector<Vec3>& pos); void setPositions(const std::vector<Vec3>& pos);
void setVelocities(const std::vector<Vec3>& vel); void setVelocities(const std::vector<Vec3>& vel);
...@@ -132,6 +135,8 @@ private: ...@@ -132,6 +135,8 @@ private:
void setEnergyParameterDerivatives(const std::map<std::string, double>& derivs); void setEnergyParameterDerivatives(const std::map<std::string, double>& derivs);
void setEnergy(double ke, double pe); void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c); void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
SerializationNode& updateIntegratorParameters();
const SerializationNode& getIntegratorParameters() const;
int types; int types;
double time, ke, pe; double time, ke, pe;
std::vector<Vec3> positions; std::vector<Vec3> positions;
...@@ -139,6 +144,7 @@ private: ...@@ -139,6 +144,7 @@ private:
std::vector<Vec3> forces; std::vector<Vec3> forces;
Vec3 periodicBoxVectors[3]; Vec3 periodicBoxVectors[3];
std::map<std::string, double> parameters, energyParameterDerivatives; std::map<std::string, double> parameters, energyParameterDerivatives;
SerializationNode integratorParameters;
}; };
/** /**
...@@ -157,6 +163,7 @@ public: ...@@ -157,6 +163,7 @@ public:
void setEnergyParameterDerivatives(const std::map<std::string, double>& params); void setEnergyParameterDerivatives(const std::map<std::string, double>& params);
void setEnergy(double ke, double pe); void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c); void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
SerializationNode& updateIntegratorParameters();
private: private:
State state; State state;
}; };
......
...@@ -140,3 +140,22 @@ void CompoundIntegrator::loadCheckpoint(std::istream& stream) { ...@@ -140,3 +140,22 @@ void CompoundIntegrator::loadCheckpoint(std::istream& stream) {
for (int i = 0; i < integrators.size(); i++) for (int i = 0; i < integrators.size(); i++)
integrators[i]->loadCheckpoint(stream); integrators[i]->loadCheckpoint(stream);
} }
void CompoundIntegrator::serializeParameters(SerializationNode& node) const {
node.setIntProperty("version", 1);
node.setIntProperty("currentIntegrator", currentIntegrator);
for (int i = 0; i < getNumIntegrators(); i++) {
SerializationNode& child = node.createChildNode("IntegratorParameters");
integrators[i]->serializeParameters(child);
}
}
void CompoundIntegrator::deserializeParameters(const SerializationNode& node) {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
if (node.getChildren().size() != getNumIntegrators())
throw OpenMMException("State has wrong number of integrators for CompoundIntegrator");
setCurrentIntegrator(node.getIntProperty("currentIntegrator"));
for (int i = 0; i < node.getChildren().size(); i++)
integrators[i]->deserializeParameters(node.getChildren()[i]);
}
...@@ -147,6 +147,9 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const { ...@@ -147,6 +147,9 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
impl->getVelocities(velocities); impl->getVelocities(velocities);
builder.setVelocities(velocities); builder.setVelocities(velocities);
} }
if (types&State::IntegratorParameters) {
getIntegrator().serializeParameters(builder.updateIntegratorParameters());
}
return builder.getState(); return builder.getState();
} }
...@@ -162,6 +165,8 @@ void Context::setState(const State& state) { ...@@ -162,6 +165,8 @@ void Context::setState(const State& state) {
if ((state.getDataTypes()&State::Parameters) != 0) if ((state.getDataTypes()&State::Parameters) != 0)
for (auto& param : state.getParameters()) for (auto& param : state.getParameters())
setParameter(param.first, param.second); setParameter(param.first, param.second);
if ((state.getDataTypes()&State::IntegratorParameters) != 0)
getIntegrator().deserializeParameters(state.getIntegratorParameters());
} }
void Context::setTime(double time) { void Context::setTime(double time) {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2019 Stanford University and the Authors. * * Portions copyright (c) 2011-2020 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -348,3 +348,33 @@ void CustomIntegrator::setKineticEnergyExpression(const string& expression) { ...@@ -348,3 +348,33 @@ void CustomIntegrator::setKineticEnergyExpression(const string& expression) {
Lepton::CompiledExpression expr = Lepton::Parser::parse(kineticEnergy).createCompiledExpression(); Lepton::CompiledExpression expr = Lepton::Parser::parse(kineticEnergy).createCompiledExpression();
keNeedsForce = (expr.getVariables().find("f") != expr.getVariables().end()); keNeedsForce = (expr.getVariables().find("f") != expr.getVariables().end());
} }
void CustomIntegrator::serializeParameters(SerializationNode& node) const {
node.setIntProperty("version", 1);
SerializationNode& globalVariablesNode = node.createChildNode("GlobalVariables");
for (int i = 0; i < getNumGlobalVariables(); i++)
globalVariablesNode.setDoubleProperty(getGlobalVariableName(i), getGlobalVariable(i));
SerializationNode& perDofVariablesNode = node.createChildNode("PerDofVariables");
for (int i = 0; i < getNumPerDofVariables(); i++) {
SerializationNode& perDofValuesNode = perDofVariablesNode.createChildNode(getPerDofVariableName(i));
vector<Vec3> perDofValues;
getPerDofVariable(i, perDofValues);
for (int j = 0; j < perDofValues.size(); j++)
perDofValuesNode.createChildNode("Value").setDoubleProperty("x",perDofValues[j][0]).setDoubleProperty("y",perDofValues[j][1]).setDoubleProperty("z",perDofValues[j][2]);
}
}
void CustomIntegrator::deserializeParameters(const SerializationNode& node) {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& globalVariablesNode = node.getChildNode("GlobalVariables");
for (auto& prop : globalVariablesNode.getProperties())
setGlobalVariableByName(prop.first, globalVariablesNode.getDoubleProperty(prop.first));
const SerializationNode& perDofVariablesNode = node.getChildNode("PerDofVariables");
for (auto& var : perDofVariablesNode.getChildren()) {
vector<Vec3> perDofValues;
for (auto& child : var.getChildren())
perDofValues.push_back(Vec3(child.getDoubleProperty("x"), child.getDoubleProperty("y"), child.getDoubleProperty("z")));
setPerDofVariableByName(var.getName(), perDofValues);
}
}
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
* * * *
* Portions copyright (c) 2019-2020 Stanford University and the Authors. * * Portions copyright (c) 2019-2020 Stanford University and the Authors. *
* Authors: Andreas Krämer and Andrew C. Simmonett * * Authors: Andreas Krämer and Andrew C. Simmonett *
* Contributors: * * Contributors: Peter Eastman *
* * * *
* Permission is hereby granted, free of charge, to any person obtaining a * * Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), * * copy of this software and associated documentation files (the "Software"), *
...@@ -349,3 +349,29 @@ void NoseHooverIntegrator::createCheckpoint(std::ostream& stream) const { ...@@ -349,3 +349,29 @@ void NoseHooverIntegrator::createCheckpoint(std::ostream& stream) const {
void NoseHooverIntegrator::loadCheckpoint(std::istream& stream) { void NoseHooverIntegrator::loadCheckpoint(std::istream& stream) {
kernel.getAs<IntegrateNoseHooverStepKernel>().loadCheckpoint(*context, stream); kernel.getAs<IntegrateNoseHooverStepKernel>().loadCheckpoint(*context, stream);
} }
void NoseHooverIntegrator::serializeParameters(SerializationNode& node) const {
node.setIntProperty("version", 1);
vector<vector<double> > positions, velocities;
kernel.getAs<IntegrateNoseHooverStepKernel>().getChainStates(*context, positions, velocities);
for (int i = 0; i < positions.size(); i++) {
SerializationNode& chain = node.createChildNode("Chain");
for (int j = 0; j < positions[i].size(); j++)
chain.createChildNode("Bead").setDoubleProperty("position", positions[i][j]).setDoubleProperty("velocity", velocities[i][j]);
}
}
void NoseHooverIntegrator::deserializeParameters(const SerializationNode& node) {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
int numChains = node.getChildren().size();
vector<vector<double> > positions(numChains), velocities(numChains);
for (int i = 0; i < numChains; i++) {
auto& chain = node.getChildren()[i];
for (auto& bead : chain.getChildren()) {
positions[i].push_back(bead.getDoubleProperty("position"));
velocities[i].push_back(bead.getDoubleProperty("velocity"));
}
}
kernel.getAs<IntegrateNoseHooverStepKernel>().setChainStates(*context, positions, velocities);
}
...@@ -81,6 +81,16 @@ const map<string, double>& State::getEnergyParameterDerivatives() const { ...@@ -81,6 +81,16 @@ const map<string, double>& State::getEnergyParameterDerivatives() const {
throw OpenMMException("Invoked getEnergyParameterDerivatives() on a State which does not contain parameter derivatives."); throw OpenMMException("Invoked getEnergyParameterDerivatives() on a State which does not contain parameter derivatives.");
return energyParameterDerivatives; return energyParameterDerivatives;
} }
const SerializationNode& State::getIntegratorParameters() const {
if ((types&IntegratorParameters) == 0)
throw OpenMMException("Invoked getPIntegratorarameters() on a State which does not contain integrator parameters.");
return integratorParameters;
}
SerializationNode& State::updateIntegratorParameters() {
types |= IntegratorParameters;
integratorParameters.setName("IntegratorParameters");
return integratorParameters;
}
int State::getDataTypes() const { int State::getDataTypes() const {
return types; return types;
} }
...@@ -159,3 +169,7 @@ void State::StateBuilder::setEnergy(double ke, double pe) { ...@@ -159,3 +169,7 @@ void State::StateBuilder::setEnergy(double ke, double pe) {
void State::StateBuilder::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) { void State::StateBuilder::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
state.setPeriodicBoxVectors(a, b, c); state.setPeriodicBoxVectors(a, b, c);
} }
SerializationNode& State::StateBuilder::updateIntegratorParameters() {
return state.updateIntegratorParameters();
}
...@@ -1017,6 +1017,22 @@ public: ...@@ -1017,6 +1017,22 @@ public:
* Load the chain states from a checkpoint. * Load the chain states from a checkpoint.
*/ */
void loadCheckpoint(ContextImpl& context, std::istream& stream); void loadCheckpoint(ContextImpl& context, std::istream& stream);
/**
* Get the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void getChainStates(ContextImpl& context, std::vector<std::vector<double> >& positions, std::vector<std::vector<double> >& velocities) const;
/**
* Set the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void setChainStates(ContextImpl& context, const std::vector<std::vector<double> >& positions, const std::vector<std::vector<double> >& velocities);
private: private:
ComputeContext& cc; ComputeContext& cc;
float prevMaxPairDistance; float prevMaxPairDistance;
......
...@@ -6255,6 +6255,58 @@ void CommonIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, i ...@@ -6255,6 +6255,58 @@ void CommonIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, i
} }
} }
void CommonIntegrateNoseHooverStepKernel::getChainStates(ContextImpl& context, vector<vector<double> >& positions, vector<vector<double> >& velocities) const {
int numChains = chainState.size();
bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision();
positions.clear();
velocities.clear();
positions.resize(numChains);
velocities.resize(numChains);
for (int i = 0; i < numChains; i++) {
const ComputeArray& state = chainState.at(i);
if (useDouble) {
vector<mm_double2> stateVec;
state.download(stateVec);
for (int j = 0; j < stateVec.size(); j++) {
positions[i].push_back(stateVec[j].x);
velocities[i].push_back(stateVec[j].y);
}
}
else {
vector<mm_float2> stateVec;
state.download(stateVec);
for (int j = 0; j < stateVec.size(); j++) {
positions[i].push_back((float) stateVec[j].x);
velocities[i].push_back((float) stateVec[j].y);
}
}
}
}
void CommonIntegrateNoseHooverStepKernel::setChainStates(ContextImpl& context, const vector<vector<double> >& positions, const vector<vector<double> >& velocities) {
int numChains = positions.size();
bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision();
chainState.clear();
for (int i = 0; i < numChains; i++) {
int chainLength = positions[i].size();
chainState[i] = ComputeArray();
if (useDouble) {
chainState[i].initialize<mm_double2>(cc, chainLength, "chainState"+cc.intToString(i));
vector<mm_double2> stateVec;
for (int j = 0; j < chainLength; j++)
stateVec.push_back(mm_double2(positions[i][j], velocities[i][j]));
chainState[i].upload(stateVec);
}
else {
chainState[i].initialize<mm_float2>(cc, chainLength, "chainState"+cc.intToString(i));
vector<mm_float2> stateVec;
for (int j = 0; j < chainLength; j++)
stateVec.push_back(mm_float2((float) positions[i][j], (float) velocities[i][j]));
chainState[i].upload(stateVec);
}
}
}
void CommonIntegrateBrownianStepKernel::initialize(const System& system, const BrownianIntegrator& integrator) { void CommonIntegrateBrownianStepKernel::initialize(const System& system, const BrownianIntegrator& integrator) {
cc.initializeContexts(); cc.initializeContexts();
cc.setAsCurrent(); cc.setAsCurrent();
......
...@@ -1211,6 +1211,22 @@ public: ...@@ -1211,6 +1211,22 @@ public:
* Load the chain states from a checkpoint. * Load the chain states from a checkpoint.
*/ */
void loadCheckpoint(ContextImpl& context, std::istream& stream); void loadCheckpoint(ContextImpl& context, std::istream& stream);
/**
* Get the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void getChainStates(ContextImpl& context, std::vector<std::vector<double> >& positions, std::vector<std::vector<double> >& velocities) const;
/**
* Set the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void setChainStates(ContextImpl& context, const std::vector<std::vector<double> >& positions, const std::vector<std::vector<double> >& velocities);
private: private:
ReferencePlatform::PlatformData& data; ReferencePlatform::PlatformData& data;
ReferenceNoseHooverChain* chainPropagator; ReferenceNoseHooverChain* chainPropagator;
......
...@@ -2378,6 +2378,16 @@ void ReferenceIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context ...@@ -2378,6 +2378,16 @@ void ReferenceIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context
} }
} }
void ReferenceIntegrateNoseHooverStepKernel::getChainStates(ContextImpl& context, vector<vector<double> >& positions, vector<vector<double> >& velocities) const {
positions = chainPositions;
velocities = chainVelocities;
}
void ReferenceIntegrateNoseHooverStepKernel::setChainStates(ContextImpl& context, const vector<vector<double> >& positions, const vector<vector<double> >& velocities) {
chainPositions = positions;
chainVelocities = velocities;
}
ReferenceIntegrateLangevinStepKernel::~ReferenceIntegrateLangevinStepKernel() { ReferenceIntegrateLangevinStepKernel::~ReferenceIntegrateLangevinStepKernel() {
if (dynamics) if (dynamics)
delete dynamics; delete dynamics;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2015 Stanford University and the Authors. * * Portions copyright (c) 2010-2020 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -89,6 +89,9 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const { ...@@ -89,6 +89,9 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]); forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]);
} }
} }
if ((s.getDataTypes()&State::IntegratorParameters) != 0) {
node.getChildren().push_back(s.getIntegratorParameters());
}
} }
void* StateProxy::deserialize(const SerializationNode& node) const { void* StateProxy::deserialize(const SerializationNode& node) const {
...@@ -138,6 +141,9 @@ void* StateProxy::deserialize(const SerializationNode& node) const { ...@@ -138,6 +141,9 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
builder.setForces(outForces); builder.setForces(outForces);
arraySizes.push_back(outForces.size()); arraySizes.push_back(outForces.size());
} }
else if (child.getName() == "IntegratorParameters") {
builder.updateIntegratorParameters() = child;
}
} }
for (int i = 1; i < arraySizes.size(); i++) { for (int i = 1; i < arraySizes.size(); i++) {
if (arraySizes[i] != arraySizes[i-1]) { if (arraySizes[i] != arraySizes[i-1]) {
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "openmm/NonbondedForce.h" #include "openmm/NonbondedForce.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/LangevinIntegrator.h" #include "openmm/LangevinIntegrator.h"
#include "openmm/AndersenThermostat.h" #include "openmm/AndersenThermostat.h"
#include "openmm/MonteCarloBarostat.h" #include "openmm/MonteCarloBarostat.h"
...@@ -183,9 +184,42 @@ void testSerialization() { ...@@ -183,9 +184,42 @@ void testSerialization() {
} }
} }
void testIntegratorParameters() {
// Create a Context involving integrator parameters.
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("a", 1.0);
integrator.addPerDofVariable("b", 2.0);
Context context(system, integrator);
integrator.setGlobalVariable(0, 3.0);
integrator.setPerDofVariable(0, {Vec3(1.0, 2.0, 3.0)});
// Create a State, then serialize and deserialize it.
State s1 = context.getState(State::IntegratorParameters);
stringstream buffer;
XmlSerializer::serialize<State>(&s1, "State", buffer);
State* copy = XmlSerializer::deserialize<State>(buffer);
State& s2 = *copy;
// Set the State on a new Context and make sure all the integrator parameters
// survived the serialization and deserialization.
CustomIntegrator* integrator2 = XmlSerializer::clone(integrator);
Context context2(system, *integrator2);
context2.setState(s2);
ASSERT_EQUAL(3.0, integrator2->getGlobalVariable(0));
vector<Vec3> values;
integrator2->getPerDofVariable(0, values);
ASSERT_EQUAL_VEC(Vec3(1.0, 2.0, 3.0), values[0], 1e-6);
}
int main() { int main() {
try { try {
testSerialization(); testSerialization();
testIntegratorParameters();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -239,6 +239,35 @@ void testCheckpoint() { ...@@ -239,6 +239,35 @@ void testCheckpoint() {
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6); ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
} }
void testSaveParameters() {
// Test that integrator variables get loaded correctly from States.
System system;
system.addParticle(1.0);
CustomIntegrator* custom = new CustomIntegrator(0.001);
custom->addGlobalVariable("a", 1.0);
custom->addPerDofVariable("b", 2.0);
CompoundIntegrator integrator;
integrator.addIntegrator(custom);
integrator.addIntegrator(new VerletIntegrator(0.005));
Context context(system, integrator, platform);
vector<Vec3> positions(1, Vec3());
context.setPositions(positions);
custom->setGlobalVariable(0, 5.0);
vector<Vec3> b1(1, Vec3(1, 2, 3));
custom->setPerDofVariable(0, b1);
State savedState = context.getState(State::IntegratorParameters);
custom->setGlobalVariable(0, 10.0);
vector<Vec3> b2(1, Vec3(4, 5, 6));
custom->setPerDofVariable(0, b2);
integrator.setCurrentIntegrator(1);
context.setState(savedState);
ASSERT_EQUAL(0, integrator.getCurrentIntegrator());
ASSERT_EQUAL(5.0, custom->getGlobalVariable(0));
vector<Vec3> b3;
custom->getPerDofVariable(0, b3);
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -248,6 +277,7 @@ int main(int argc, char* argv[]) { ...@@ -248,6 +277,7 @@ int main(int argc, char* argv[]) {
testChangingParameters(); testChangingParameters();
testDifferentStepSizes(); testDifferentStepSizes();
testCheckpoint(); testCheckpoint();
testSaveParameters();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -1187,6 +1187,30 @@ void testCheckpoint() { ...@@ -1187,6 +1187,30 @@ void testCheckpoint() {
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6); ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
} }
void testSaveParameters() {
// Test that integrator variables get loaded correctly from States.
System system;
system.addParticle(1.0);
CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("a", 1.0);
integrator.addPerDofVariable("b", 2.0);
Context context(system, integrator, platform);
vector<Vec3> positions(1, Vec3());
context.setPositions(positions);
integrator.setGlobalVariable(0, 5.0);
vector<Vec3> b1(1, Vec3(1, 2, 3));
integrator.setPerDofVariable(0, b1);
State savedState = context.getState(State::IntegratorParameters);
integrator.setGlobalVariable(0, 10.0);
vector<Vec3> b2(1, Vec3(4, 5, 6));
integrator.setPerDofVariable(0, b2);
context.setState(savedState);
ASSERT_EQUAL(5.0, integrator.getGlobalVariable(0));
vector<Vec3> b3;
integrator.getPerDofVariable(0, b3);
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -1216,6 +1240,7 @@ int main(int argc, char* argv[]) { ...@@ -1216,6 +1240,7 @@ int main(int argc, char* argv[]) {
testRecordEnergy(); testRecordEnergy();
testInitialTemperature(); testInitialTemperature();
testCheckpoint(); testCheckpoint();
testSaveParameters();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -499,6 +499,57 @@ void testCheckpoints() { ...@@ -499,6 +499,57 @@ void testCheckpoints() {
ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6); ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6);
} }
void testSaveParameters() {
// Create a system with Drude-like particles to be thermostated as a pair, as well as another
// particle to be thermostated independently, to test all integrator features.
double timeStep = 0.001;
NoseHooverIntegrator integrator(timeStep), newIntegrator(timeStep);
System system;
double mass = 1;
system.addParticle(8*mass);
system.addParticle(mass);
system.addParticle(5*mass);
HarmonicBondForce* force = new HarmonicBondForce();
force->addBond(0, 1, 0.1, 50.0);
force->addBond(0, 2, 0.1, 50.0);
system.addForce(force);
double kineticEnergy = 1e6;
double temperature=300, collisionFrequency=1, chainLength=3, numMTS=3, numYS=3;
chainLength = 10;
integrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
newIntegrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
Context context(system, integrator, platform);
Context newContext(system, newIntegrator, platform);
std::vector<Vec3> positions(3);
std::vector<Vec3> velocities(3);
positions[1] = {0.1, 0.0, 0.0};
velocities[1] = {0.1,0.2,-0.2};
positions[2] = {-0.1, 0.001, 0.001};
velocities[2] = {-0.1,0.2,-0.2};
context.setPositions(positions);
context.setVelocities(velocities);
// Run a short simulation and save a state..
integrator.step(500);
State savedState = context.getState(State::Positions | State::Velocities | State::IntegratorParameters);
// Now continue the simulation
integrator.step(5);
// And try the same, starting from the state
newContext.setState(savedState);
newIntegrator.step(5);
State state1 = context.getState(State::Positions | State::Velocities);
State state2 = newContext.getState(State::Positions | State::Velocities);
ASSERT_EQUAL_VEC(state1.getPositions()[0], state2.getPositions()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getPositions()[1], state2.getPositions()[1], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[0], state2.getVelocities()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6);
}
void testAPIChangeNumParticles() { void testAPIChangeNumParticles() {
bool constrain = true; bool constrain = true;
int numMolecules = 20; int numMolecules = 20;
...@@ -553,6 +604,7 @@ int main(int argc, char* argv[]) { ...@@ -553,6 +604,7 @@ int main(int argc, char* argv[]) {
constrain = false; testDimerBox(constrain); constrain = false; testDimerBox(constrain);
constrain = true; testDimerBox(constrain); constrain = true; testDimerBox(constrain);
testCheckpoints(); testCheckpoints();
testSaveParameters();
testForceGroups(); testForceGroups();
runPlatformTests(); runPlatformTests();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment