Unverified Commit f67ae730 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

CPU platform checkpoints random number generator (#4740)

* CPU platform checkpoints random number generator

* Fix Windows compilation error

* Another Windows compilation error
parent 9fe1bae6
...@@ -95,7 +95,7 @@ OPENMM_EXPORT int get_min_array_size64(void); ...@@ -95,7 +95,7 @@ OPENMM_EXPORT int get_min_array_size64(void);
OPENMM_EXPORT SFMTData* createSFMTData(void); OPENMM_EXPORT SFMTData* createSFMTData(void);
OPENMM_EXPORT void deleteSFMTData(SFMTData* data); OPENMM_EXPORT void deleteSFMTData(SFMTData* data);
class SFMT { class OPENMM_EXPORT SFMT {
public: public:
SFMT() : data(createSFMTData()) { SFMT() : data(createSFMTData()) {
} }
......
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include "CpuNeighborList.h" #include "CpuNeighborList.h"
#include "CpuNonbondedForce.h" #include "CpuNonbondedForce.h"
#include "CpuPlatform.h" #include "CpuPlatform.h"
#include "ReferenceKernels.h"
#include "openmm/kernels.h" #include "openmm/kernels.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/internal/CustomNonbondedForceImpl.h" #include "openmm/internal/CustomNonbondedForceImpl.h"
...@@ -95,6 +96,31 @@ private: ...@@ -95,6 +96,31 @@ private:
std::vector<Vec3> lastPositions; std::vector<Vec3> lastPositions;
}; };
/**
* This kernel provides methods for setting and retrieving various state data: time, positions,
* velocities, and forces.
*/
class CpuUpdateStateDataKernel : public ReferenceUpdateStateDataKernel {
public:
CpuUpdateStateDataKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ReferencePlatform::PlatformData& refdata) :
ReferenceUpdateStateDataKernel(name, platform, refdata), data(data) {
}
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private:
CpuPlatform::PlatformData& data;
};
/** /**
* This kernel is invoked by HarmonicAngleForce to calculate the forces acting on the system and the energy of the system. * This kernel is invoked by HarmonicAngleForce to calculate the forces acting on the system and the energy of the system.
*/ */
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2013 Stanford University and the Authors. * * Portions copyright (c) 2013-2024 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "sfmt/SFMT.h" #include "sfmt/SFMT.h"
#include "windowsExportCpu.h" #include "windowsExportCpu.h"
#include <iosfwd>
#include <vector> #include <vector>
namespace OpenMM { namespace OpenMM {
...@@ -48,6 +49,8 @@ public: ...@@ -48,6 +49,8 @@ public:
void initialize(int seed, int numThreads); void initialize(int seed, int numThreads);
float getGaussianRandom(int threadIndex); float getGaussianRandom(int threadIndex);
float getUniformRandom(int threadIndex); float getUniformRandom(int threadIndex);
void createCheckpoint(std::ostream& stream);
void loadCheckpoint(std::istream& stream);
private: private:
bool hasInitialized; bool hasInitialized;
int randomSeed; int randomSeed;
......
...@@ -39,8 +39,11 @@ using namespace OpenMM; ...@@ -39,8 +39,11 @@ using namespace OpenMM;
KernelImpl* CpuKernelFactory::createKernelImpl(std::string name, const Platform& platform, ContextImpl& context) const { KernelImpl* CpuKernelFactory::createKernelImpl(std::string name, const Platform& platform, ContextImpl& context) const {
CpuPlatform::PlatformData& data = CpuPlatform::getPlatformData(context); CpuPlatform::PlatformData& data = CpuPlatform::getPlatformData(context);
ReferencePlatform::PlatformData& refdata = *static_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
if (name == CalcForcesAndEnergyKernel::Name()) if (name == CalcForcesAndEnergyKernel::Name())
return new CpuCalcForcesAndEnergyKernel(name, platform, data, context); return new CpuCalcForcesAndEnergyKernel(name, platform, data, context);
if (name == UpdateStateDataKernel::Name())
return new CpuUpdateStateDataKernel(name, platform, data, refdata);
if (name == CalcHarmonicAngleForceKernel::Name()) if (name == CalcHarmonicAngleForceKernel::Name())
return new CpuCalcHarmonicAngleForceKernel(name, platform, data); return new CpuCalcHarmonicAngleForceKernel(name, platform, data);
if (name == CalcPeriodicTorsionForceKernel::Name()) if (name == CalcPeriodicTorsionForceKernel::Name())
......
...@@ -308,6 +308,16 @@ double CpuCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, boo ...@@ -308,6 +308,16 @@ double CpuCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, boo
return referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().finishComputation(context, includeForce, includeEnergy, groups, valid); return referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().finishComputation(context, includeForce, includeEnergy, groups, valid);
} }
void CpuUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
ReferenceUpdateStateDataKernel::createCheckpoint(context, stream);
data.random.createCheckpoint(stream);
}
void CpuUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
ReferenceUpdateStateDataKernel::loadCheckpoint(context, stream);
data.random.loadCheckpoint(stream);
}
void CpuCalcHarmonicAngleForceKernel::initialize(const System& system, const HarmonicAngleForce& force) { void CpuCalcHarmonicAngleForceKernel::initialize(const System& system, const HarmonicAngleForce& force) {
numAngles = force.getNumAngles(); numAngles = force.getNumAngles();
angleIndexArray.resize(numAngles, vector<int>(3)); angleIndexArray.resize(numAngles, vector<int>(3));
......
...@@ -64,6 +64,7 @@ CpuPlatform::CpuPlatform() { ...@@ -64,6 +64,7 @@ CpuPlatform::CpuPlatform() {
deprecatedPropertyReplacements["CpuThreads"] = CpuThreads(); deprecatedPropertyReplacements["CpuThreads"] = CpuThreads();
CpuKernelFactory* factory = new CpuKernelFactory(); CpuKernelFactory* factory = new CpuKernelFactory();
registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory); registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory);
registerKernelFactory(UpdateStateDataKernel::Name(), factory);
registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory); registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory);
registerKernelFactory(CalcPeriodicTorsionForceKernel::Name(), factory); registerKernelFactory(CalcPeriodicTorsionForceKernel::Name(), factory);
registerKernelFactory(CalcRBTorsionForceKernel::Name(), factory); registerKernelFactory(CalcRBTorsionForceKernel::Name(), factory);
......
/* Portions copyright (c) 2013 Stanford University and Simbios. /* Portions copyright (c) 2013-2024 Stanford University and Simbios.
* Authors: Peter Eastman * Authors: Peter Eastman
* Contributors: * Contributors:
* *
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "openmm/internal/OSRngSeed.h" #include "openmm/internal/OSRngSeed.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include <cmath> #include <cmath>
#include <iostream>
using namespace std; using namespace std;
using namespace OpenMM; using namespace OpenMM;
...@@ -87,3 +88,36 @@ float CpuRandom::getGaussianRandom(int threadIndex) { ...@@ -87,3 +88,36 @@ float CpuRandom::getGaussianRandom(int threadIndex) {
float CpuRandom::getUniformRandom(int threadIndex) { float CpuRandom::getUniformRandom(int threadIndex) {
return genrand_real2(*threadRandom[threadIndex]); return genrand_real2(*threadRandom[threadIndex]);
} }
void CpuRandom::createCheckpoint(std::ostream& stream) {
int initialized = hasInitialized;
stream.write((char*) &initialized, sizeof(int));
if (hasInitialized) {
stream.write((char*) &randomSeed, sizeof(int));
int numThreads = threadRandom.size();
stream.write((char*) &numThreads, sizeof(int));
stream.write((char*) nextGaussian.data(), sizeof(float)*numThreads);
stream.write((char*) nextGaussianIsValid.data(), sizeof(int)*numThreads);
for (int i = 0; i < numThreads; i++)
threadRandom[i]->createCheckpoint(stream);
}
}
void CpuRandom::loadCheckpoint(std::istream& stream) {
int initialized;
stream.read((char*) &initialized, sizeof(int));
hasInitialized = false;
threadRandom.clear();
nextGaussian.clear();
nextGaussianIsValid.clear();
if (initialized) {
int seed, numThreads;
stream.read((char*) &seed, sizeof(int));
stream.read((char*) &numThreads, sizeof(int));
initialize(seed, numThreads);
stream.read((char*) nextGaussian.data(), sizeof(float)*numThreads);
stream.read((char*) nextGaussianIsValid.data(), sizeof(float)*numThreads);
for (int i = 0; i < numThreads; i++)
threadRandom[i]->loadCheckpoint(stream);
}
}
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include "openmm/kernels.h" #include "openmm/kernels.h"
#include "openmm/internal/CustomCPPForceImpl.h" #include "openmm/internal/CustomCPPForceImpl.h"
#include "openmm/internal/CustomNonbondedForceImpl.h" #include "openmm/internal/CustomNonbondedForceImpl.h"
#include "openmm/internal/windowsExport.h"
#include "SimTKOpenMMRealType.h" #include "SimTKOpenMMRealType.h"
#include "ReferenceNeighborList.h" #include "ReferenceNeighborList.h"
#include "lepton/CompiledExpression.h" #include "lepton/CompiledExpression.h"
...@@ -116,7 +117,7 @@ private: ...@@ -116,7 +117,7 @@ private:
* This kernel provides methods for setting and retrieving various state data: time, positions, * This kernel provides methods for setting and retrieving various state data: time, positions,
* velocities, and forces. * velocities, and forces.
*/ */
class ReferenceUpdateStateDataKernel : public UpdateStateDataKernel { class OPENMM_EXPORT ReferenceUpdateStateDataKernel : public UpdateStateDataKernel {
public: public:
ReferenceUpdateStateDataKernel(std::string name, const Platform& platform, ReferencePlatform::PlatformData& data) : UpdateStateDataKernel(name, platform), data(data) { ReferenceUpdateStateDataKernel(std::string name, const Platform& platform, ReferencePlatform::PlatformData& data) : UpdateStateDataKernel(name, platform), data(data) {
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "openmm/AndersenThermostat.h" #include "openmm/AndersenThermostat.h"
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/LangevinIntegrator.h"
#include "openmm/NonbondedForce.h" #include "openmm/NonbondedForce.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/VerletIntegrator.h" #include "openmm/VerletIntegrator.h"
...@@ -220,6 +221,55 @@ void testMultipleDevices() { ...@@ -220,6 +221,55 @@ void testMultipleDevices() {
compareStates(s1, s9); compareStates(s1, s9);
} }
void testLangevin() {
const int numParticles = 10;
const double boxSize = 3.0;
System system;
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
nonbonded->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
vector<Vec3> positions(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
nonbonded->addParticle(i%2 == 0 ? 0.1 : -0.1, 0.2, 0.1);
positions[i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
}
LangevinIntegrator integrator(300.0, 1.0, 0.001);
Context context(system, integrator, platform);
context.setPositions(positions);
context.setPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
// Run for a little while.
integrator.step(100);
// Record the current state and make a checkpoint.
State s1 = context.getState(State::Positions | State::Velocities | State::Parameters);
stringstream stream1(ios_base::out | ios_base::in | ios_base::binary);
context.createCheckpoint(stream1);
// Continue the simulation for a few more steps and record the state again.
integrator.step(10);
State s2 = context.getState(State::Positions | State::Velocities | State::Parameters);
// Restore from the checkpoint and see if everything gets restored correctly.
context.setPeriodicBoxVectors(Vec3(2*boxSize, 0, 0), Vec3(0, 2*boxSize, 0), Vec3(0, 0, 2*boxSize));
context.loadCheckpoint(stream1);
State s3 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s1, s3);
// Now simulate from there and see if the trajectory is identical.
integrator.step(10);
State s4 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s2, s4);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -227,6 +277,7 @@ int main(int argc, char* argv[]) { ...@@ -227,6 +277,7 @@ int main(int argc, char* argv[]) {
initializeTests(argc, argv); initializeTests(argc, argv);
testSetState(); testSetState();
testMultipleDevices(); testMultipleDevices();
testLangevin();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment