Unverified Commit fdfdb0f0 authored by peastman's avatar peastman Committed by GitHub
Browse files

Cleaned up checkpointing code for NoseHooverIntegrator (#2671)

* Cleaned up checkpointing code for NoseHooverIntegrator

* Fixed compilation error
parent 389f79e6
...@@ -1128,6 +1128,14 @@ public: ...@@ -1128,6 +1128,14 @@ public:
* @param scaleFactor the multiplicative factor by which {absolute, relative} velocities are scaled. * @param scaleFactor the multiplicative factor by which {absolute, relative} velocities are scaled.
*/ */
virtual void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactor) = 0; virtual void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactor) = 0;
/**
* Write the chain states to a checkpoint.
*/
virtual void createCheckpoint(ContextImpl& context, std::ostream& stream) const = 0;
/**
* Load the chain states from a checkpoint.
*/
virtual void loadCheckpoint(ContextImpl& context, std::istream& stream) = 0;
}; };
/** /**
......
...@@ -260,6 +260,16 @@ protected: ...@@ -260,6 +260,16 @@ protected:
* Computing kinetic energy for this integrator does not require forces. * Computing kinetic energy for this integrator does not require forces.
*/ */
bool kineticEnergyRequiresForce() const; bool kineticEnergyRequiresForce() const;
/**
* This is called while writing checkpoints. It gives the integrator a chance to write
* its own data.
*/
void createCheckpoint(std::ostream& stream) const;
/**
* This is called while loading a checkpoint. The integrator should read in whatever
* data it wrote in createCheckpoint() and update its internal state accordingly.
*/
void loadCheckpoint(std::istream& stream);
std::vector<NoseHooverChain> noseHooverChains; std::vector<NoseHooverChain> noseHooverChains;
std::vector<int> allAtoms; std::vector<int> allAtoms;
......
...@@ -341,3 +341,11 @@ void NoseHooverIntegrator::step(int steps) { ...@@ -341,3 +341,11 @@ void NoseHooverIntegrator::step(int steps) {
kernel.getAs<IntegrateNoseHooverStepKernel>().execute(*context, *this, forcesAreValid); kernel.getAs<IntegrateNoseHooverStepKernel>().execute(*context, *this, forcesAreValid);
} }
} }
void NoseHooverIntegrator::createCheckpoint(std::ostream& stream) const {
kernel.getAs<IntegrateNoseHooverStepKernel>().createCheckpoint(*context, stream);
}
void NoseHooverIntegrator::loadCheckpoint(std::istream& stream) {
kernel.getAs<IntegrateNoseHooverStepKernel>().loadCheckpoint(*context, stream);
}
\ No newline at end of file
...@@ -1009,10 +1009,19 @@ public: ...@@ -1009,10 +1009,19 @@ public:
* @param scaleFactor the multiplicative factor by which {absolute, relative} velocities are scaled. * @param scaleFactor the multiplicative factor by which {absolute, relative} velocities are scaled.
*/ */
void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactor); void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactor);
/**
* Write the chain states to a checkpoint.
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream) const;
/**
* Load the chain states from a checkpoint.
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private: private:
ComputeContext& cc; ComputeContext& cc;
float prevMaxPairDistance; float prevMaxPairDistance;
ComputeArray maxPairDistanceBuffer, pairListBuffer, atomListBuffer, pairTemperatureBuffer, oldDelta; ComputeArray maxPairDistanceBuffer, pairListBuffer, atomListBuffer, pairTemperatureBuffer, oldDelta;
std::map<int, ComputeArray> chainState;
ComputeKernel kernel1, kernel2, kernel3, kernel4, kernelHardWall; ComputeKernel kernel1, kernel2, kernel3, kernel4, kernelHardWall;
bool hasInitializedKernels; bool hasInitializedKernels;
ComputeKernel reduceEnergyKernel; ComputeKernel reduceEnergyKernel;
......
...@@ -130,12 +130,6 @@ public: ...@@ -130,12 +130,6 @@ public:
* @param timeShift the amount by which to shift the velocities in time * @param timeShift the amount by which to shift the velocities in time
*/ */
double computeKineticEnergy(double timeShift); double computeKineticEnergy(double timeShift);
/**
* Get the data structure that holds the state of all Nose-Hoover chains
*/
std::map<int, ComputeArray>& getNoseHooverChainState() {
return noseHooverChainState;
}
protected: protected:
virtual void applyConstraintsImpl(bool constrainVelocities, double tol) = 0; virtual void applyConstraintsImpl(bool constrainVelocities, double tol) = 0;
ComputeContext& context; ComputeContext& context;
...@@ -174,7 +168,6 @@ protected: ...@@ -174,7 +168,6 @@ protected:
ComputeArray vsiteLocalCoordsWeights; ComputeArray vsiteLocalCoordsWeights;
ComputeArray vsiteLocalCoordsPos; ComputeArray vsiteLocalCoordsPos;
ComputeArray vsiteLocalCoordsStartIndex; ComputeArray vsiteLocalCoordsStartIndex;
std::map<int, ComputeArray> noseHooverChainState;
int randomPos, lastSeed, numVsites; int randomPos, lastSeed, numVsites;
bool hasOverlappingVsites; bool hasOverlappingVsites;
mm_double2 lastStepSize; mm_double2 lastStepSize;
......
...@@ -5844,8 +5844,6 @@ std::pair<double, double> CommonIntegrateNoseHooverStepKernel::propagateChain(Co ...@@ -5844,8 +5844,6 @@ std::pair<double, double> CommonIntegrateNoseHooverStepKernel::propagateChain(Co
throw OpenMMException("Number of Yoshida Suzuki time steps has to be 1, 3, 5, or 7."); throw OpenMMException("Number of Yoshida Suzuki time steps has to be 1, 3, 5, or 7.");
} }
auto & chainState = cc.getIntegrationUtilities().getNoseHooverChainState();
if (!scaleFactorBuffer.isInitialized() || scaleFactorBuffer.getSize() == 0) { if (!scaleFactorBuffer.isInitialized() || scaleFactorBuffer.getSize() == 0) {
if (useDouble) { if (useDouble) {
std::vector<mm_double2> zeros{{0,0}}; std::vector<mm_double2> zeros{{0,0}};
...@@ -6007,8 +6005,6 @@ double CommonIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl& c ...@@ -6007,8 +6005,6 @@ double CommonIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl& c
int chainID = nhc.getChainID(); int chainID = nhc.getChainID();
int chainLength = nhc.getChainLength(); int chainLength = nhc.getChainLength();
auto & chainState = cc.getIntegrationUtilities().getNoseHooverChainState();
bool absChainIsValid = chainState.count(2*chainID) != 0 && bool absChainIsValid = chainState.count(2*chainID) != 0 &&
chainState[2*chainID].isInitialized() && chainState[2*chainID].isInitialized() &&
chainState[2*chainID].getSize() == chainLength; chainState[2*chainID].getSize() == chainLength;
...@@ -6211,6 +6207,54 @@ void CommonIntegrateNoseHooverStepKernel::scaleVelocities(ContextImpl& context, ...@@ -6211,6 +6207,54 @@ void CommonIntegrateNoseHooverStepKernel::scaleVelocities(ContextImpl& context,
} }
} }
void CommonIntegrateNoseHooverStepKernel::createCheckpoint(ContextImpl& context, ostream& stream) const {
int numChains = chainState.size();
bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision();
stream.write((char*) &numChains, sizeof(int));
for (auto& state : chainState){
int chainID = state.first;
int chainLength = state.second.getSize();
stream.write((char*) &chainID, sizeof(int));
stream.write((char*) &chainLength, sizeof(int));
if (useDouble) {
vector<mm_double2> stateVec;
state.second.download(stateVec);
stream.write((char*) stateVec.data(), sizeof(mm_double2)*chainLength);
}
else {
vector<mm_float2> stateVec;
state.second.download(stateVec);
stream.write((char*) stateVec.data(), sizeof(mm_float2)*chainLength);
}
}
}
void CommonIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
int numChains;
bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision();
stream.read((char*) &numChains, sizeof(int));
chainState.clear();
for (int i = 0; i < numChains; i++) {
int chainID, chainLength;
stream.read((char*) &chainID, sizeof(int));
stream.read((char*) &chainLength, sizeof(int));
if (useDouble) {
chainState[chainID] = ComputeArray();
chainState[chainID].initialize<mm_double2>(cc, chainLength, "chainState" + to_string(chainID));
vector<mm_double2> stateVec(chainLength);
stream.read((char*) &stateVec[0], sizeof(mm_double2)*chainLength);
chainState[chainID].upload(stateVec);
}
else {
chainState[chainID] = ComputeArray();
chainState[chainID].initialize<mm_float2>(cc, chainLength, "chainState" + to_string(chainID));
vector<mm_float2> stateVec(chainLength);
stream.read((char*) &stateVec[0], sizeof(mm_float2)*chainLength);
chainState[chainID].upload(stateVec);
}
}
}
void CommonIntegrateBrownianStepKernel::initialize(const System& system, const BrownianIntegrator& integrator) { void CommonIntegrateBrownianStepKernel::initialize(const System& system, const BrownianIntegrator& integrator) {
cc.initializeContexts(); cc.initializeContexts();
cc.setAsCurrent(); cc.setAsCurrent();
......
...@@ -753,25 +753,6 @@ int IntegrationUtilities::prepareRandomNumbers(int numValues) { ...@@ -753,25 +753,6 @@ int IntegrationUtilities::prepareRandomNumbers(int numValues) {
} }
void IntegrationUtilities::createCheckpoint(ostream& stream) { void IntegrationUtilities::createCheckpoint(ostream& stream) {
int numChains = noseHooverChainState.size();
bool useDouble = context.getUseDoublePrecision() || context.getUseMixedPrecision();
stream.write((char*) &numChains, sizeof(int));
for (auto &chainState: noseHooverChainState){
int chainID = chainState.first;
int chainLength = chainState.second.getSize();
stream.write((char*) &chainID, sizeof(int));
stream.write((char*) &chainLength, sizeof(int));
if (useDouble) {
vector<mm_double2> stateVec;
chainState.second.download(stateVec);
stream.write((char*) stateVec.data(), sizeof(mm_double2)*chainLength);
}
else {
vector<mm_float2> stateVec;
chainState.second.download(stateVec);
stream.write((char*) stateVec.data(), sizeof(mm_float2)*chainLength);
}
}
if (!random.isInitialized()) if (!random.isInitialized())
return; return;
stream.write((char*) &randomPos, sizeof(int)); stream.write((char*) &randomPos, sizeof(int));
...@@ -784,29 +765,6 @@ void IntegrationUtilities::createCheckpoint(ostream& stream) { ...@@ -784,29 +765,6 @@ void IntegrationUtilities::createCheckpoint(ostream& stream) {
} }
void IntegrationUtilities::loadCheckpoint(istream& stream) { void IntegrationUtilities::loadCheckpoint(istream& stream) {
int numChains;
bool useDouble = context.getUseDoublePrecision() || context.getUseMixedPrecision();
stream.read((char*) &numChains, sizeof(int));
noseHooverChainState.clear();
for (int i = 0; i < numChains; i++) {
int chainID, chainLength;
stream.read((char*) &chainID, sizeof(int));
stream.read((char*) &chainLength, sizeof(int));
if (useDouble) {
noseHooverChainState[chainID] = ComputeArray();
noseHooverChainState[chainID].initialize<mm_double2>(context, chainLength, "chainState" + to_string(chainID));
vector<mm_double2> stateVec(chainLength);
stream.read((char*) &stateVec[0], sizeof(mm_double2)*chainLength);
noseHooverChainState[chainID].upload(stateVec);
}
else {
noseHooverChainState[chainID] = ComputeArray();
noseHooverChainState[chainID].initialize<mm_float2>(context, chainLength, "chainState" + to_string(chainID));
vector<mm_float2> stateVec(chainLength);
stream.read((char*) &stateVec[0], sizeof(mm_float2)*chainLength);
noseHooverChainState[chainID].upload(stateVec);
}
}
if (!random.isInitialized()) if (!random.isInitialized())
return; return;
stream.read((char*) &randomPos, sizeof(int)); stream.read((char*) &randomPos, sizeof(int));
......
...@@ -1203,11 +1203,21 @@ public: ...@@ -1203,11 +1203,21 @@ public:
* @param scaleFactor the multiplicative factor by which {absolute, relative} velocities are scaled. * @param scaleFactor the multiplicative factor by which {absolute, relative} velocities are scaled.
*/ */
void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactor); void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactor);
/**
* Write the chain states to a checkpoint.
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream) const;
/**
* Load the chain states from a checkpoint.
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private: private:
ReferencePlatform::PlatformData& data; ReferencePlatform::PlatformData& data;
ReferenceNoseHooverChain* chainPropagator; ReferenceNoseHooverChain* chainPropagator;
ReferenceNoseHooverDynamics* dynamics; ReferenceNoseHooverDynamics* dynamics;
std::vector<double> masses; std::vector<double> masses;
std::vector<std::vector<double> > chainPositions;
std::vector<std::vector<double> > chainVelocities;
double prevStepSize; double prevStepSize;
}; };
......
...@@ -72,8 +72,6 @@ public: ...@@ -72,8 +72,6 @@ public:
Vec3* periodicBoxVectors; Vec3* periodicBoxVectors;
ReferenceConstraints* constraints; ReferenceConstraints* constraints;
std::map<std::string, double>* energyParameterDerivatives; std::map<std::string, double>* energyParameterDerivatives;
std::vector<std::vector<double>>* noseHooverPositions;
std::vector<std::vector<double>>* noseHooverVelocities;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -127,15 +127,6 @@ static map<string, double>& extractEnergyParameterDerivatives(ContextImpl& conte ...@@ -127,15 +127,6 @@ static map<string, double>& extractEnergyParameterDerivatives(ContextImpl& conte
return *data->energyParameterDerivatives; return *data->energyParameterDerivatives;
} }
static vector<vector<double> >& extractNoseHooverPositions(ContextImpl& context) {
ReferencePlatform::PlatformData *data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return *((vector<vector<double> >*) data->noseHooverPositions);
}
static vector<vector<double> >& extractNoseHooverVelocities(ContextImpl& context) {
ReferencePlatform::PlatformData *data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return *((vector<vector<double> >*) data->noseHooverVelocities);
}
/** /**
* Make sure an expression doesn't use any undefined variables. * Make sure an expression doesn't use any undefined variables.
*/ */
...@@ -297,20 +288,6 @@ void ReferenceUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostr ...@@ -297,20 +288,6 @@ void ReferenceUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostr
stream.write((char*) &velData[0], sizeof(Vec3)*velData.size()); stream.write((char*) &velData[0], sizeof(Vec3)*velData.size());
Vec3* vectors = extractBoxVectors(context); Vec3* vectors = extractBoxVectors(context);
stream.write((char*) vectors, 3*sizeof(Vec3)); stream.write((char*) vectors, 3*sizeof(Vec3));
auto& allNoseHooverPositions = extractNoseHooverPositions(context);
auto& allNoseHooverVelocities = extractNoseHooverVelocities(context);
size_t numChains = allNoseHooverPositions.size();
assert(numChains == allNoseHooverVelocities.size());
stream.write((char*) &numChains, sizeof(size_t));
for (size_t i=0; i<numChains; i++){
auto & noseHooverPositions = allNoseHooverPositions.at(i);
auto & noseHooverVelocities = allNoseHooverVelocities.at(i);
size_t numBeads = noseHooverPositions.size();
assert(numBeads == noseHooverVelocities.size());
stream.write((char*) &numBeads, sizeof(size_t));
stream.write((char*) noseHooverPositions.data(), sizeof(double)*numBeads);
stream.write((char*) noseHooverVelocities.data(), sizeof(double)*numBeads);
}
SimTKOpenMMUtilities::createCheckpoint(stream); SimTKOpenMMUtilities::createCheckpoint(stream);
} }
...@@ -326,21 +303,6 @@ void ReferenceUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istrea ...@@ -326,21 +303,6 @@ void ReferenceUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istrea
stream.read((char*) &velData[0], sizeof(Vec3)*velData.size()); stream.read((char*) &velData[0], sizeof(Vec3)*velData.size());
Vec3* vectors = extractBoxVectors(context); Vec3* vectors = extractBoxVectors(context);
stream.read((char*) vectors, 3*sizeof(Vec3)); stream.read((char*) vectors, 3*sizeof(Vec3));
size_t numChains, numBeads;
auto& allNoseHooverPositions = extractNoseHooverPositions(context);
auto& allNoseHooverVelocities = extractNoseHooverVelocities(context);
stream.read((char*) &numChains, sizeof(size_t));
allNoseHooverPositions.clear();
allNoseHooverVelocities.clear();
for (size_t i=0; i<numChains; i++){
stream.read((char*) &numBeads, sizeof(size_t));
std::vector<double> noseHooverPositions(numBeads);
std::vector<double> noseHooverVelocities(numBeads);
stream.read((char*) &noseHooverPositions[0], sizeof(double)*numBeads);
stream.read((char*) &noseHooverVelocities[0], sizeof(double)*numBeads);
allNoseHooverPositions.push_back(noseHooverPositions);
allNoseHooverVelocities.push_back(noseHooverVelocities);
}
SimTKOpenMMUtilities::loadCheckpoint(stream); SimTKOpenMMUtilities::loadCheckpoint(stream);
} }
...@@ -2211,53 +2173,49 @@ std::pair<double, double> ReferenceIntegrateNoseHooverStepKernel::propagateChain ...@@ -2211,53 +2173,49 @@ std::pair<double, double> ReferenceIntegrateNoseHooverStepKernel::propagateChain
int numDOFs = nhc.getNumDegreesOfFreedom(); int numDOFs = nhc.getNumDegreesOfFreedom();
int numMTS = nhc.getNumMultiTimeSteps(); int numMTS = nhc.getNumMultiTimeSteps();
// Get the state of the NHC from the context
auto& allChainPositions = extractNoseHooverPositions(context);
auto& allChainVelocities = extractNoseHooverVelocities(context);
int nAtoms = nhc.getThermostatedAtoms().size(); int nAtoms = nhc.getThermostatedAtoms().size();
double absScale = 0; double absScale = 0;
if (nAtoms) { if (nAtoms) {
if (allChainPositions.size() < 2*chainID+1){ if (chainPositions.size() < 2*chainID+1){
allChainPositions.resize(2*chainID+1); chainPositions.resize(2*chainID+1);
} }
if (allChainVelocities.size() < 2*chainID+1){ if (chainVelocities.size() < 2*chainID+1){
allChainVelocities.resize(2*chainID+1); chainVelocities.resize(2*chainID+1);
} }
auto& chainPositions = allChainPositions.at(2*chainID); auto& positions = chainPositions.at(2*chainID);
auto& chainVelocities = allChainVelocities.at(2*chainID); auto& velocities = chainVelocities.at(2*chainID);
if (chainPositions.size() < chainLength){ if (positions.size() < chainLength){
chainPositions.resize(chainLength, 0); positions.resize(chainLength, 0);
} }
if (chainVelocities.size() < chainLength){ if (velocities.size() < chainLength){
chainVelocities.resize(chainLength, 0); velocities.resize(chainLength, 0);
} }
double temperature = nhc.getTemperature(); double temperature = nhc.getTemperature();
double collisionFrequency = nhc.getCollisionFrequency(); double collisionFrequency = nhc.getCollisionFrequency();
absScale = chainPropagator->propagate(absKE, chainVelocities, chainPositions, numDOFs, absScale = chainPropagator->propagate(absKE, velocities, positions, numDOFs,
temperature, collisionFrequency, timeStep, temperature, collisionFrequency, timeStep,
numMTS, nhc.getYoshidaSuzukiWeights()); numMTS, nhc.getYoshidaSuzukiWeights());
} }
double relScale = 0; double relScale = 0;
int nPairs = nhc.getThermostatedPairs().size(); int nPairs = nhc.getThermostatedPairs().size();
if (nPairs) { if (nPairs) {
if (allChainPositions.size() < 2*chainID+2){ if (chainPositions.size() < 2*chainID+2){
allChainPositions.resize(2*chainID+2); chainPositions.resize(2*chainID+2);
} }
if (allChainVelocities.size() < 2*chainID+2){ if (chainVelocities.size() < 2*chainID+2){
allChainVelocities.resize(2*chainID+2); chainVelocities.resize(2*chainID+2);
} }
auto& chainPositions = allChainPositions.at(2*chainID+1); auto& positions = chainPositions.at(2*chainID+1);
auto& chainVelocities = allChainVelocities.at(2*chainID+1); auto& velocities = chainVelocities.at(2*chainID+1);
if (chainPositions.size() < chainLength){ if (positions.size() < chainLength){
chainPositions.resize(chainLength, 0); positions.resize(chainLength, 0);
} }
if (chainVelocities.size() < chainLength){ if (velocities.size() < chainLength){
chainVelocities.resize(chainLength, 0); velocities.resize(chainLength, 0);
} }
double temperature = nhc.getRelativeTemperature(); double temperature = nhc.getRelativeTemperature();
double collisionFrequency = nhc.getRelativeCollisionFrequency(); double collisionFrequency = nhc.getRelativeCollisionFrequency();
relScale = chainPropagator->propagate(relKE, chainVelocities, chainPositions, 3*nPairs, relScale = chainPropagator->propagate(relKE, velocities, positions, 3*nPairs,
temperature, collisionFrequency, timeStep, temperature, collisionFrequency, timeStep,
numMTS, nhc.getYoshidaSuzukiWeights()); numMTS, nhc.getYoshidaSuzukiWeights());
} }
...@@ -2271,8 +2229,6 @@ double ReferenceIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl ...@@ -2271,8 +2229,6 @@ double ReferenceIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl
int chainID = nhc.getChainID(); int chainID = nhc.getChainID();
int nAtoms = nhc.getThermostatedAtoms().size(); int nAtoms = nhc.getThermostatedAtoms().size();
int nPairs = nhc.getThermostatedPairs().size(); int nPairs = nhc.getThermostatedPairs().size();
auto& nhcPositions = extractNoseHooverPositions(context);
auto& nhcVelocities = extractNoseHooverVelocities(context);
if (nAtoms) { if (nAtoms) {
double temperature = nhc.getTemperature(); double temperature = nhc.getTemperature();
double collisionFrequency = nhc.getCollisionFrequency(); double collisionFrequency = nhc.getCollisionFrequency();
...@@ -2281,11 +2237,11 @@ double ReferenceIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl ...@@ -2281,11 +2237,11 @@ double ReferenceIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl
for(int i = 0; i < chainLength; ++i) { for(int i = 0; i < chainLength; ++i) {
double prefac = i ? 1 : numDOFs; double prefac = i ? 1 : numDOFs;
double mass = prefac * kT / (collisionFrequency * collisionFrequency); double mass = prefac * kT / (collisionFrequency * collisionFrequency);
double velocity = nhcVelocities[2*chainID][i]; double velocity = chainVelocities[2*chainID][i];
// The kinetic energy of this bead // The kinetic energy of this bead
kineticEnergy += 0.5 * mass * velocity * velocity; kineticEnergy += 0.5 * mass * velocity * velocity;
// The potential energy of this bead // The potential energy of this bead
double position = nhcPositions[2*chainID][i]; double position = chainPositions[2*chainID][i];
potentialEnergy += prefac * kT * position; potentialEnergy += prefac * kT * position;
} }
} }
...@@ -2297,11 +2253,11 @@ double ReferenceIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl ...@@ -2297,11 +2253,11 @@ double ReferenceIntegrateNoseHooverStepKernel::computeHeatBathEnergy(ContextImpl
for(int i = 0; i < chainLength; ++i) { for(int i = 0; i < chainLength; ++i) {
double prefac = i ? 1 : numDOFs; double prefac = i ? 1 : numDOFs;
double mass = prefac * kT / (collisionFrequency * collisionFrequency); double mass = prefac * kT / (collisionFrequency * collisionFrequency);
double velocity = nhcVelocities[2*chainID+1][i]; double velocity = chainVelocities[2*chainID+1][i];
// The kinetic energy of this bead // The kinetic energy of this bead
kineticEnergy += 0.5 * mass * velocity * velocity; kineticEnergy += 0.5 * mass * velocity * velocity;
// The potential energy of this bead // The potential energy of this bead
double position = nhcPositions[2*chainID+1][i]; double position = chainPositions[2*chainID+1][i];
potentialEnergy += prefac * kT * position; potentialEnergy += prefac * kT * position;
} }
} }
...@@ -2380,6 +2336,37 @@ void ReferenceIntegrateNoseHooverStepKernel::scaleVelocities(ContextImpl& contex ...@@ -2380,6 +2336,37 @@ void ReferenceIntegrateNoseHooverStepKernel::scaleVelocities(ContextImpl& contex
} }
} }
void ReferenceIntegrateNoseHooverStepKernel::createCheckpoint(ContextImpl& context, ostream& stream) const {
size_t numChains = chainPositions.size();
assert(numChains == chainVelocities.size());
stream.write((char*) &numChains, sizeof(size_t));
for (size_t i=0; i<numChains; i++){
auto & noseHooverPositions = chainPositions.at(i);
auto & noseHooverVelocities = chainVelocities.at(i);
size_t numBeads = noseHooverPositions.size();
assert(numBeads == noseHooverVelocities.size());
stream.write((char*) &numBeads, sizeof(size_t));
stream.write((char*) noseHooverPositions.data(), sizeof(double)*numBeads);
stream.write((char*) noseHooverVelocities.data(), sizeof(double)*numBeads);
}
}
void ReferenceIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
size_t numChains, numBeads;
stream.read((char*) &numChains, sizeof(size_t));
chainPositions.clear();
chainVelocities.clear();
for (size_t i=0; i<numChains; i++){
stream.read((char*) &numBeads, sizeof(size_t));
std::vector<double> noseHooverPositions(numBeads);
std::vector<double> noseHooverVelocities(numBeads);
stream.read((char*) &noseHooverPositions[0], sizeof(double)*numBeads);
stream.read((char*) &noseHooverVelocities[0], sizeof(double)*numBeads);
chainPositions.push_back(noseHooverPositions);
chainVelocities.push_back(noseHooverVelocities);
}
}
ReferenceIntegrateLangevinStepKernel::~ReferenceIntegrateLangevinStepKernel() { ReferenceIntegrateLangevinStepKernel::~ReferenceIntegrateLangevinStepKernel() {
if (dynamics) if (dynamics)
delete dynamics; delete dynamics;
......
...@@ -103,8 +103,6 @@ ReferencePlatform::PlatformData::PlatformData(const System& system) : time(0.0), ...@@ -103,8 +103,6 @@ ReferencePlatform::PlatformData::PlatformData(const System& system) : time(0.0),
periodicBoxVectors = new Vec3[3]; periodicBoxVectors = new Vec3[3];
constraints = new ReferenceConstraints(system); constraints = new ReferenceConstraints(system);
energyParameterDerivatives = new map<string, double>(); energyParameterDerivatives = new map<string, double>();
noseHooverPositions = new vector<vector<double> >();
noseHooverVelocities = new vector<vector<double> >();
} }
ReferencePlatform::PlatformData::~PlatformData() { ReferencePlatform::PlatformData::~PlatformData() {
...@@ -115,6 +113,4 @@ ReferencePlatform::PlatformData::~PlatformData() { ...@@ -115,6 +113,4 @@ ReferencePlatform::PlatformData::~PlatformData() {
delete[] periodicBoxVectors; delete[] periodicBoxVectors;
delete constraints; delete constraints;
delete energyParameterDerivatives; delete energyParameterDerivatives;
delete noseHooverPositions;
delete noseHooverVelocities;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment