"vscode:/vscode.git/clone" did not exist on "f7ef2dd001984e3be35498e41f8f8106ee0e25fa"
Unverified Commit fdfdb0f0 authored by peastman's avatar peastman Committed by GitHub
Browse files

Cleaned up checkpointing code for NoseHooverIntegrator (#2671)

* Cleaned up checkpointing code for NoseHooverIntegrator

* Fixed compilation error
parent 389f79e6
......@@ -1128,6 +1128,14 @@ public:
* @param scaleFactor the multiplicative factor by which {absolute, relative} velocities are scaled.
*/
virtual void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactor) = 0;
/**
* Write the chain states to a checkpoint.
*/
virtual void createCheckpoint(ContextImpl& context, std::ostream& stream) const = 0;
/**
* Load the chain states from a checkpoint.
*/
virtual void loadCheckpoint(ContextImpl& context, std::istream& stream) = 0;
};
/**
......
......@@ -260,6 +260,16 @@ protected:
* Computing kinetic energy for this integrator does not require forces.
*/
bool kineticEnergyRequiresForce() const;
/**
* This is called while writing checkpoints. It gives the integrator a chance to write
* its own data.
*/
void createCheckpoint(std::ostream& stream) const;
/**
* This is called while loading a checkpoint. The integrator should read in whatever
* data it wrote in createCheckpoint() and update its internal state accordingly.
*/
void loadCheckpoint(std::istream& stream);
std::vector<NoseHooverChain> noseHooverChains;
std::vector<int> allAtoms;
......
......@@ -341,3 +341,11 @@ void NoseHooverIntegrator::step(int steps) {
kernel.getAs<IntegrateNoseHooverStepKernel>().execute(*context, *this, forcesAreValid);
}
}
void NoseHooverIntegrator::createCheckpoint(std::ostream& stream) const {
kernel.getAs<IntegrateNoseHooverStepKernel>().createCheckpoint(*context, stream);
}
void NoseHooverIntegrator::loadCheckpoint(std::istream& stream) {
kernel.getAs<IntegrateNoseHooverStepKernel>().loadCheckpoint(*context, stream);
}
\ No newline at end of file
......@@ -1009,10 +1009,19 @@ public:
* @param scaleFactor the multiplicative factor by which {absolute, relative} velocities are scaled.
*/
void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactor);
/**
* Write the chain states to a checkpoint.
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream) const;
/**
* Load the chain states from a checkpoint.
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private:
ComputeContext& cc;
float prevMaxPairDistance;
ComputeArray maxPairDistanceBuffer, pairListBuffer, atomListBuffer, pairTemperatureBuffer, oldDelta;
std::map<int, ComputeArray> chainState;
ComputeKernel kernel1, kernel2, kernel3, kernel4, kernelHardWall;
bool hasInitializedKernels;
ComputeKernel reduceEnergyKernel;
......
......@@ -130,12 +130,6 @@ public:
* @param timeShift the amount by which to shift the velocities in time
*/
double computeKineticEnergy(double timeShift);
/**
* Get the data structure that holds the state of all Nose-Hoover chains
*/
std::map<int, ComputeArray>& getNoseHooverChainState() {
return noseHooverChainState;
}
protected:
virtual void applyConstraintsImpl(bool constrainVelocities, double tol) = 0;
ComputeContext& context;
......@@ -174,7 +168,6 @@ protected:
ComputeArray vsiteLocalCoordsWeights;
ComputeArray vsiteLocalCoordsPos;
ComputeArray vsiteLocalCoordsStartIndex;
std::map<int, ComputeArray> noseHooverChainState;
int randomPos, lastSeed, numVsites;
bool hasOverlappingVsites;
mm_double2 lastStepSize;
......
......@@ -5844,8 +5844,6 @@ std::pair<double, double> CommonIntegrateNoseHooverStepKernel::propagateChain(Co
throw OpenMMException("Number of Yoshida Suzuki time steps has to be 1, 3, 5, or 7.");
}
auto & chainState = cc.getIntegrationUtilities().getNoseHooverChainState();
if (!scaleFactorBuffer.isInitialized() || scaleFactorBuffer.getSize() == 0) {
if (useDouble) {
std::vector<mm_double2> zeros{{0,0}};
......@@ -6007,8 +6005,6 @@ double CommonIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl& c
int chainID = nhc.getChainID();
int chainLength = nhc.getChainLength();
auto & chainState = cc.getIntegrationUtilities().getNoseHooverChainState();
bool absChainIsValid = chainState.count(2*chainID) != 0 &&
chainState[2*chainID].isInitialized() &&
chainState[2*chainID].getSize() == chainLength;
......@@ -6211,6 +6207,54 @@ void CommonIntegrateNoseHooverStepKernel::scaleVelocities(ContextImpl& context,
}
}
void CommonIntegrateNoseHooverStepKernel::createCheckpoint(ContextImpl& context, ostream& stream) const {
int numChains = chainState.size();
bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision();
stream.write((char*) &numChains, sizeof(int));
for (auto& state : chainState){
int chainID = state.first;
int chainLength = state.second.getSize();
stream.write((char*) &chainID, sizeof(int));
stream.write((char*) &chainLength, sizeof(int));
if (useDouble) {
vector<mm_double2> stateVec;
state.second.download(stateVec);
stream.write((char*) stateVec.data(), sizeof(mm_double2)*chainLength);
}
else {
vector<mm_float2> stateVec;
state.second.download(stateVec);
stream.write((char*) stateVec.data(), sizeof(mm_float2)*chainLength);
}
}
}
void CommonIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
int numChains;
bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision();
stream.read((char*) &numChains, sizeof(int));
chainState.clear();
for (int i = 0; i < numChains; i++) {
int chainID, chainLength;
stream.read((char*) &chainID, sizeof(int));
stream.read((char*) &chainLength, sizeof(int));
if (useDouble) {
chainState[chainID] = ComputeArray();
chainState[chainID].initialize<mm_double2>(cc, chainLength, "chainState" + to_string(chainID));
vector<mm_double2> stateVec(chainLength);
stream.read((char*) &stateVec[0], sizeof(mm_double2)*chainLength);
chainState[chainID].upload(stateVec);
}
else {
chainState[chainID] = ComputeArray();
chainState[chainID].initialize<mm_float2>(cc, chainLength, "chainState" + to_string(chainID));
vector<mm_float2> stateVec(chainLength);
stream.read((char*) &stateVec[0], sizeof(mm_float2)*chainLength);
chainState[chainID].upload(stateVec);
}
}
}
void CommonIntegrateBrownianStepKernel::initialize(const System& system, const BrownianIntegrator& integrator) {
cc.initializeContexts();
cc.setAsCurrent();
......
......@@ -753,25 +753,6 @@ int IntegrationUtilities::prepareRandomNumbers(int numValues) {
}
void IntegrationUtilities::createCheckpoint(ostream& stream) {
int numChains = noseHooverChainState.size();
bool useDouble = context.getUseDoublePrecision() || context.getUseMixedPrecision();
stream.write((char*) &numChains, sizeof(int));
for (auto &chainState: noseHooverChainState){
int chainID = chainState.first;
int chainLength = chainState.second.getSize();
stream.write((char*) &chainID, sizeof(int));
stream.write((char*) &chainLength, sizeof(int));
if (useDouble) {
vector<mm_double2> stateVec;
chainState.second.download(stateVec);
stream.write((char*) stateVec.data(), sizeof(mm_double2)*chainLength);
}
else {
vector<mm_float2> stateVec;
chainState.second.download(stateVec);
stream.write((char*) stateVec.data(), sizeof(mm_float2)*chainLength);
}
}
if (!random.isInitialized())
return;
stream.write((char*) &randomPos, sizeof(int));
......@@ -784,29 +765,6 @@ void IntegrationUtilities::createCheckpoint(ostream& stream) {
}
void IntegrationUtilities::loadCheckpoint(istream& stream) {
int numChains;
bool useDouble = context.getUseDoublePrecision() || context.getUseMixedPrecision();
stream.read((char*) &numChains, sizeof(int));
noseHooverChainState.clear();
for (int i = 0; i < numChains; i++) {
int chainID, chainLength;
stream.read((char*) &chainID, sizeof(int));
stream.read((char*) &chainLength, sizeof(int));
if (useDouble) {
noseHooverChainState[chainID] = ComputeArray();
noseHooverChainState[chainID].initialize<mm_double2>(context, chainLength, "chainState" + to_string(chainID));
vector<mm_double2> stateVec(chainLength);
stream.read((char*) &stateVec[0], sizeof(mm_double2)*chainLength);
noseHooverChainState[chainID].upload(stateVec);
}
else {
noseHooverChainState[chainID] = ComputeArray();
noseHooverChainState[chainID].initialize<mm_float2>(context, chainLength, "chainState" + to_string(chainID));
vector<mm_float2> stateVec(chainLength);
stream.read((char*) &stateVec[0], sizeof(mm_float2)*chainLength);
noseHooverChainState[chainID].upload(stateVec);
}
}
if (!random.isInitialized())
return;
stream.read((char*) &randomPos, sizeof(int));
......
......@@ -1203,11 +1203,21 @@ public:
* @param scaleFactor the multiplicative factor by which {absolute, relative} velocities are scaled.
*/
void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactor);
/**
* Write the chain states to a checkpoint.
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream) const;
/**
* Load the chain states from a checkpoint.
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private:
ReferencePlatform::PlatformData& data;
ReferenceNoseHooverChain* chainPropagator;
ReferenceNoseHooverDynamics* dynamics;
std::vector<double> masses;
std::vector<std::vector<double> > chainPositions;
std::vector<std::vector<double> > chainVelocities;
double prevStepSize;
};
......
......@@ -72,8 +72,6 @@ public:
Vec3* periodicBoxVectors;
ReferenceConstraints* constraints;
std::map<std::string, double>* energyParameterDerivatives;
std::vector<std::vector<double>>* noseHooverPositions;
std::vector<std::vector<double>>* noseHooverVelocities;
};
} // namespace OpenMM
......
......@@ -127,15 +127,6 @@ static map<string, double>& extractEnergyParameterDerivatives(ContextImpl& conte
return *data->energyParameterDerivatives;
}
static vector<vector<double> >& extractNoseHooverPositions(ContextImpl& context) {
ReferencePlatform::PlatformData *data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return *((vector<vector<double> >*) data->noseHooverPositions);
}
static vector<vector<double> >& extractNoseHooverVelocities(ContextImpl& context) {
ReferencePlatform::PlatformData *data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return *((vector<vector<double> >*) data->noseHooverVelocities);
}
/**
* Make sure an expression doesn't use any undefined variables.
*/
......@@ -297,20 +288,6 @@ void ReferenceUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostr
stream.write((char*) &velData[0], sizeof(Vec3)*velData.size());
Vec3* vectors = extractBoxVectors(context);
stream.write((char*) vectors, 3*sizeof(Vec3));
auto& allNoseHooverPositions = extractNoseHooverPositions(context);
auto& allNoseHooverVelocities = extractNoseHooverVelocities(context);
size_t numChains = allNoseHooverPositions.size();
assert(numChains == allNoseHooverVelocities.size());
stream.write((char*) &numChains, sizeof(size_t));
for (size_t i=0; i<numChains; i++){
auto & noseHooverPositions = allNoseHooverPositions.at(i);
auto & noseHooverVelocities = allNoseHooverVelocities.at(i);
size_t numBeads = noseHooverPositions.size();
assert(numBeads == noseHooverVelocities.size());
stream.write((char*) &numBeads, sizeof(size_t));
stream.write((char*) noseHooverPositions.data(), sizeof(double)*numBeads);
stream.write((char*) noseHooverVelocities.data(), sizeof(double)*numBeads);
}
SimTKOpenMMUtilities::createCheckpoint(stream);
}
......@@ -326,21 +303,6 @@ void ReferenceUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istrea
stream.read((char*) &velData[0], sizeof(Vec3)*velData.size());
Vec3* vectors = extractBoxVectors(context);
stream.read((char*) vectors, 3*sizeof(Vec3));
size_t numChains, numBeads;
auto& allNoseHooverPositions = extractNoseHooverPositions(context);
auto& allNoseHooverVelocities = extractNoseHooverVelocities(context);
stream.read((char*) &numChains, sizeof(size_t));
allNoseHooverPositions.clear();
allNoseHooverVelocities.clear();
for (size_t i=0; i<numChains; i++){
stream.read((char*) &numBeads, sizeof(size_t));
std::vector<double> noseHooverPositions(numBeads);
std::vector<double> noseHooverVelocities(numBeads);
stream.read((char*) &noseHooverPositions[0], sizeof(double)*numBeads);
stream.read((char*) &noseHooverVelocities[0], sizeof(double)*numBeads);
allNoseHooverPositions.push_back(noseHooverPositions);
allNoseHooverVelocities.push_back(noseHooverVelocities);
}
SimTKOpenMMUtilities::loadCheckpoint(stream);
}
......@@ -2211,53 +2173,49 @@ std::pair<double, double> ReferenceIntegrateNoseHooverStepKernel::propagateChain
int numDOFs = nhc.getNumDegreesOfFreedom();
int numMTS = nhc.getNumMultiTimeSteps();
// Get the state of the NHC from the context
auto& allChainPositions = extractNoseHooverPositions(context);
auto& allChainVelocities = extractNoseHooverVelocities(context);
int nAtoms = nhc.getThermostatedAtoms().size();
double absScale = 0;
if (nAtoms) {
if (allChainPositions.size() < 2*chainID+1){
allChainPositions.resize(2*chainID+1);
if (chainPositions.size() < 2*chainID+1){
chainPositions.resize(2*chainID+1);
}
if (allChainVelocities.size() < 2*chainID+1){
allChainVelocities.resize(2*chainID+1);
if (chainVelocities.size() < 2*chainID+1){
chainVelocities.resize(2*chainID+1);
}
auto& chainPositions = allChainPositions.at(2*chainID);
auto& chainVelocities = allChainVelocities.at(2*chainID);
if (chainPositions.size() < chainLength){
chainPositions.resize(chainLength, 0);
auto& positions = chainPositions.at(2*chainID);
auto& velocities = chainVelocities.at(2*chainID);
if (positions.size() < chainLength){
positions.resize(chainLength, 0);
}
if (chainVelocities.size() < chainLength){
chainVelocities.resize(chainLength, 0);
if (velocities.size() < chainLength){
velocities.resize(chainLength, 0);
}
double temperature = nhc.getTemperature();
double collisionFrequency = nhc.getCollisionFrequency();
absScale = chainPropagator->propagate(absKE, chainVelocities, chainPositions, numDOFs,
absScale = chainPropagator->propagate(absKE, velocities, positions, numDOFs,
temperature, collisionFrequency, timeStep,
numMTS, nhc.getYoshidaSuzukiWeights());
}
double relScale = 0;
int nPairs = nhc.getThermostatedPairs().size();
if (nPairs) {
if (allChainPositions.size() < 2*chainID+2){
allChainPositions.resize(2*chainID+2);
if (chainPositions.size() < 2*chainID+2){
chainPositions.resize(2*chainID+2);
}
if (allChainVelocities.size() < 2*chainID+2){
allChainVelocities.resize(2*chainID+2);
if (chainVelocities.size() < 2*chainID+2){
chainVelocities.resize(2*chainID+2);
}
auto& chainPositions = allChainPositions.at(2*chainID+1);
auto& chainVelocities = allChainVelocities.at(2*chainID+1);
if (chainPositions.size() < chainLength){
chainPositions.resize(chainLength, 0);
auto& positions = chainPositions.at(2*chainID+1);
auto& velocities = chainVelocities.at(2*chainID+1);
if (positions.size() < chainLength){
positions.resize(chainLength, 0);
}
if (chainVelocities.size() < chainLength){
chainVelocities.resize(chainLength, 0);
if (velocities.size() < chainLength){
velocities.resize(chainLength, 0);
}
double temperature = nhc.getRelativeTemperature();
double collisionFrequency = nhc.getRelativeCollisionFrequency();
relScale = chainPropagator->propagate(relKE, chainVelocities, chainPositions, 3*nPairs,
relScale = chainPropagator->propagate(relKE, velocities, positions, 3*nPairs,
temperature, collisionFrequency, timeStep,
numMTS, nhc.getYoshidaSuzukiWeights());
}
......@@ -2271,8 +2229,6 @@ double ReferenceIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl
int chainID = nhc.getChainID();
int nAtoms = nhc.getThermostatedAtoms().size();
int nPairs = nhc.getThermostatedPairs().size();
auto& nhcPositions = extractNoseHooverPositions(context);
auto& nhcVelocities = extractNoseHooverVelocities(context);
if (nAtoms) {
double temperature = nhc.getTemperature();
double collisionFrequency = nhc.getCollisionFrequency();
......@@ -2281,11 +2237,11 @@ double ReferenceIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl
for(int i = 0; i < chainLength; ++i) {
double prefac = i ? 1 : numDOFs;
double mass = prefac * kT / (collisionFrequency * collisionFrequency);
double velocity = nhcVelocities[2*chainID][i];
double velocity = chainVelocities[2*chainID][i];
// The kinetic energy of this bead
kineticEnergy += 0.5 * mass * velocity * velocity;
// The potential energy of this bead
double position = nhcPositions[2*chainID][i];
double position = chainPositions[2*chainID][i];
potentialEnergy += prefac * kT * position;
}
}
......@@ -2297,11 +2253,11 @@ double ReferenceIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl
for(int i = 0; i < chainLength; ++i) {
double prefac = i ? 1 : numDOFs;
double mass = prefac * kT / (collisionFrequency * collisionFrequency);
double velocity = nhcVelocities[2*chainID+1][i];
double velocity = chainVelocities[2*chainID+1][i];
// The kinetic energy of this bead
kineticEnergy += 0.5 * mass * velocity * velocity;
// The potential energy of this bead
double position = nhcPositions[2*chainID+1][i];
double position = chainPositions[2*chainID+1][i];
potentialEnergy += prefac * kT * position;
}
}
......@@ -2380,6 +2336,37 @@ void ReferenceIntegrateNoseHooverStepKernel::scaleVelocities(ContextImpl& contex
}
}
void ReferenceIntegrateNoseHooverStepKernel::createCheckpoint(ContextImpl& context, ostream& stream) const {
size_t numChains = chainPositions.size();
assert(numChains == chainVelocities.size());
stream.write((char*) &numChains, sizeof(size_t));
for (size_t i=0; i<numChains; i++){
auto & noseHooverPositions = chainPositions.at(i);
auto & noseHooverVelocities = chainVelocities.at(i);
size_t numBeads = noseHooverPositions.size();
assert(numBeads == noseHooverVelocities.size());
stream.write((char*) &numBeads, sizeof(size_t));
stream.write((char*) noseHooverPositions.data(), sizeof(double)*numBeads);
stream.write((char*) noseHooverVelocities.data(), sizeof(double)*numBeads);
}
}
void ReferenceIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
size_t numChains, numBeads;
stream.read((char*) &numChains, sizeof(size_t));
chainPositions.clear();
chainVelocities.clear();
for (size_t i=0; i<numChains; i++){
stream.read((char*) &numBeads, sizeof(size_t));
std::vector<double> noseHooverPositions(numBeads);
std::vector<double> noseHooverVelocities(numBeads);
stream.read((char*) &noseHooverPositions[0], sizeof(double)*numBeads);
stream.read((char*) &noseHooverVelocities[0], sizeof(double)*numBeads);
chainPositions.push_back(noseHooverPositions);
chainVelocities.push_back(noseHooverVelocities);
}
}
ReferenceIntegrateLangevinStepKernel::~ReferenceIntegrateLangevinStepKernel() {
if (dynamics)
delete dynamics;
......
......@@ -103,8 +103,6 @@ ReferencePlatform::PlatformData::PlatformData(const System& system) : time(0.0),
periodicBoxVectors = new Vec3[3];
constraints = new ReferenceConstraints(system);
energyParameterDerivatives = new map<string, double>();
noseHooverPositions = new vector<vector<double> >();
noseHooverVelocities = new vector<vector<double> >();
}
ReferencePlatform::PlatformData::~PlatformData() {
......@@ -115,6 +113,4 @@ ReferencePlatform::PlatformData::~PlatformData() {
delete[] periodicBoxVectors;
delete constraints;
delete energyParameterDerivatives;
delete noseHooverPositions;
delete noseHooverVelocities;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment