Commit 46098e35 authored by peastman's avatar peastman
Browse files

States can save integrator parameters

parent 694c3930
...@@ -1017,6 +1017,22 @@ public: ...@@ -1017,6 +1017,22 @@ public:
* Load the chain states from a checkpoint. * Load the chain states from a checkpoint.
*/ */
void loadCheckpoint(ContextImpl& context, std::istream& stream); void loadCheckpoint(ContextImpl& context, std::istream& stream);
/**
* Get the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void getChainStates(ContextImpl& context, std::vector<std::vector<double> >& positions, std::vector<std::vector<double> >& velocities) const;
/**
* Set the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void setChainStates(ContextImpl& context, const std::vector<std::vector<double> >& positions, const std::vector<std::vector<double> >& velocities);
private: private:
ComputeContext& cc; ComputeContext& cc;
float prevMaxPairDistance; float prevMaxPairDistance;
......
...@@ -6255,6 +6255,58 @@ void CommonIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, i ...@@ -6255,6 +6255,58 @@ void CommonIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, i
} }
} }
void CommonIntegrateNoseHooverStepKernel::getChainStates(ContextImpl& context, vector<vector<double> >& positions, vector<vector<double> >& velocities) const {
int numChains = chainState.size();
bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision();
positions.clear();
velocities.clear();
positions.resize(numChains);
velocities.resize(numChains);
for (int i = 0; i < numChains; i++) {
const ComputeArray& state = chainState.at(i);
if (useDouble) {
vector<mm_double2> stateVec;
state.download(stateVec);
for (int j = 0; j < stateVec.size(); j++) {
positions[i].push_back(stateVec[i].x);
velocities[i].push_back(stateVec[i].y);
}
}
else {
vector<mm_float2> stateVec;
state.download(stateVec);
for (int j = 0; j < stateVec.size(); j++) {
positions[i].push_back((float) stateVec[i].x);
velocities[i].push_back((float) stateVec[i].y);
}
}
}
}
void CommonIntegrateNoseHooverStepKernel::setChainStates(ContextImpl& context, const vector<vector<double> >& positions, const vector<vector<double> >& velocities) {
int numChains = chainState.size();
bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision();
if (positions.size() != numChains || velocities.size() != numChains)
throw OpenMMException("setChainStates(): wrong number of chains");
for (int i = 0; i < numChains; i++) {
ComputeArray& state = chainState[i];
if (positions[i].size() != state.getSize() || velocities[i].size() != state.getSize())
throw OpenMMException("setChainStates(): wrong number of beads in chain");
if (useDouble) {
vector<mm_double2> stateVec;
for (int j = 0; j < state.getSize(); j++)
stateVec.push_back(mm_double2(positions[i][j], velocities[i][j]));
state.upload(stateVec);
}
else {
vector<mm_float2> stateVec;
for (int j = 0; j < state.getSize(); j++)
stateVec.push_back(mm_float2((float) positions[i][j], (float) velocities[i][j]));
state.upload(stateVec);
}
}
}
void CommonIntegrateBrownianStepKernel::initialize(const System& system, const BrownianIntegrator& integrator) { void CommonIntegrateBrownianStepKernel::initialize(const System& system, const BrownianIntegrator& integrator) {
cc.initializeContexts(); cc.initializeContexts();
cc.setAsCurrent(); cc.setAsCurrent();
......
...@@ -1211,6 +1211,22 @@ public: ...@@ -1211,6 +1211,22 @@ public:
* Load the chain states from a checkpoint. * Load the chain states from a checkpoint.
*/ */
void loadCheckpoint(ContextImpl& context, std::istream& stream); void loadCheckpoint(ContextImpl& context, std::istream& stream);
/**
* Get the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void getChainStates(ContextImpl& context, std::vector<std::vector<double> >& positions, std::vector<std::vector<double> >& velocities) const;
/**
* Set the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void setChainStates(ContextImpl& context, const std::vector<std::vector<double> >& positions, const std::vector<std::vector<double> >& velocities);
private: private:
ReferencePlatform::PlatformData& data; ReferencePlatform::PlatformData& data;
ReferenceNoseHooverChain* chainPropagator; ReferenceNoseHooverChain* chainPropagator;
......
...@@ -2378,6 +2378,16 @@ void ReferenceIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context ...@@ -2378,6 +2378,16 @@ void ReferenceIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context
} }
} }
void ReferenceIntegrateNoseHooverStepKernel::getChainStates(ContextImpl& context, vector<vector<double> >& positions, vector<vector<double> >& velocities) const {
positions = chainPositions;
velocities = chainVelocities;
}
void ReferenceIntegrateNoseHooverStepKernel::setChainStates(ContextImpl& context, const vector<vector<double> >& positions, const vector<vector<double> >& velocities) {
chainPositions = positions;
chainVelocities = velocities;
}
ReferenceIntegrateLangevinStepKernel::~ReferenceIntegrateLangevinStepKernel() { ReferenceIntegrateLangevinStepKernel::~ReferenceIntegrateLangevinStepKernel() {
if (dynamics) if (dynamics)
delete dynamics; delete dynamics;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment