Commit 68b065e5 authored by Peter Eastman's avatar Peter Eastman
Browse files

Created StateBuilder as a mechanism for creating new State objects

parent a414ed23
...@@ -52,6 +52,7 @@ namespace OpenMM { ...@@ -52,6 +52,7 @@ namespace OpenMM {
class OPENMM_EXPORT State { class OPENMM_EXPORT State {
public: public:
class StateBuilder;
/** /**
* This is an enumeration of the types of data which may be stored in a State. When you create * This is an enumeration of the types of data which may be stored in a State. When you create
* a State, use these values to specify which data types it should contain. * a State, use these values to specify which data types it should contain.
...@@ -103,13 +104,11 @@ public: ...@@ -103,13 +104,11 @@ public:
*/ */
const std::map<std::string, double>& getParameters() const; const std::map<std::string, double>& getParameters() const;
private: private:
friend class Context; State(double time);
friend class StateProxy; void setPositions(const std::vector<Vec3>& pos);
State(double time, int numParticles, int types); void setVelocities(const std::vector<Vec3>& vel);
std::vector<Vec3>& updPositions(); void setForces(const std::vector<Vec3>& force);
std::vector<Vec3>& updVelocities(); void setParameters(const std::map<std::string, double>& params);
std::vector<Vec3>& updForces();
std::map<std::string, double>& updParameters();
void setEnergy(double ke, double pe); void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c); void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
int types; int types;
...@@ -121,6 +120,25 @@ private: ...@@ -121,6 +120,25 @@ private:
std::map<std::string, double> parameters; std::map<std::string, double> parameters;
}; };
/**
* Internal class used to construct new State objects.
* @private
*/
class State::StateBuilder {
public:
StateBuilder(double time);
State getState();
void setPositions(const std::vector<Vec3>& pos);
void setVelocities(const std::vector<Vec3>& vel);
void setForces(const std::vector<Vec3>& force);
void setParameters(const std::map<std::string, double>& params);
void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
private:
State state;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_STATE_H_*/ #endif /*OPENMM_STATE_H_*/
...@@ -82,28 +82,33 @@ Platform& Context::getPlatform() { ...@@ -82,28 +82,33 @@ Platform& Context::getPlatform() {
} }
State Context::getState(int types, bool enforcePeriodicBox, int groups) const { State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
State state(impl->getTime(), impl->getSystem().getNumParticles(), types); State::StateBuilder builder(impl->getTime());
Vec3 periodicBoxSize[3]; Vec3 periodicBoxSize[3];
impl->getPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]); impl->getPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
state.setPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]); builder.setPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
bool includeForces = types&State::Forces; bool includeForces = types&State::Forces;
bool includeEnergy = types&State::Energy; bool includeEnergy = types&State::Energy;
if (includeForces || includeEnergy) { if (includeForces || includeEnergy) {
double energy = impl->calcForcesAndEnergy(includeForces || includeEnergy, includeEnergy, groups); double energy = impl->calcForcesAndEnergy(includeForces || includeEnergy, includeEnergy, groups);
if (includeEnergy) if (includeEnergy)
state.setEnergy(impl->calcKineticEnergy(), energy); builder.setEnergy(impl->calcKineticEnergy(), energy);
if (includeForces) if (includeForces) {
impl->getForces(state.updForces()); vector<Vec3> forces;
impl->getForces(forces);
builder.setForces(forces);
}
} }
if (types&State::Parameters) { if (types&State::Parameters) {
map<string, double> params;
for (map<string, double>::const_iterator iter = impl->parameters.begin(); iter != impl->parameters.end(); iter++) for (map<string, double>::const_iterator iter = impl->parameters.begin(); iter != impl->parameters.end(); iter++)
state.updParameters()[iter->first] = iter->second; params[iter->first] = iter->second;
builder.setParameters(params);
} }
if (types&State::Positions) { if (types&State::Positions) {
impl->getPositions(state.updPositions()); vector<Vec3> positions;
impl->getPositions(positions);
if (enforcePeriodicBox) { if (enforcePeriodicBox) {
const vector<vector<int> >& molecules = impl->getMolecules(); const vector<vector<int> >& molecules = impl->getMolecules();
vector<Vec3>& positions = state.updPositions();
for (int i = 0; i < (int) molecules.size(); i++) { for (int i = 0; i < (int) molecules.size(); i++) {
// Find the molecule center. // Find the molecule center.
...@@ -131,10 +136,14 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const { ...@@ -131,10 +136,14 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
} }
} }
} }
builder.setPositions(positions);
}
if (types&State::Velocities) {
vector<Vec3> velocities;
impl->getVelocities(velocities);
builder.setVelocities(velocities);
} }
if (types&State::Velocities) return builder.getState();
impl->getVelocities(state.updVelocities());
return state;
} }
void Context::setTime(double time) { void Context::setTime(double time) {
......
...@@ -73,27 +73,34 @@ const map<string, double>& State::getParameters() const { ...@@ -73,27 +73,34 @@ const map<string, double>& State::getParameters() const {
throw OpenMMException("Invoked getParameters() on a State which does not contain parameters."); throw OpenMMException("Invoked getParameters() on a State which does not contain parameters.");
return parameters; return parameters;
} }
State::State(double time, int numParticles, int types) : types(types), time(time), ke(0), pe(0), State::State(double time) : types(0), time(time), ke(0), pe(0) {
positions( (types & Positions) == 0 ? 0 : numParticles), velocities( (types & Velocities) == 0 ? 0 : numParticles),
forces( (types & Forces) == 0 ? 0 : numParticles) {
} }
State::State() : types(0), time(0.0), ke(0), pe(0), positions(0), velocities(0), forces(0) { State::State() : types(0), time(0.0), ke(0), pe(0) {
} }
vector<Vec3>& State::updPositions() { void State::setPositions(const std::vector<Vec3>& pos) {
return positions; positions = pos;
types |= Positions;
} }
vector<Vec3>& State::updVelocities() {
return velocities; void State::setVelocities(const std::vector<Vec3>& vel) {
velocities = vel;
types |= Velocities;
} }
vector<Vec3>& State::updForces() {
return forces; void State::setForces(const std::vector<Vec3>& force) {
forces = force;
types |= Forces;
} }
map<string, double>& State::updParameters() {
return parameters; void State::setParameters(const std::map<std::string, double>& params) {
parameters = params;
types |= Parameters;
} }
void State::setEnergy(double kinetic, double potential) { void State::setEnergy(double kinetic, double potential) {
ke = kinetic; ke = kinetic;
pe = potential; pe = potential;
types |= Energy;
} }
void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) { void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
...@@ -101,3 +108,34 @@ void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) { ...@@ -101,3 +108,34 @@ void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
periodicBoxVectors[1] = b; periodicBoxVectors[1] = b;
periodicBoxVectors[2] = c; periodicBoxVectors[2] = c;
} }
State::StateBuilder::StateBuilder(double time) : state(time) {
}
State State::StateBuilder::getState() {
return state;
}
void State::StateBuilder::setPositions(const std::vector<Vec3>& pos) {
state.setPositions(pos);
}
void State::StateBuilder::setVelocities(const std::vector<Vec3>& vel) {
state.setVelocities(vel);
}
void State::StateBuilder::setForces(const std::vector<Vec3>& force) {
state.setForces(force);
}
void State::StateBuilder::setParameters(const std::map<std::string, double>& params) {
state.setParameters(params);
}
void State::StateBuilder::setEnergy(double ke, double pe) {
state.setEnergy(ke, pe);
}
void State::StateBuilder::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
state.setPeriodicBoxVectors(a, b, c);
}
...@@ -55,7 +55,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const { ...@@ -55,7 +55,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
SerializationNode& parametersNode = node.createChildNode("Parameters"); SerializationNode& parametersNode = node.createChildNode("Parameters");
map<string, double> stateParams = s.getParameters(); map<string, double> stateParams = s.getParameters();
map<string, double>::const_iterator it; map<string, double>::const_iterator it;
for(it = stateParams.begin(); it!=stateParams.end();it++) { for (it = stateParams.begin(); it!=stateParams.end();it++) {
parametersNode.setDoubleProperty(it->first, it->second); parametersNode.setDoubleProperty(it->first, it->second);
} }
} catch (const OpenMMException &) { } catch (const OpenMMException &) {
...@@ -73,7 +73,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const { ...@@ -73,7 +73,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
s.getPositions(); s.getPositions();
SerializationNode& positionsNode = node.createChildNode("Positions"); SerializationNode& positionsNode = node.createChildNode("Positions");
vector<Vec3> statePositions = s.getPositions(); vector<Vec3> statePositions = s.getPositions();
for(int i=0; i<statePositions.size();i++) { for (int i=0; i<statePositions.size();i++) {
positionsNode.createChildNode("Position").setDoubleProperty("x", statePositions[i][0]).setDoubleProperty("y", statePositions[i][1]).setDoubleProperty("z", statePositions[i][2]); positionsNode.createChildNode("Position").setDoubleProperty("x", statePositions[i][0]).setDoubleProperty("y", statePositions[i][1]).setDoubleProperty("z", statePositions[i][2]);
} }
} catch (const OpenMMException &) { } catch (const OpenMMException &) {
...@@ -83,7 +83,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const { ...@@ -83,7 +83,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
s.getVelocities(); s.getVelocities();
SerializationNode& velocitiesNode = node.createChildNode("Velocities"); SerializationNode& velocitiesNode = node.createChildNode("Velocities");
vector<Vec3> stateVelocities = s.getVelocities(); vector<Vec3> stateVelocities = s.getVelocities();
for(int i=0; i<stateVelocities.size();i++) { for (int i=0; i<stateVelocities.size();i++) {
velocitiesNode.createChildNode("Velocity").setDoubleProperty("x", stateVelocities[i][0]).setDoubleProperty("y", stateVelocities[i][1]).setDoubleProperty("z", stateVelocities[i][2]); velocitiesNode.createChildNode("Velocity").setDoubleProperty("x", stateVelocities[i][0]).setDoubleProperty("y", stateVelocities[i][1]).setDoubleProperty("z", stateVelocities[i][2]);
} }
} catch (const OpenMMException &) { } catch (const OpenMMException &) {
...@@ -93,7 +93,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const { ...@@ -93,7 +93,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
s.getForces(); s.getForces();
SerializationNode& forcesNode = node.createChildNode("Forces"); SerializationNode& forcesNode = node.createChildNode("Forces");
vector<Vec3> stateForces = s.getForces(); vector<Vec3> stateForces = s.getForces();
for(int i=0; i<stateForces.size();i++) { for (int i=0; i<stateForces.size();i++) {
forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]); forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]);
} }
} catch (const OpenMMException &) { } catch (const OpenMMException &) {
...@@ -119,7 +119,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const { ...@@ -119,7 +119,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
// inStateParams is really a <string,double> pair, where string is the name and double is the value // inStateParams is really a <string,double> pair, where string is the name and double is the value
// but we want to avoid casting a string to a double and instead use the built in routines, // but we want to avoid casting a string to a double and instead use the built in routines,
map<string, string> inStateParams = parametersNode.getProperties(); map<string, string> inStateParams = parametersNode.getProperties();
for(map<string, string>::const_iterator pit = inStateParams.begin(); pit != inStateParams.end(); pit++) { for (map<string, string>::const_iterator pit = inStateParams.begin(); pit != inStateParams.end(); pit++) {
outStateParams[pit->first] = parametersNode.getDoubleProperty(pit->first); outStateParams[pit->first] = parametersNode.getDoubleProperty(pit->first);
} }
types = types | State::Parameters; types = types | State::Parameters;
...@@ -141,7 +141,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const { ...@@ -141,7 +141,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
vector<Vec3> outForces; vector<Vec3> outForces;
try { try {
const SerializationNode& positionsNode = node.getChildNode("Positions"); const SerializationNode& positionsNode = node.getChildNode("Positions");
for(int i=0; i<(int) positionsNode.getChildren().size();i++) { for (int i = 0; i < (int) positionsNode.getChildren().size(); i++) {
const SerializationNode& particle = positionsNode.getChildren()[i]; const SerializationNode& particle = positionsNode.getChildren()[i];
outPositions.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z"))); outPositions.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
} }
...@@ -151,7 +151,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const { ...@@ -151,7 +151,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
} }
try { try {
const SerializationNode& velocitiesNode = node.getChildNode("Velocities"); const SerializationNode& velocitiesNode = node.getChildNode("Velocities");
for(int i=0; i<(int) velocitiesNode.getChildren().size();i++) { for (int i = 0; i < (int) velocitiesNode.getChildren().size(); i++) {
const SerializationNode& particle = velocitiesNode.getChildren()[i]; const SerializationNode& particle = velocitiesNode.getChildren()[i];
outVelocities.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z"))); outVelocities.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
} }
...@@ -161,7 +161,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const { ...@@ -161,7 +161,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
} }
try { try {
const SerializationNode& forcesNode = node.getChildNode("Forces"); const SerializationNode& forcesNode = node.getChildNode("Forces");
for(int i=0; i<(int) forcesNode.getChildren().size();i++) { for (int i = 0; i < (int) forcesNode.getChildren().size(); i++) {
const SerializationNode& particle = forcesNode.getChildren()[i]; const SerializationNode& particle = forcesNode.getChildren()[i];
outForces.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z"))); outForces.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
} }
...@@ -169,33 +169,34 @@ void* StateProxy::deserialize(const SerializationNode& node) const { ...@@ -169,33 +169,34 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
} catch (const OpenMMException &) { } catch (const OpenMMException &) {
// do nothing // do nothing
} }
int numParticles = max(outPositions.size(), max(outForces.size(), outVelocities.size()));
vector<int> arraySizes; vector<int> arraySizes;
State *s = new State(outTime,numParticles,types); State::StateBuilder builder(outTime);
if(types & State::Positions) { if (types & State::Positions) {
s->updPositions() = outPositions; builder.setPositions(outPositions);
arraySizes.push_back(outPositions.size()); arraySizes.push_back(outPositions.size());
} }
if(types & State::Velocities) { if (types & State::Velocities) {
s->updVelocities() = outVelocities; builder.setVelocities(outVelocities);
arraySizes.push_back(outVelocities.size()); arraySizes.push_back(outVelocities.size());
} }
if(types & State::Forces) { if (types & State::Forces) {
s->updForces() = outForces; builder.setForces(outForces);
arraySizes.push_back(outVelocities.size()); arraySizes.push_back(outForces.size());
} }
if(types & State::Energy) { if (types & State::Energy) {
s->setEnergy(kineticEnergy, potentialEnergy); builder.setEnergy(kineticEnergy, potentialEnergy);
} }
if(types & State::Parameters) { if (types & State::Parameters) {
s->updParameters() = outStateParams; builder.setParameters(outStateParams);
} }
for(int i=1; i<arraySizes.size();i++) { for (int i = 1; i < arraySizes.size(); i++) {
if(arraySizes[i] != arraySizes[i-1]) { if (arraySizes[i] != arraySizes[i-1]) {
throw(OpenMMException("State Deserialization Particle Size Mismatch, check number of particles in Forces, Velocities, Positions!")); throw(OpenMMException("State Deserialization Particle Size Mismatch, check number of particles in Forces, Velocities, Positions!"));
} }
} }
s->setPeriodicBoxVectors(outAVec, outBVec, outCVec); builder.setPeriodicBoxVectors(outAVec, outBVec, outCVec);
State *s = new State();
*s = builder.getState();
return s; return s;
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment