Unverified Commit f67ae730 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

CPU platform checkpoints random number generator (#4740)

* CPU platform checkpoints random number generator

* Fix Windows compilation error

* Another Windows compilation error
parent 9fe1bae6
......@@ -95,7 +95,7 @@ OPENMM_EXPORT int get_min_array_size64(void);
OPENMM_EXPORT SFMTData* createSFMTData(void);
OPENMM_EXPORT void deleteSFMTData(SFMTData* data);
class SFMT {
class OPENMM_EXPORT SFMT {
public:
SFMT() : data(createSFMTData()) {
}
......
......@@ -42,6 +42,7 @@
#include "CpuNeighborList.h"
#include "CpuNonbondedForce.h"
#include "CpuPlatform.h"
#include "ReferenceKernels.h"
#include "openmm/kernels.h"
#include "openmm/System.h"
#include "openmm/internal/CustomNonbondedForceImpl.h"
......@@ -95,6 +96,31 @@ private:
std::vector<Vec3> lastPositions;
};
/**
* This kernel provides methods for setting and retrieving various state data: time, positions,
* velocities, and forces.
*/
class CpuUpdateStateDataKernel : public ReferenceUpdateStateDataKernel {
public:
CpuUpdateStateDataKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ReferencePlatform::PlatformData& refdata) :
ReferenceUpdateStateDataKernel(name, platform, refdata), data(data) {
}
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private:
CpuPlatform::PlatformData& data;
};
/**
* This kernel is invoked by HarmonicAngleForce to calculate the forces acting on the system and the energy of the system.
*/
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013 Stanford University and the Authors. *
* Portions copyright (c) 2013-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -34,6 +34,7 @@
#include "sfmt/SFMT.h"
#include "windowsExportCpu.h"
#include <iosfwd>
#include <vector>
namespace OpenMM {
......@@ -48,6 +49,8 @@ public:
void initialize(int seed, int numThreads);
float getGaussianRandom(int threadIndex);
float getUniformRandom(int threadIndex);
void createCheckpoint(std::ostream& stream);
void loadCheckpoint(std::istream& stream);
private:
bool hasInitialized;
int randomSeed;
......
......@@ -39,8 +39,11 @@ using namespace OpenMM;
KernelImpl* CpuKernelFactory::createKernelImpl(std::string name, const Platform& platform, ContextImpl& context) const {
CpuPlatform::PlatformData& data = CpuPlatform::getPlatformData(context);
ReferencePlatform::PlatformData& refdata = *static_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
if (name == CalcForcesAndEnergyKernel::Name())
return new CpuCalcForcesAndEnergyKernel(name, platform, data, context);
if (name == UpdateStateDataKernel::Name())
return new CpuUpdateStateDataKernel(name, platform, data, refdata);
if (name == CalcHarmonicAngleForceKernel::Name())
return new CpuCalcHarmonicAngleForceKernel(name, platform, data);
if (name == CalcPeriodicTorsionForceKernel::Name())
......
......@@ -308,6 +308,16 @@ double CpuCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, boo
return referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().finishComputation(context, includeForce, includeEnergy, groups, valid);
}
void CpuUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
ReferenceUpdateStateDataKernel::createCheckpoint(context, stream);
data.random.createCheckpoint(stream);
}
void CpuUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
ReferenceUpdateStateDataKernel::loadCheckpoint(context, stream);
data.random.loadCheckpoint(stream);
}
void CpuCalcHarmonicAngleForceKernel::initialize(const System& system, const HarmonicAngleForce& force) {
numAngles = force.getNumAngles();
angleIndexArray.resize(numAngles, vector<int>(3));
......
......@@ -64,6 +64,7 @@ CpuPlatform::CpuPlatform() {
deprecatedPropertyReplacements["CpuThreads"] = CpuThreads();
CpuKernelFactory* factory = new CpuKernelFactory();
registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory);
registerKernelFactory(UpdateStateDataKernel::Name(), factory);
registerKernelFactory(CalcHarmonicAngleForceKernel::Name(), factory);
registerKernelFactory(CalcPeriodicTorsionForceKernel::Name(), factory);
registerKernelFactory(CalcRBTorsionForceKernel::Name(), factory);
......
/* Portions copyright (c) 2013 Stanford University and Simbios.
/* Portions copyright (c) 2013-2024 Stanford University and Simbios.
* Authors: Peter Eastman
* Contributors:
*
......@@ -26,6 +26,7 @@
#include "openmm/internal/OSRngSeed.h"
#include "openmm/OpenMMException.h"
#include <cmath>
#include <iostream>
using namespace std;
using namespace OpenMM;
......@@ -87,3 +88,36 @@ float CpuRandom::getGaussianRandom(int threadIndex) {
float CpuRandom::getUniformRandom(int threadIndex) {
return genrand_real2(*threadRandom[threadIndex]);
}
void CpuRandom::createCheckpoint(std::ostream& stream) {
int initialized = hasInitialized;
stream.write((char*) &initialized, sizeof(int));
if (hasInitialized) {
stream.write((char*) &randomSeed, sizeof(int));
int numThreads = threadRandom.size();
stream.write((char*) &numThreads, sizeof(int));
stream.write((char*) nextGaussian.data(), sizeof(float)*numThreads);
stream.write((char*) nextGaussianIsValid.data(), sizeof(int)*numThreads);
for (int i = 0; i < numThreads; i++)
threadRandom[i]->createCheckpoint(stream);
}
}
void CpuRandom::loadCheckpoint(std::istream& stream) {
int initialized;
stream.read((char*) &initialized, sizeof(int));
hasInitialized = false;
threadRandom.clear();
nextGaussian.clear();
nextGaussianIsValid.clear();
if (initialized) {
int seed, numThreads;
stream.read((char*) &seed, sizeof(int));
stream.read((char*) &numThreads, sizeof(int));
initialize(seed, numThreads);
stream.read((char*) nextGaussian.data(), sizeof(float)*numThreads);
stream.read((char*) nextGaussianIsValid.data(), sizeof(float)*numThreads);
for (int i = 0; i < numThreads; i++)
threadRandom[i]->loadCheckpoint(stream);
}
}
......@@ -36,6 +36,7 @@
#include "openmm/kernels.h"
#include "openmm/internal/CustomCPPForceImpl.h"
#include "openmm/internal/CustomNonbondedForceImpl.h"
#include "openmm/internal/windowsExport.h"
#include "SimTKOpenMMRealType.h"
#include "ReferenceNeighborList.h"
#include "lepton/CompiledExpression.h"
......@@ -116,7 +117,7 @@ private:
* This kernel provides methods for setting and retrieving various state data: time, positions,
* velocities, and forces.
*/
class ReferenceUpdateStateDataKernel : public UpdateStateDataKernel {
class OPENMM_EXPORT ReferenceUpdateStateDataKernel : public UpdateStateDataKernel {
public:
ReferenceUpdateStateDataKernel(std::string name, const Platform& platform, ReferencePlatform::PlatformData& data) : UpdateStateDataKernel(name, platform), data(data) {
}
......
......@@ -32,6 +32,7 @@
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/Context.h"
#include "openmm/LangevinIntegrator.h"
#include "openmm/NonbondedForce.h"
#include "openmm/System.h"
#include "openmm/VerletIntegrator.h"
......@@ -220,6 +221,55 @@ void testMultipleDevices() {
compareStates(s1, s9);
}
void testLangevin() {
const int numParticles = 10;
const double boxSize = 3.0;
System system;
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
nonbonded->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
vector<Vec3> positions(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
nonbonded->addParticle(i%2 == 0 ? 0.1 : -0.1, 0.2, 0.1);
positions[i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
}
LangevinIntegrator integrator(300.0, 1.0, 0.001);
Context context(system, integrator, platform);
context.setPositions(positions);
context.setPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
// Run for a little while.
integrator.step(100);
// Record the current state and make a checkpoint.
State s1 = context.getState(State::Positions | State::Velocities | State::Parameters);
stringstream stream1(ios_base::out | ios_base::in | ios_base::binary);
context.createCheckpoint(stream1);
// Continue the simulation for a few more steps and record the state again.
integrator.step(10);
State s2 = context.getState(State::Positions | State::Velocities | State::Parameters);
// Restore from the checkpoint and see if everything gets restored correctly.
context.setPeriodicBoxVectors(Vec3(2*boxSize, 0, 0), Vec3(0, 2*boxSize, 0), Vec3(0, 0, 2*boxSize));
context.loadCheckpoint(stream1);
State s3 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s1, s3);
// Now simulate from there and see if the trajectory is identical.
integrator.step(10);
State s4 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s2, s4);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -227,6 +277,7 @@ int main(int argc, char* argv[]) {
initializeTests(argc, argv);
testSetState();
testMultipleDevices();
testLangevin();
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment