"platforms/cuda/vscode:/vscode.git/clone" did not exist on "7639d0bdecff280b8d6570ac4eb09ec952f1f78c"
Commit 15b9fb48 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1853 from peastman/updatecontextstate

CustomIntegrator avoids extra force computations when UpdateContextState doesn't change them
parents feb79f77 717df453
...@@ -55,7 +55,7 @@ public: ...@@ -55,7 +55,7 @@ public:
const NonbondedForce& getOwner() const { const NonbondedForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
const PeriodicTorsionForce& getOwner() const { const PeriodicTorsionForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
const RBTorsionForce& getOwner() const { const RBTorsionForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
...@@ -47,7 +47,7 @@ void AndersenThermostatImpl::initialize(ContextImpl& context) { ...@@ -47,7 +47,7 @@ void AndersenThermostatImpl::initialize(ContextImpl& context) {
kernel.getAs<ApplyAndersenThermostatKernel>().initialize(context.getSystem(), owner); kernel.getAs<ApplyAndersenThermostatKernel>().initialize(context.getSystem(), owner);
} }
void AndersenThermostatImpl::updateContextState(ContextImpl& context) { void AndersenThermostatImpl::updateContextState(ContextImpl& context, bool& forcesInvalid) {
kernel.getAs<ApplyAndersenThermostatKernel>().execute(context); kernel.getAs<ApplyAndersenThermostatKernel>().execute(context);
} }
......
...@@ -48,7 +48,7 @@ void CMMotionRemoverImpl::initialize(ContextImpl& context) { ...@@ -48,7 +48,7 @@ void CMMotionRemoverImpl::initialize(ContextImpl& context) {
kernel.getAs<RemoveCMMotionKernel>().initialize(system, owner); kernel.getAs<RemoveCMMotionKernel>().initialize(system, owner);
} }
void CMMotionRemoverImpl::updateContextState(ContextImpl& context) { void CMMotionRemoverImpl::updateContextState(ContextImpl& context, bool& forcesInvalid) {
kernel.getAs<RemoveCMMotionKernel>().execute(context); kernel.getAs<RemoveCMMotionKernel>().execute(context);
} }
......
...@@ -315,9 +315,11 @@ double ContextImpl::calcKineticEnergy() { ...@@ -315,9 +315,11 @@ double ContextImpl::calcKineticEnergy() {
return integrator.computeKineticEnergy(); return integrator.computeKineticEnergy();
} }
void ContextImpl::updateContextState() { bool ContextImpl::updateContextState() {
bool forcesInvalid = false;
for (auto force : forceImpls) for (auto force : forceImpls)
force->updateContextState(*this); force->updateContextState(*this, forcesInvalid);
return forcesInvalid;
} }
const vector<ForceImpl*>& ContextImpl::getForceImpls() const { const vector<ForceImpl*>& ContextImpl::getForceImpls() const {
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2017 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/internal/ForceImpl.h"
using namespace OpenMM;
using namespace std;
void ForceImpl::updateContextState(ContextImpl& context, bool& forcesInvalid) {
// Usually subclasses will override this. If they don't, call the old
// (single argument) version instead, and just assume they invalidate forces.
updateContextState(context);
forcesInvalid = true;
}
void ForceImpl::updateContextState(ContextImpl& context) {
}
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2016 Stanford University and the Authors. * * Portions copyright (c) 2010-2017 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -68,7 +68,7 @@ void MonteCarloBarostatImpl::initialize(ContextImpl& context) { ...@@ -68,7 +68,7 @@ void MonteCarloBarostatImpl::initialize(ContextImpl& context) {
init_gen_rand(randSeed, random); init_gen_rand(randSeed, random);
} }
void MonteCarloBarostatImpl::updateContextState(ContextImpl& context) { void MonteCarloBarostatImpl::updateContextState(ContextImpl& context, bool& forcesInvalid) {
if (++step < owner.getFrequency() || owner.getFrequency() == 0) if (++step < owner.getFrequency() || owner.getFrequency() == 0)
return; return;
step = 0; step = 0;
...@@ -101,8 +101,10 @@ void MonteCarloBarostatImpl::updateContextState(ContextImpl& context) { ...@@ -101,8 +101,10 @@ void MonteCarloBarostatImpl::updateContextState(ContextImpl& context) {
context.getOwner().setPeriodicBoxVectors(box[0], box[1], box[2]); context.getOwner().setPeriodicBoxVectors(box[0], box[1], box[2]);
volume = newVolume; volume = newVolume;
} }
else else {
numAccepted++; numAccepted++;
forcesInvalid = true;
}
numAttempted++; numAttempted++;
if (numAttempted >= 10) { if (numAttempted >= 10) {
if (numAccepted < 0.25*numAttempted) { if (numAccepted < 0.25*numAttempted) {
......
...@@ -1640,7 +1640,7 @@ private: ...@@ -1640,7 +1640,7 @@ private:
class CudaApplyMonteCarloBarostatKernel : public ApplyMonteCarloBarostatKernel { class CudaApplyMonteCarloBarostatKernel : public ApplyMonteCarloBarostatKernel {
public: public:
CudaApplyMonteCarloBarostatKernel(std::string name, const Platform& platform, CudaContext& cu) : ApplyMonteCarloBarostatKernel(name, platform), cu(cu), CudaApplyMonteCarloBarostatKernel(std::string name, const Platform& platform, CudaContext& cu) : ApplyMonteCarloBarostatKernel(name, platform), cu(cu),
hasInitializedKernels(false), savedPositions(NULL), moleculeAtoms(NULL), moleculeStartIndex(NULL) { hasInitializedKernels(false), savedPositions(NULL), savedForces(NULL), moleculeAtoms(NULL), moleculeStartIndex(NULL) {
} }
~CudaApplyMonteCarloBarostatKernel(); ~CudaApplyMonteCarloBarostatKernel();
/** /**
...@@ -1675,6 +1675,7 @@ private: ...@@ -1675,6 +1675,7 @@ private:
bool hasInitializedKernels; bool hasInitializedKernels;
int numMolecules; int numMolecules;
CudaArray* savedPositions; CudaArray* savedPositions;
CudaArray* savedForces;
CudaArray* moleculeAtoms; CudaArray* moleculeAtoms;
CudaArray* moleculeStartIndex; CudaArray* moleculeStartIndex;
CUfunction kernel; CUfunction kernel;
......
...@@ -7827,6 +7827,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7827,6 +7827,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
globalValues->upload(globalValuesFloat); globalValues->upload(globalValuesFloat);
} }
} }
bool stepInvalidatesForces = invalidatesForces[step];
if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) { if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) {
int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]); int randomIndex = integration.prepareRandomNumbers(requiredGaussian[step]);
kernelArgs[step][0][1] = &posCorrection; kernelArgs[step][0][1] = &posCorrection;
...@@ -7867,7 +7868,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7867,7 +7868,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
} }
else if (stepType[step] == CustomIntegrator::UpdateContextState) { else if (stepType[step] == CustomIntegrator::UpdateContextState) {
recordChangedParameters(context); recordChangedParameters(context);
context.updateContextState(); stepInvalidatesForces = context.updateContextState();
} }
else if (stepType[step] == CustomIntegrator::ConstrainPositions) { else if (stepType[step] == CustomIntegrator::ConstrainPositions) {
if (hasAnyConstraints) { if (hasAnyConstraints) {
...@@ -7892,7 +7893,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat ...@@ -7892,7 +7893,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
if (blockEnd[step] != -1) if (blockEnd[step] != -1)
nextStep = blockEnd[step]; // Return to the start of a while block. nextStep = blockEnd[step]; // Return to the start of a while block.
} }
if (invalidatesForces[step]) { if (stepInvalidatesForces) {
forcesAreValid = false; forcesAreValid = false;
savedEnergy.clear(); savedEnergy.clear();
} }
...@@ -8111,6 +8112,8 @@ CudaApplyMonteCarloBarostatKernel::~CudaApplyMonteCarloBarostatKernel() { ...@@ -8111,6 +8112,8 @@ CudaApplyMonteCarloBarostatKernel::~CudaApplyMonteCarloBarostatKernel() {
cu.setAsCurrent(); cu.setAsCurrent();
if (savedPositions != NULL) if (savedPositions != NULL)
delete savedPositions; delete savedPositions;
if (savedForces != NULL)
delete savedForces;
if (moleculeAtoms != NULL) if (moleculeAtoms != NULL)
delete moleculeAtoms; delete moleculeAtoms;
if (moleculeStartIndex != NULL) if (moleculeStartIndex != NULL)
...@@ -8120,6 +8123,7 @@ CudaApplyMonteCarloBarostatKernel::~CudaApplyMonteCarloBarostatKernel() { ...@@ -8120,6 +8123,7 @@ CudaApplyMonteCarloBarostatKernel::~CudaApplyMonteCarloBarostatKernel() {
void CudaApplyMonteCarloBarostatKernel::initialize(const System& system, const Force& thermostat) { void CudaApplyMonteCarloBarostatKernel::initialize(const System& system, const Force& thermostat) {
cu.setAsCurrent(); cu.setAsCurrent();
savedPositions = new CudaArray(cu, cu.getPaddedNumAtoms(), cu.getUseDoublePrecision() ? sizeof(double4) : sizeof(float4), "savedPositions"); savedPositions = new CudaArray(cu, cu.getPaddedNumAtoms(), cu.getUseDoublePrecision() ? sizeof(double4) : sizeof(float4), "savedPositions");
savedForces = CudaArray::create<long long>(cu, cu.getPaddedNumAtoms()*3, "savedForces");
CUmodule module = cu.createModule(CudaKernelSources::monteCarloBarostat); CUmodule module = cu.createModule(CudaKernelSources::monteCarloBarostat);
kernel = cu.getKernel(module, "scalePositions"); kernel = cu.getKernel(module, "scalePositions");
} }
...@@ -8157,6 +8161,12 @@ void CudaApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, d ...@@ -8157,6 +8161,12 @@ void CudaApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, d
m<<"Error saving positions for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")"; m<<"Error saving positions for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")";
throw OpenMMException(m.str()); throw OpenMMException(m.str());
} }
result = cuMemcpyDtoD(savedForces->getDevicePointer(), cu.getForce().getDevicePointer(), savedForces->getSize()*savedForces->getElementSize());
if (result != CUDA_SUCCESS) {
std::stringstream m;
m<<"Error saving forces for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")";
throw OpenMMException(m.str());
}
float scalefX = (float) scaleX; float scalefX = (float) scaleX;
float scalefY = (float) scaleY; float scalefY = (float) scaleY;
float scalefZ = (float) scaleZ; float scalefZ = (float) scaleZ;
...@@ -8178,6 +8188,12 @@ void CudaApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context) ...@@ -8178,6 +8188,12 @@ void CudaApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context)
m<<"Error restoring positions for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")"; m<<"Error restoring positions for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")";
throw OpenMMException(m.str()); throw OpenMMException(m.str());
} }
result = cuMemcpyDtoD(cu.getForce().getDevicePointer(), savedForces->getDevicePointer(), savedForces->getSize()*savedForces->getElementSize());
if (result != CUDA_SUCCESS) {
std::stringstream m;
m<<"Error restoring forces for MC barostat: "<<cu.getErrorString(result)<<" ("<<result<<")";
throw OpenMMException(m.str());
}
} }
CudaRemoveCMMotionKernel::~CudaRemoveCMMotionKernel() { CudaRemoveCMMotionKernel::~CudaRemoveCMMotionKernel() {
......
...@@ -1626,7 +1626,7 @@ private: ...@@ -1626,7 +1626,7 @@ private:
class OpenCLApplyMonteCarloBarostatKernel : public ApplyMonteCarloBarostatKernel { class OpenCLApplyMonteCarloBarostatKernel : public ApplyMonteCarloBarostatKernel {
public: public:
OpenCLApplyMonteCarloBarostatKernel(std::string name, const Platform& platform, OpenCLContext& cl) : ApplyMonteCarloBarostatKernel(name, platform), cl(cl), OpenCLApplyMonteCarloBarostatKernel(std::string name, const Platform& platform, OpenCLContext& cl) : ApplyMonteCarloBarostatKernel(name, platform), cl(cl),
hasInitializedKernels(false), savedPositions(NULL), moleculeAtoms(NULL), moleculeStartIndex(NULL) { hasInitializedKernels(false), savedPositions(NULL), savedForces(NULL), moleculeAtoms(NULL), moleculeStartIndex(NULL) {
} }
~OpenCLApplyMonteCarloBarostatKernel(); ~OpenCLApplyMonteCarloBarostatKernel();
/** /**
...@@ -1661,6 +1661,7 @@ private: ...@@ -1661,6 +1661,7 @@ private:
bool hasInitializedKernels; bool hasInitializedKernels;
int numMolecules; int numMolecules;
OpenCLArray* savedPositions; OpenCLArray* savedPositions;
OpenCLArray* savedForces;
OpenCLArray* moleculeAtoms; OpenCLArray* moleculeAtoms;
OpenCLArray* moleculeStartIndex; OpenCLArray* moleculeStartIndex;
cl::Kernel kernel; cl::Kernel kernel;
......
...@@ -8182,6 +8182,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -8182,6 +8182,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
globalValues->upload(globalValuesFloat); globalValues->upload(globalValuesFloat);
} }
} }
bool stepInvalidatesForces = invalidatesForces[step];
if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) { if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) {
kernels[step][0].setArg<cl_uint>(9, integration.prepareRandomNumbers(requiredGaussian[step])); kernels[step][0].setArg<cl_uint>(9, integration.prepareRandomNumbers(requiredGaussian[step]));
kernels[step][0].setArg<cl::Buffer>(8, integration.getRandom().getDeviceBuffer()); kernels[step][0].setArg<cl::Buffer>(8, integration.getRandom().getDeviceBuffer());
...@@ -8226,7 +8227,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -8226,7 +8227,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
} }
else if (stepType[step] == CustomIntegrator::UpdateContextState) { else if (stepType[step] == CustomIntegrator::UpdateContextState) {
recordChangedParameters(context); recordChangedParameters(context);
context.updateContextState(); stepInvalidatesForces = context.updateContextState();
} }
else if (stepType[step] == CustomIntegrator::ConstrainPositions) { else if (stepType[step] == CustomIntegrator::ConstrainPositions) {
if (hasAnyConstraints) { if (hasAnyConstraints) {
...@@ -8250,7 +8251,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -8250,7 +8251,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
if (blockEnd[step] != -1) if (blockEnd[step] != -1)
nextStep = blockEnd[step]; // Return to the start of a while block. nextStep = blockEnd[step]; // Return to the start of a while block.
} }
if (invalidatesForces[step]) { if (stepInvalidatesForces) {
forcesAreValid = false; forcesAreValid = false;
savedEnergy.clear(); savedEnergy.clear();
} }
...@@ -8471,6 +8472,8 @@ void OpenCLApplyAndersenThermostatKernel::execute(ContextImpl& context) { ...@@ -8471,6 +8472,8 @@ void OpenCLApplyAndersenThermostatKernel::execute(ContextImpl& context) {
OpenCLApplyMonteCarloBarostatKernel::~OpenCLApplyMonteCarloBarostatKernel() { OpenCLApplyMonteCarloBarostatKernel::~OpenCLApplyMonteCarloBarostatKernel() {
if (savedPositions != NULL) if (savedPositions != NULL)
delete savedPositions; delete savedPositions;
if (savedForces != NULL)
delete savedForces;
if (moleculeAtoms != NULL) if (moleculeAtoms != NULL)
delete moleculeAtoms; delete moleculeAtoms;
if (moleculeStartIndex != NULL) if (moleculeStartIndex != NULL)
...@@ -8479,6 +8482,7 @@ OpenCLApplyMonteCarloBarostatKernel::~OpenCLApplyMonteCarloBarostatKernel() { ...@@ -8479,6 +8482,7 @@ OpenCLApplyMonteCarloBarostatKernel::~OpenCLApplyMonteCarloBarostatKernel() {
void OpenCLApplyMonteCarloBarostatKernel::initialize(const System& system, const Force& thermostat) { void OpenCLApplyMonteCarloBarostatKernel::initialize(const System& system, const Force& thermostat) {
savedPositions = new OpenCLArray(cl, cl.getPaddedNumAtoms(), cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4), "savedPositions"); savedPositions = new OpenCLArray(cl, cl.getPaddedNumAtoms(), cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4), "savedPositions");
savedForces = new OpenCLArray(cl, cl.getPaddedNumAtoms(), cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4), "savedForces");
cl::Program program = cl.createProgram(OpenCLKernelSources::monteCarloBarostat); cl::Program program = cl.createProgram(OpenCLKernelSources::monteCarloBarostat);
kernel = cl::Kernel(program, "scalePositions"); kernel = cl::Kernel(program, "scalePositions");
} }
...@@ -8514,6 +8518,7 @@ void OpenCLApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, ...@@ -8514,6 +8518,7 @@ void OpenCLApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context,
} }
int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4)); int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
cl.getQueue().enqueueCopyBuffer(cl.getPosq().getDeviceBuffer(), savedPositions->getDeviceBuffer(), 0, 0, bytesToCopy); cl.getQueue().enqueueCopyBuffer(cl.getPosq().getDeviceBuffer(), savedPositions->getDeviceBuffer(), 0, 0, bytesToCopy);
cl.getQueue().enqueueCopyBuffer(cl.getForce().getDeviceBuffer(), savedForces->getDeviceBuffer(), 0, 0, bytesToCopy);
kernel.setArg<cl_float>(0, (cl_float) scaleX); kernel.setArg<cl_float>(0, (cl_float) scaleX);
kernel.setArg<cl_float>(1, (cl_float) scaleY); kernel.setArg<cl_float>(1, (cl_float) scaleY);
kernel.setArg<cl_float>(2, (cl_float) scaleZ); kernel.setArg<cl_float>(2, (cl_float) scaleZ);
...@@ -8527,6 +8532,7 @@ void OpenCLApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, ...@@ -8527,6 +8532,7 @@ void OpenCLApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context,
void OpenCLApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context) { void OpenCLApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context) {
int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4)); int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
cl.getQueue().enqueueCopyBuffer(savedPositions->getDeviceBuffer(), cl.getPosq().getDeviceBuffer(), 0, 0, bytesToCopy); cl.getQueue().enqueueCopyBuffer(savedPositions->getDeviceBuffer(), cl.getPosq().getDeviceBuffer(), 0, 0, bytesToCopy);
cl.getQueue().enqueueCopyBuffer(savedForces->getDeviceBuffer(), cl.getForce().getDeviceBuffer(), 0, 0, bytesToCopy);
} }
OpenCLRemoveCMMotionKernel::~OpenCLRemoveCMMotionKernel() { OpenCLRemoveCMMotionKernel::~OpenCLRemoveCMMotionKernel() {
......
...@@ -249,6 +249,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -249,6 +249,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
energy = (needsEnergy[step] ? groupEnergy[flags] : 0); energy = (needsEnergy[step] ? groupEnergy[flags] : 0);
vector<Vec3>& stepForces = (needsForces[step] ? groupForces[flags] : forces); vector<Vec3>& stepForces = (needsForces[step] ? groupForces[flags] : forces);
int nextStep = step+1; int nextStep = step+1;
bool stepInvalidatesForces = invalidatesForces[step];
switch (stepType[step]) { switch (stepType[step]) {
case CustomIntegrator::ComputeGlobal: { case CustomIntegrator::ComputeGlobal: {
uniform = SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber(); uniform = SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber();
...@@ -295,7 +296,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -295,7 +296,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
} }
case CustomIntegrator::UpdateContextState: { case CustomIntegrator::UpdateContextState: {
recordChangedParameters(context, globals); recordChangedParameters(context, globals);
context.updateContextState(); stepInvalidatesForces = context.updateContextState();
globals.insert(context.getParameters().begin(), context.getParameters().end()); globals.insert(context.getParameters().begin(), context.getParameters().end());
for (auto& global : globals) for (auto& global : globals)
expressionSet.setVariable(expressionSet.getVariableIndex(global.first), global.second); expressionSet.setVariable(expressionSet.getVariableIndex(global.first), global.second);
...@@ -317,7 +318,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -317,7 +318,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
break; break;
} }
} }
if (invalidatesForces[step]) { if (stepInvalidatesForces) {
forcesAreValid = false; forcesAreValid = false;
groupForces.clear(); groupForces.clear();
groupEnergy.clear(); groupEnergy.clear();
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
const AmoebaAngleForce& getOwner() const { const AmoebaAngleForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
const AmoebaBondForce& getOwner() const { const AmoebaBondForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
...@@ -50,7 +50,7 @@ public: ...@@ -50,7 +50,7 @@ public:
const AmoebaGeneralizedKirkwoodForce& getOwner() const { const AmoebaGeneralizedKirkwoodForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
const AmoebaInPlaneAngleForce& getOwner() const { const AmoebaInPlaneAngleForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
const AmoebaMultipoleForce& getOwner() const { const AmoebaMultipoleForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
const AmoebaOutOfPlaneBendForce& getOwner() const { const AmoebaOutOfPlaneBendForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
const AmoebaPiTorsionForce& getOwner() const { const AmoebaPiTorsionForce& getOwner() const {
return owner; return owner;
} }
void updateContextState(ContextImpl& context) { void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly. // This force field doesn't update the state directly.
} }
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups); double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment