Commit 717df453 authored by peastman's avatar peastman
Browse files

CustomIntegrator avoids extra force computations when UpdateContextState doesn't change them

parent feb79f77
......@@ -55,7 +55,7 @@ public:
const NonbondedForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
......@@ -53,7 +53,7 @@ public:
const PeriodicTorsionForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
......@@ -53,7 +53,7 @@ public:
const RBTorsionForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
......@@ -47,7 +47,7 @@ void AndersenThermostatImpl::initialize(ContextImpl& context) {
kernel.getAs<ApplyAndersenThermostatKernel>().initialize(context.getSystem(), owner);
}
void AndersenThermostatImpl::updateContextState(ContextImpl& context) {
void AndersenThermostatImpl::updateContextState(ContextImpl& context, bool& forcesInvalid) {
kernel.getAs<ApplyAndersenThermostatKernel>().execute(context);
}
......
......@@ -48,7 +48,7 @@ void CMMotionRemoverImpl::initialize(ContextImpl& context) {
kernel.getAs<RemoveCMMotionKernel>().initialize(system, owner);
}
void CMMotionRemoverImpl::updateContextState(ContextImpl& context) {
void CMMotionRemoverImpl::updateContextState(ContextImpl& context, bool& forcesInvalid) {
kernel.getAs<RemoveCMMotionKernel>().execute(context);
}
......
......@@ -315,9 +315,11 @@ double ContextImpl::calcKineticEnergy() {
return integrator.computeKineticEnergy();
}
void ContextImpl::updateContextState() {
bool ContextImpl::updateContextState() {
bool forcesInvalid = false;
for (auto force : forceImpls)
force->updateContextState(*this);
force->updateContextState(*this, forcesInvalid);
return forcesInvalid;
}
const vector<ForceImpl*>& ContextImpl::getForceImpls() const {
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2017 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/internal/ForceImpl.h"
using namespace OpenMM;
using namespace std;
void ForceImpl::updateContextState(ContextImpl& context, bool& forcesInvalid) {
// Usually subclasses will override this. If they don't, call the old
// (single argument) version instead, and just assume they invalidate forces.
updateContextState(context);
forcesInvalid = true;
}
void ForceImpl::updateContextState(ContextImpl& context) {
}
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2010-2016 Stanford University and the Authors. *
* Portions copyright (c) 2010-2017 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -68,7 +68,7 @@ void MonteCarloBarostatImpl::initialize(ContextImpl& context) {
init_gen_rand(randSeed, random);
}
void MonteCarloBarostatImpl::updateContextState(ContextImpl& context) {
void MonteCarloBarostatImpl::updateContextState(ContextImpl& context, bool& forcesInvalid) {
if (++step < owner.getFrequency() || owner.getFrequency() == 0)
return;
step = 0;
......@@ -101,8 +101,10 @@ void MonteCarloBarostatImpl::updateContextState(ContextImpl& context) {
context.getOwner().setPeriodicBoxVectors(box[0], box[1], box[2]);
volume = newVolume;
}
else
else {
numAccepted++;
forcesInvalid = true;
}
numAttempted++;
if (numAttempted >= 10) {
if (numAccepted < 0.25*numAttempted) {
......
......@@ -1640,7 +1640,7 @@ private:
class CudaApplyMonteCarloBarostatKernel : public ApplyMonteCarloBarostatKernel {
public:
CudaApplyMonteCarloBarostatKernel(std::string name, const Platform& platform, CudaContext& cu) : ApplyMonteCarloBarostatKernel(name, platform), cu(cu),
hasInitializedKernels(false), savedPositions(NULL), moleculeAtoms(NULL), moleculeStartIndex(NULL) {
hasInitializedKernels(false), savedPositions(NULL), savedForces(NULL), moleculeAtoms(NULL), moleculeStartIndex(NULL) {
}
~CudaApplyMonteCarloBarostatKernel();
/**
......@@ -1675,6 +1675,7 @@ private:
bool hasInitializedKernels;
int numMolecules;
CudaArray* savedPositions;
CudaArray* savedForces;
CudaArray* moleculeAtoms;
CudaArray* moleculeStartIndex;
CUfunction kernel;
......
......@@ -7827,6 +7827,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
globalValues->upload(globalValuesFloat);
}
}
bool stepInvalidatesForces = invalidatesForces[step];
if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]);
kernelArgs[step][0][1] = &posCorrection;
......@@ -7867,7 +7868,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
}
else if (stepType[step] == CustomIntegrator::UpdateContextState) {
recordChangedParameters(context);
context.updateContextState();
stepInvalidatesForces = context.updateContextState();
}
else if (stepType[step] == CustomIntegrator::ConstrainPositions) {
if (hasAnyConstraints) {
......@@ -7892,7 +7893,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
if (blockEnd[step] != -1)
nextStep = blockEnd[step]; // Return to the start of a while block.
}
if (invalidatesForces[step]) {
if (stepInvalidatesForces) {
forcesAreValid = false;
savedEnergy.clear();
}
......@@ -8111,6 +8112,8 @@ CudaApplyMonteCarloBarostatKernel::~CudaApplyMonteCarloBarostatKernel() {
cu.setAsCurrent();
if (savedPositions != NULL)
delete savedPositions;
if (savedForces != NULL)
delete savedForces;
if (moleculeAtoms != NULL)
delete moleculeAtoms;
if (moleculeStartIndex != NULL)
......@@ -8120,6 +8123,7 @@ CudaApplyMonteCarloBarostatKernel::~CudaApplyMonteCarloBarostatKernel() {
void CudaApplyMonteCarloBarostatKernel::initialize(const System& system, const Force& thermostat) {
cu.setAsCurrent();
savedPositions = new CudaArray(cu, cu.getPaddedNumAtoms(), cu.getUseDoublePrecision() ? sizeof(double4) : sizeof(float4), "savedPositions");
savedForces = CudaArray::create<long long>(cu, cu.getPaddedNumAtoms()*3, "savedForces");
CUmodule module = cu.createModule(CudaKernelSources::monteCarloBarostat);
kernel = cu.getKernel(module, "scalePositions");
}
......@@ -8157,6 +8161,12 @@ void CudaApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, d
m<<"Error saving positions for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")";
throw OpenMMException(m.str());
}
result = cuMemcpyDtoD(savedForces->getDevicePointer(), cu.getForce().getDevicePointer(), savedForces->getSize()*savedForces->getElementSize());
if (result != CUDA_SUCCESS) {
std::stringstream m;
m<<"Error saving forces for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")";
throw OpenMMException(m.str());
}
float scalefX = (float) scaleX;
float scalefY = (float) scaleY;
float scalefZ = (float) scaleZ;
......@@ -8178,6 +8188,12 @@ void CudaApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context)
m<<"Error restoring positions for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")";
throw OpenMMException(m.str());
}
result = cuMemcpyDtoD(cu.getForce().getDevicePointer(), savedForces->getDevicePointer(), savedForces->getSize()*savedForces->getElementSize());
if (result != CUDA_SUCCESS) {
std::stringstream m;
m<<"Error restoring forces for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")";
throw OpenMMException(m.str());
}
}
CudaRemoveCMMotionKernel::~CudaRemoveCMMotionKernel() {
......
......@@ -1626,7 +1626,7 @@ private:
class OpenCLApplyMonteCarloBarostatKernel : public ApplyMonteCarloBarostatKernel {
public:
OpenCLApplyMonteCarloBarostatKernel(std::string name, const Platform& platform, OpenCLContext& cl) : ApplyMonteCarloBarostatKernel(name, platform), cl(cl),
hasInitializedKernels(false), savedPositions(NULL), moleculeAtoms(NULL), moleculeStartIndex(NULL) {
hasInitializedKernels(false), savedPositions(NULL), savedForces(NULL), moleculeAtoms(NULL), moleculeStartIndex(NULL) {
}
~OpenCLApplyMonteCarloBarostatKernel();
/**
......@@ -1661,6 +1661,7 @@ private:
bool hasInitializedKernels;
int numMolecules;
OpenCLArray* savedPositions;
OpenCLArray* savedForces;
OpenCLArray* moleculeAtoms;
OpenCLArray* moleculeStartIndex;
cl::Kernel kernel;
......
......@@ -8182,6 +8182,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
globalValues->upload(globalValuesFloat);
}
}
bool stepInvalidatesForces = invalidatesForces[step];
if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) {
kernels[step][0].setArg<cl_uint>(9, integration.prepareRandomNumbers(requiredGaussian[step]));
kernels[step][0].setArg<cl::Buffer>(8, integration.getRandom().getDeviceBuffer());
......@@ -8226,7 +8227,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
}
else if (stepType[step] == CustomIntegrator::UpdateContextState) {
recordChangedParameters(context);
context.updateContextState();
stepInvalidatesForces = context.updateContextState();
}
else if (stepType[step] == CustomIntegrator::ConstrainPositions) {
if (hasAnyConstraints) {
......@@ -8250,7 +8251,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
if (blockEnd[step] != -1)
nextStep = blockEnd[step]; // Return to the start of a while block.
}
if (invalidatesForces[step]) {
if (stepInvalidatesForces) {
forcesAreValid = false;
savedEnergy.clear();
}
......@@ -8471,6 +8472,8 @@ void OpenCLApplyAndersenThermostatKernel::execute(ContextImpl& context) {
OpenCLApplyMonteCarloBarostatKernel::~OpenCLApplyMonteCarloBarostatKernel() {
if (savedPositions != NULL)
delete savedPositions;
if (savedForces != NULL)
delete savedForces;
if (moleculeAtoms != NULL)
delete moleculeAtoms;
if (moleculeStartIndex != NULL)
......@@ -8479,6 +8482,7 @@ OpenCLApplyMonteCarloBarostatKernel::~OpenCLApplyMonteCarloBarostatKernel() {
void OpenCLApplyMonteCarloBarostatKernel::initialize(const System& system, const Force& thermostat) {
savedPositions = new OpenCLArray(cl, cl.getPaddedNumAtoms(), cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4), "savedPositions");
savedForces = new OpenCLArray(cl, cl.getPaddedNumAtoms(), cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4), "savedForces");
cl::Program program = cl.createProgram(OpenCLKernelSources::monteCarloBarostat);
kernel = cl::Kernel(program, "scalePositions");
}
......@@ -8514,6 +8518,7 @@ void OpenCLApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context,
}
int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
cl.getQueue().enqueueCopyBuffer(cl.getPosq().getDeviceBuffer(), savedPositions->getDeviceBuffer(), 0, 0, bytesToCopy);
cl.getQueue().enqueueCopyBuffer(cl.getForce().getDeviceBuffer(), savedForces->getDeviceBuffer(), 0, 0, bytesToCopy);
kernel.setArg<cl_float>(0, (cl_float) scaleX);
kernel.setArg<cl_float>(1, (cl_float) scaleY);
kernel.setArg<cl_float>(2, (cl_float) scaleZ);
......@@ -8527,6 +8532,7 @@ void OpenCLApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context,
void OpenCLApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context) {
int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
cl.getQueue().enqueueCopyBuffer(savedPositions->getDeviceBuffer(), cl.getPosq().getDeviceBuffer(), 0, 0, bytesToCopy);
cl.getQueue().enqueueCopyBuffer(savedForces->getDeviceBuffer(), cl.getForce().getDeviceBuffer(), 0, 0, bytesToCopy);
}
OpenCLRemoveCMMotionKernel::~OpenCLRemoveCMMotionKernel() {
......
......@@ -249,6 +249,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
energy = (needsEnergy[step] ? groupEnergy[flags] : 0);
vector<Vec3>& stepForces = (needsForces[step] ? groupForces[flags] : forces);
int nextStep = step+1;
bool stepInvalidatesForces = invalidatesForces[step];
switch (stepType[step]) {
case CustomIntegrator::ComputeGlobal: {
uniform = SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber();
......@@ -295,7 +296,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
}
case CustomIntegrator::UpdateContextState: {
recordChangedParameters(context, globals);
context.updateContextState();
stepInvalidatesForces = context.updateContextState();
globals.insert(context.getParameters().begin(), context.getParameters().end());
for (auto& global : globals)
expressionSet.setVariable(expressionSet.getVariableIndex(global.first), global.second);
......@@ -317,7 +318,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
break;
}
}
if (invalidatesForces[step]) {
if (stepInvalidatesForces) {
forcesAreValid = false;
groupForces.clear();
groupEnergy.clear();
......
......@@ -53,7 +53,7 @@ public:
const AmoebaAngleForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
......@@ -54,7 +54,7 @@ public:
const AmoebaBondForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
......@@ -50,7 +50,7 @@ public:
const AmoebaGeneralizedKirkwoodForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
......@@ -53,7 +53,7 @@ public:
const AmoebaInPlaneAngleForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
......@@ -53,7 +53,7 @@ public:
const AmoebaMultipoleForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
......@@ -53,7 +53,7 @@ public:
const AmoebaOutOfPlaneBendForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
......@@ -53,7 +53,7 @@ public:
const AmoebaPiTorsionForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context) {
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment