Unverified Commit 22da37a4 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Allow querying current step count (#3248)

* Allow querying current step count

* Fixed error building Python wrapper
parent b7c9526a
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2020 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -148,6 +148,18 @@ public:
* @param time the time
*/
virtual void setTime(ContextImpl& context, double time) = 0;
/**
* Get the current step count
*
* @param context the context in which to execute this kernel
*/
virtual long long getStepCount(const ContextImpl& context) const = 0;
/**
* Set the current step count
*
* @param context the context in which to execute this kernel
*/
virtual void setStepCount(const ContextImpl& context, long long count) = 0;
/**
* Get the positions of all particles.
*
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -137,10 +137,22 @@ public:
* specific.
*/
void setState(const State& state);
/**
* Get the current time of the simulation (in picoseconds).
*/
double getTime() const;
/**
* Set the current time of the simulation (in picoseconds).
*/
void setTime(double time);
/**
* Get the current step count.
*/
long long getStepCount() const;
/**
* Set the current step count.
*/
void setStepCount(long long count);
/**
* Set the positions of all particles in the System (measured in nm). This method simply sets the positions
* without checking to see whether they satisfy distance constraints. If you want constraints to be
......
......@@ -68,6 +68,10 @@ public:
* Get the time for which this State was created.
*/
double getTime() const;
/**
* Get the number of integration steps that had been performed when this State was created.
*/
long long getStepCount() const;
/**
* Get the position of each particle. If this State does not contain positions, this will throw an exception.
*/
......@@ -127,7 +131,7 @@ public:
private:
friend class Context;
friend class StateProxy;
State(double time);
State(double time, long long stepCount);
void setPositions(const std::vector<Vec3>& pos);
void setVelocities(const std::vector<Vec3>& vel);
void setForces(const std::vector<Vec3>& force);
......@@ -139,6 +143,7 @@ private:
const SerializationNode& getIntegratorParameters() const;
int types;
double time, ke, pe;
long long stepCount;
std::vector<Vec3> positions;
std::vector<Vec3> velocities;
std::vector<Vec3> forces;
......@@ -154,7 +159,7 @@ private:
class OPENMM_EXPORT State::StateBuilder {
public:
StateBuilder(double time);
StateBuilder(double time, long long stepCount);
State getState();
void setPositions(const std::vector<Vec3>& pos);
void setVelocities(const std::vector<Vec3>& vel);
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -92,6 +92,14 @@ public:
* Set the current time (in picoseconds).
*/
void setTime(double t);
/**
* Get the current step count.
*/
long long getStepCount() const;
/**
* Set the current step count.
*/
void setStepCount(long long count);
/**
* Get the positions of all particles.
*
......
......@@ -87,7 +87,7 @@ Platform& Context::getPlatform() {
}
State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
State::StateBuilder builder(impl->getTime());
State::StateBuilder builder(impl->getTime(), impl->getStepCount());
Vec3 periodicBoxSize[3];
impl->getPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
builder.setPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
......@@ -155,6 +155,7 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
void Context::setState(const State& state) {
setTime(state.getTime());
setStepCount(state.getStepCount());
Vec3 a, b, c;
state.getPeriodicBoxVectors(a, b, c);
setPeriodicBoxVectors(a, b, c);
......@@ -169,10 +170,22 @@ void Context::setState(const State& state) {
getIntegrator().deserializeParameters(state.getIntegratorParameters());
}
double Context::getTime() const {
return impl->getTime();
}
void Context::setTime(double time) {
impl->setTime(time);
}
long long Context::getStepCount() const {
return impl->getStepCount();
}
void Context::setStepCount(long long count) {
impl->setStepCount(count);
}
void Context::setPositions(const vector<Vec3>& positions) {
if ((int) positions.size() != impl->getSystem().getNumParticles())
throw OpenMMException("Called setPositions() on a Context with the wrong number of positions");
......
......@@ -216,6 +216,14 @@ void ContextImpl::setTime(double t) {
updateStateDataKernel.getAs<UpdateStateDataKernel>().setTime(*this, t);
}
long long ContextImpl::getStepCount() const {
return updateStateDataKernel.getAs<const UpdateStateDataKernel>().getStepCount(*this);
}
void ContextImpl::setStepCount(long long count) {
updateStateDataKernel.getAs<UpdateStateDataKernel>().setStepCount(*this, count);
}
void ContextImpl::getPositions(std::vector<Vec3>& positions) {
updateStateDataKernel.getAs<UpdateStateDataKernel>().getPositions(*this, positions);
}
......
......@@ -38,6 +38,9 @@ using namespace std;
double State::getTime() const {
return time;
}
long long State::getStepCount() const {
return stepCount;
}
const vector<Vec3>& State::getPositions() const {
if ((types&Positions) == 0)
throw OpenMMException("Invoked getPositions() on a State which does not contain positions.");
......@@ -94,7 +97,7 @@ SerializationNode& State::updateIntegratorParameters() {
int State::getDataTypes() const {
return types;
}
State::State(double time) : types(0), time(time), ke(0), pe(0) {
State::State(double time, long long stepCount) : types(0), time(time), stepCount(stepCount), ke(0), pe(0) {
}
State::State() : types(0), time(0.0), ke(0), pe(0) {
}
......@@ -135,7 +138,7 @@ void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
periodicBoxVectors[2] = c;
}
State::StateBuilder::StateBuilder(double time) : state(time) {
State::StateBuilder::StateBuilder(double time, long long stepCount) : state(time, stepCount) {
}
State State::StateBuilder::getState() {
......
......@@ -179,13 +179,13 @@ public:
/**
* Get the number of integration steps that have been taken.
*/
int getStepCount() {
long long getStepCount() {
return stepCount;
}
/**
* Set the number of integration steps that have been taken.
*/
void setStepCount(int steps) {
void setStepCount(long long steps) {
stepCount = steps;
}
/**
......@@ -480,7 +480,8 @@ protected:
void reorderAtomsImpl();
const System& system;
double time;
int numAtoms, paddedNumAtoms, stepCount, computeForceCount, stepsSinceReorder;
int numAtoms, paddedNumAtoms, computeForceCount, stepsSinceReorder;
long long stepCount;
bool atomsWereReordered, forcesValid;
std::vector<ComputeForceInfo*> forces;
std::vector<Molecule> molecules;
......
......@@ -134,6 +134,18 @@ public:
* @param context the context in which to execute this kernel
*/
void setTime(ContextImpl& context, double time);
/**
* Get the current step count
*
* @param context the context in which to execute this kernel
*/
long long getStepCount(const ContextImpl& context) const;
/**
* Set the current step count
*
* @param context the context in which to execute this kernel
*/
void setStepCount(const ContextImpl& context, long long count);
/**
* Get the positions of all particles.
*
......
......@@ -139,8 +139,8 @@ public:
std::vector<CudaContext*> contexts;
std::vector<double> contextEnergy;
bool hasInitializedContexts, removeCM, peerAccessSupported, useCpuPme, disablePmeStream, deterministicForces, allowRuntimeCompiler;
int cmMotionFrequency;
int stepCount, computeForceCount;
int cmMotionFrequency, computeForceCount;
long long stepCount;
double time;
std::map<std::string, std::string> propertyValues;
ThreadPool threads;
......
......@@ -98,6 +98,16 @@ void CudaUpdateStateDataKernel::setTime(ContextImpl& context, double time) {
ctx->setTime(time);
}
long long CudaUpdateStateDataKernel::getStepCount(const ContextImpl& context) const {
return cu.getStepCount();
}
void CudaUpdateStateDataKernel::setStepCount(const ContextImpl& context, long long count) {
vector<CudaContext*>& contexts = cu.getPlatformData().contexts;
for (auto ctx : contexts)
ctx->setStepCount(count);
}
void CudaUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) {
cu.setAsCurrent();
int numParticles = context.getSystem().getNumParticles();
......@@ -343,8 +353,8 @@ void CudaUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream&
stream.write((char*) &precision, sizeof(int));
double time = cu.getTime();
stream.write((char*) &time, sizeof(double));
int stepCount = cu.getStepCount();
stream.write((char*) &stepCount, sizeof(int));
long long stepCount = cu.getStepCount();
stream.write((char*) &stepCount, sizeof(long long));
int stepsSinceReorder = cu.getStepsSinceReorder();
stream.write((char*) &stepsSinceReorder, sizeof(int));
char* buffer = (char*) cu.getPinnedBuffer();
......@@ -378,8 +388,9 @@ void CudaUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& st
throw OpenMMException("Checkpoint was created with a different numeric precision");
double time;
stream.read((char*) &time, sizeof(double));
int stepCount, stepsSinceReorder;
stream.read((char*) &stepCount, sizeof(int));
long long stepCount;
stream.read((char*) &stepCount, sizeof(long long));
int stepsSinceReorder;
stream.read((char*) &stepsSinceReorder, sizeof(int));
vector<CudaContext*>& contexts = cu.getPlatformData().contexts;
for (auto ctx : contexts) {
......
......@@ -394,13 +394,13 @@ public:
/**
* Get the number of integration steps that have been taken.
*/
int getStepCount() {
long long getStepCount() {
return stepCount;
}
/**
* Set the number of integration steps that have been taken.
*/
void setStepCount(int steps) {
void setStepCount(long long steps) {
stepCount = steps;
}
/**
......
......@@ -108,6 +108,18 @@ public:
* @param context the context in which to execute this kernel
*/
void setTime(ContextImpl& context, double time);
/**
* Get the current step count
*
* @param context the context in which to execute this kernel
*/
long long getStepCount(const ContextImpl& context) const;
/**
* Set the current step count
*
* @param context the context in which to execute this kernel
*/
void setStepCount(const ContextImpl& context, long long count);
/**
* Get the positions of all particles.
*
......
......@@ -117,8 +117,8 @@ public:
std::vector<OpenCLContext*> contexts;
std::vector<double> contextEnergy;
bool hasInitializedContexts, removeCM, useCpuPme, disablePmeStream;
int cmMotionFrequency;
int stepCount, computeForceCount;
int cmMotionFrequency, computeForceCount;
long long stepCount;
double time;
std::map<std::string, std::string> propertyValues;
ThreadPool threads;
......
......@@ -114,6 +114,16 @@ void OpenCLUpdateStateDataKernel::setTime(ContextImpl& context, double time) {
ctx->setTime(time);
}
long long OpenCLUpdateStateDataKernel::getStepCount(const ContextImpl& context) const {
return cl.getStepCount();
}
void OpenCLUpdateStateDataKernel::setStepCount(const ContextImpl& context, long long count) {
vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
for (auto ctx : contexts)
ctx->setStepCount(count);
}
void OpenCLUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) {
int numParticles = context.getSystem().getNumParticles();
positions.resize(numParticles);
......@@ -363,8 +373,8 @@ void OpenCLUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream
stream.write((char*) &precision, sizeof(int));
double time = cl.getTime();
stream.write((char*) &time, sizeof(double));
int stepCount = cl.getStepCount();
stream.write((char*) &stepCount, sizeof(int));
long long stepCount = cl.getStepCount();
stream.write((char*) &stepCount, sizeof(long long));
int stepsSinceReorder = cl.getStepsSinceReorder();
stream.write((char*) &stepsSinceReorder, sizeof(int));
char* buffer = (char*) cl.getPinnedBuffer();
......@@ -397,8 +407,9 @@ void OpenCLUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream&
throw OpenMMException("Checkpoint was created with a different numeric precision");
double time;
stream.read((char*) &time, sizeof(double));
int stepCount, stepsSinceReorder;
stream.read((char*) &stepCount, sizeof(int));
long long stepCount;
stream.read((char*) &stepCount, sizeof(long long));
int stepsSinceReorder;
stream.read((char*) &stepsSinceReorder, sizeof(int));
vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
for (auto ctx : contexts) {
......
......@@ -137,6 +137,18 @@ public:
* @param context the context in which to execute this kernel
*/
void setTime(ContextImpl& context, double time);
/**
* Get the current step count
*
* @param context the context in which to execute this kernel
*/
long long getStepCount(const ContextImpl& context) const;
/**
* Set the current step count
*
* @param context the context in which to execute this kernel
*/
void setStepCount(const ContextImpl& context, long long count);
/**
* Get the positions of all particles.
*
......
......@@ -63,7 +63,8 @@ class OPENMM_EXPORT ReferencePlatform::PlatformData {
public:
PlatformData(const System& system);
~PlatformData();
int numParticles, stepCount;
int numParticles;
long long stepCount;
double time;
std::vector<Vec3>* positions;
std::vector<Vec3>* velocities;
......
......@@ -213,6 +213,14 @@ void ReferenceUpdateStateDataKernel::setTime(ContextImpl& context, double time)
data.time = time;
}
long long ReferenceUpdateStateDataKernel::getStepCount(const ContextImpl& context) const {
return data.stepCount;
}
void ReferenceUpdateStateDataKernel::setStepCount(const ContextImpl& context, long long count) {
data.stepCount = count;
}
void ReferenceUpdateStateDataKernel::getPositions(ContextImpl& context, std::vector<Vec3>& positions) {
int numParticles = context.getSystem().getNumParticles();
vector<Vec3>& posData = extractPositions(context);
......@@ -283,6 +291,7 @@ void ReferenceUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostr
int version = 3;
stream.write((char*) &version, sizeof(int));
stream.write((char*) &data.time, sizeof(data.time));
stream.write((char*) &data.stepCount, sizeof(long long));
vector<Vec3>& posData = extractPositions(context);
stream.write((char*) &posData[0], sizeof(Vec3)*posData.size());
vector<Vec3>& velData = extractVelocities(context);
......@@ -298,6 +307,7 @@ void ReferenceUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istrea
if (version != 3)
throw OpenMMException("Checkpoint was created with a different version of OpenMM");
stream.read((char*) &data.time, sizeof(data.time));
stream.read((char*) &data.stepCount, sizeof(long long));
vector<Vec3>& posData = extractPositions(context);
stream.read((char*) &posData[0], sizeof(Vec3)*posData.size());
vector<Vec3>& velData = extractVelocities(context);
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -142,7 +142,7 @@ State RPMDIntegrator::getState(int copy, int types, bool enforcePeriodicBox, int
// Construct the new State.
State::StateBuilder builder(state.getTime());
State::StateBuilder builder(state.getTime(), state.getStepCount());
builder.setPositions(positions);
builder.setPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
if (types&State::Velocities)
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2010 Stanford University and the Authors. *
* Portions copyright (c) 2010-2021 Stanford University and the Authors *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -152,6 +152,28 @@ public:
* @param value the value to set for the property
*/
SerializationNode& setIntProperty(const std::string& name, int value);
/**
* Get the property with a particular name, specified as a long long. If there is no property with
* the specified name, an exception is thrown.
*
* @param name the name of the property to get
*/
long long getLongProperty(const std::string& name) const;
/**
* Get the property with a particular name, specified as a long long. If there is no property with
* the specified name, a default value is returned instead.
*
* @param name the name of the property to get
* @param defaultValue the value to return if the specified property does not exist
*/
long long getLongProperty(const std::string& name, long long defaultValue) const;
/**
* Set the value of a property, specified as a long long.
*
* @param name the name of the property to set
* @param value the value to set for the property
*/
SerializationNode& setLongProperty(const std::string& name, long long value);
/**
* Get the property with a particular name, specified as an bool. If there is no property with
* the specified name, an exception is thrown.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment