Commit 68b065e5 authored by Peter Eastman's avatar Peter Eastman
Browse files

Created StateBuilder as a mechanism for creating new State objects

parent a414ed23
......@@ -52,6 +52,7 @@ namespace OpenMM {
class OPENMM_EXPORT State {
public:
class StateBuilder;
/**
* This is an enumeration of the types of data which may be stored in a State. When you create
* a State, use these values to specify which data types it should contain.
......@@ -103,13 +104,11 @@ public:
*/
const std::map<std::string, double>& getParameters() const;
private:
friend class Context;
friend class StateProxy;
State(double time, int numParticles, int types);
std::vector<Vec3>& updPositions();
std::vector<Vec3>& updVelocities();
std::vector<Vec3>& updForces();
std::map<std::string, double>& updParameters();
State(double time);
void setPositions(const std::vector<Vec3>& pos);
void setVelocities(const std::vector<Vec3>& vel);
void setForces(const std::vector<Vec3>& force);
void setParameters(const std::map<std::string, double>& params);
void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
int types;
......@@ -121,6 +120,25 @@ private:
std::map<std::string, double> parameters;
};
/**
* Internal class used to construct new State objects.
* @private
*/
class State::StateBuilder {
public:
StateBuilder(double time);
State getState();
void setPositions(const std::vector<Vec3>& pos);
void setVelocities(const std::vector<Vec3>& vel);
void setForces(const std::vector<Vec3>& force);
void setParameters(const std::map<std::string, double>& params);
void setEnergy(double ke, double pe);
void setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c);
private:
State state;
};
} // namespace OpenMM
#endif /*OPENMM_STATE_H_*/
......@@ -82,28 +82,33 @@ Platform& Context::getPlatform() {
}
State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
State state(impl->getTime(), impl->getSystem().getNumParticles(), types);
State::StateBuilder builder(impl->getTime());
Vec3 periodicBoxSize[3];
impl->getPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
state.setPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
builder.setPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
bool includeForces = types&State::Forces;
bool includeEnergy = types&State::Energy;
if (includeForces || includeEnergy) {
double energy = impl->calcForcesAndEnergy(includeForces || includeEnergy, includeEnergy, groups);
if (includeEnergy)
state.setEnergy(impl->calcKineticEnergy(), energy);
if (includeForces)
impl->getForces(state.updForces());
builder.setEnergy(impl->calcKineticEnergy(), energy);
if (includeForces) {
vector<Vec3> forces;
impl->getForces(forces);
builder.setForces(forces);
}
}
if (types&State::Parameters) {
map<string, double> params;
for (map<string, double>::const_iterator iter = impl->parameters.begin(); iter != impl->parameters.end(); iter++)
state.updParameters()[iter->first] = iter->second;
params[iter->first] = iter->second;
builder.setParameters(params);
}
if (types&State::Positions) {
impl->getPositions(state.updPositions());
vector<Vec3> positions;
impl->getPositions(positions);
if (enforcePeriodicBox) {
const vector<vector<int> >& molecules = impl->getMolecules();
vector<Vec3>& positions = state.updPositions();
for (int i = 0; i < (int) molecules.size(); i++) {
// Find the molecule center.
......@@ -131,10 +136,14 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
}
}
}
builder.setPositions(positions);
}
if (types&State::Velocities) {
vector<Vec3> velocities;
impl->getVelocities(velocities);
builder.setVelocities(velocities);
}
if (types&State::Velocities)
impl->getVelocities(state.updVelocities());
return state;
return builder.getState();
}
void Context::setTime(double time) {
......
......@@ -73,27 +73,34 @@ const map<string, double>& State::getParameters() const {
throw OpenMMException("Invoked getParameters() on a State which does not contain parameters.");
return parameters;
}
State::State(double time, int numParticles, int types) : types(types), time(time), ke(0), pe(0),
positions( (types & Positions) == 0 ? 0 : numParticles), velocities( (types & Velocities) == 0 ? 0 : numParticles),
forces( (types & Forces) == 0 ? 0 : numParticles) {
State::State(double time) : types(0), time(time), ke(0), pe(0) {
}
State::State() : types(0), time(0.0), ke(0), pe(0), positions(0), velocities(0), forces(0) {
State::State() : types(0), time(0.0), ke(0), pe(0) {
}
vector<Vec3>& State::updPositions() {
return positions;
void State::setPositions(const std::vector<Vec3>& pos) {
positions = pos;
types |= Positions;
}
vector<Vec3>& State::updVelocities() {
return velocities;
void State::setVelocities(const std::vector<Vec3>& vel) {
velocities = vel;
types |= Velocities;
}
vector<Vec3>& State::updForces() {
return forces;
void State::setForces(const std::vector<Vec3>& force) {
forces = force;
types |= Forces;
}
map<string, double>& State::updParameters() {
return parameters;
void State::setParameters(const std::map<std::string, double>& params) {
parameters = params;
types |= Parameters;
}
void State::setEnergy(double kinetic, double potential) {
ke = kinetic;
pe = potential;
types |= Energy;
}
void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
......@@ -101,3 +108,34 @@ void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
periodicBoxVectors[1] = b;
periodicBoxVectors[2] = c;
}
State::StateBuilder::StateBuilder(double time) : state(time) {
}
State State::StateBuilder::getState() {
return state;
}
void State::StateBuilder::setPositions(const std::vector<Vec3>& pos) {
state.setPositions(pos);
}
void State::StateBuilder::setVelocities(const std::vector<Vec3>& vel) {
state.setVelocities(vel);
}
void State::StateBuilder::setForces(const std::vector<Vec3>& force) {
state.setForces(force);
}
void State::StateBuilder::setParameters(const std::map<std::string, double>& params) {
state.setParameters(params);
}
void State::StateBuilder::setEnergy(double ke, double pe) {
state.setEnergy(ke, pe);
}
void State::StateBuilder::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
state.setPeriodicBoxVectors(a, b, c);
}
......@@ -55,7 +55,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
SerializationNode& parametersNode = node.createChildNode("Parameters");
map<string, double> stateParams = s.getParameters();
map<string, double>::const_iterator it;
for(it = stateParams.begin(); it!=stateParams.end();it++) {
for (it = stateParams.begin(); it!=stateParams.end();it++) {
parametersNode.setDoubleProperty(it->first, it->second);
}
} catch (const OpenMMException &) {
......@@ -73,7 +73,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
s.getPositions();
SerializationNode& positionsNode = node.createChildNode("Positions");
vector<Vec3> statePositions = s.getPositions();
for(int i=0; i<statePositions.size();i++) {
for (int i=0; i<statePositions.size();i++) {
positionsNode.createChildNode("Position").setDoubleProperty("x", statePositions[i][0]).setDoubleProperty("y", statePositions[i][1]).setDoubleProperty("z", statePositions[i][2]);
}
} catch (const OpenMMException &) {
......@@ -83,7 +83,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
s.getVelocities();
SerializationNode& velocitiesNode = node.createChildNode("Velocities");
vector<Vec3> stateVelocities = s.getVelocities();
for(int i=0; i<stateVelocities.size();i++) {
for (int i=0; i<stateVelocities.size();i++) {
velocitiesNode.createChildNode("Velocity").setDoubleProperty("x", stateVelocities[i][0]).setDoubleProperty("y", stateVelocities[i][1]).setDoubleProperty("z", stateVelocities[i][2]);
}
} catch (const OpenMMException &) {
......@@ -93,7 +93,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
s.getForces();
SerializationNode& forcesNode = node.createChildNode("Forces");
vector<Vec3> stateForces = s.getForces();
for(int i=0; i<stateForces.size();i++) {
for (int i=0; i<stateForces.size();i++) {
forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]);
}
} catch (const OpenMMException &) {
......@@ -119,7 +119,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
// inStateParams is really a <string,double> pair, where string is the name and double is the value
// but we want to avoid casting a string to a double and instead use the built in routines,
map<string, string> inStateParams = parametersNode.getProperties();
for(map<string, string>::const_iterator pit = inStateParams.begin(); pit != inStateParams.end(); pit++) {
for (map<string, string>::const_iterator pit = inStateParams.begin(); pit != inStateParams.end(); pit++) {
outStateParams[pit->first] = parametersNode.getDoubleProperty(pit->first);
}
types = types | State::Parameters;
......@@ -141,7 +141,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
vector<Vec3> outForces;
try {
const SerializationNode& positionsNode = node.getChildNode("Positions");
for(int i=0; i<(int) positionsNode.getChildren().size();i++) {
for (int i = 0; i < (int) positionsNode.getChildren().size(); i++) {
const SerializationNode& particle = positionsNode.getChildren()[i];
outPositions.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
}
......@@ -151,7 +151,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
}
try {
const SerializationNode& velocitiesNode = node.getChildNode("Velocities");
for(int i=0; i<(int) velocitiesNode.getChildren().size();i++) {
for (int i = 0; i < (int) velocitiesNode.getChildren().size(); i++) {
const SerializationNode& particle = velocitiesNode.getChildren()[i];
outVelocities.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
}
......@@ -161,7 +161,7 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
}
try {
const SerializationNode& forcesNode = node.getChildNode("Forces");
for(int i=0; i<(int) forcesNode.getChildren().size();i++) {
for (int i = 0; i < (int) forcesNode.getChildren().size(); i++) {
const SerializationNode& particle = forcesNode.getChildren()[i];
outForces.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
}
......@@ -169,33 +169,34 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
} catch (const OpenMMException &) {
// do nothing
}
int numParticles = max(outPositions.size(), max(outForces.size(), outVelocities.size()));
vector<int> arraySizes;
State *s = new State(outTime,numParticles,types);
if(types & State::Positions) {
s->updPositions() = outPositions;
State::StateBuilder builder(outTime);
if (types & State::Positions) {
builder.setPositions(outPositions);
arraySizes.push_back(outPositions.size());
}
if(types & State::Velocities) {
s->updVelocities() = outVelocities;
if (types & State::Velocities) {
builder.setVelocities(outVelocities);
arraySizes.push_back(outVelocities.size());
}
if(types & State::Forces) {
s->updForces() = outForces;
arraySizes.push_back(outVelocities.size());
if (types & State::Forces) {
builder.setForces(outForces);
arraySizes.push_back(outForces.size());
}
if(types & State::Energy) {
s->setEnergy(kineticEnergy, potentialEnergy);
if (types & State::Energy) {
builder.setEnergy(kineticEnergy, potentialEnergy);
}
if(types & State::Parameters) {
s->updParameters() = outStateParams;
if (types & State::Parameters) {
builder.setParameters(outStateParams);
}
for(int i=1; i<arraySizes.size();i++) {
if(arraySizes[i] != arraySizes[i-1]) {
for (int i = 1; i < arraySizes.size(); i++) {
if (arraySizes[i] != arraySizes[i-1]) {
throw(OpenMMException("State Deserialization Particle Size Mismatch, check number of particles in Forces, Velocities, Positions!"));
}
}
s->setPeriodicBoxVectors(outAVec, outBVec, outCVec);
builder.setPeriodicBoxVectors(outAVec, outBVec, outCVec);
State *s = new State();
*s = builder.getState();
return s;
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment