Unverified Commit c456dd54 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Use cuCtxPushCurrent() and cuCtxPopCurrent() for selecting CUDA context (#3258)

* Use cuCtxPushCurrent() and cuCtxPopCurrent() for selecting CUDA context

* Fixed errors in amoeba coda

* Fixed more errors in context selection
parent b2c35a8b
...@@ -93,9 +93,30 @@ public: ...@@ -93,9 +93,30 @@ public:
* doing any computation when you do not know what other code has just been executing on * doing any computation when you do not know what other code has just been executing on
* the thread. Platforms that rely on binding contexts to threads (such as CUDA) need to * the thread. Platforms that rely on binding contexts to threads (such as CUDA) need to
* implement this. * implement this.
*
* @deprecated It is recommended to use pushAsCurrent() and popAsCurrent() instead, or even better to create a ContextSelector.
* This provides better interoperability with other libraries that use CUDA and create
* their own contexts.
*/ */
virtual void setAsCurrent() { virtual void setAsCurrent() {
} }
/**
* Set this as the current context for the calling thread, maintaining any previous context
* on a stack. This should be called before doing any computation when you do not know what
* other code has just been executing on the thread. It must be paired with popAsCurrent()
* when you are done to restore the previous context. Alternatively, you can create a
* ContextSelector object to automate this for a block of code.
*
* Platforms that rely on binding contexts to threads (such as CUDA) need to implement this.
*/
virtual void pushAsCurrent() {
}
/**
* Restore a previous context that was replaced by pushAsCurrent(). Platforms that rely on binding
* contexts to threads (such as CUDA) need to implement this.
*/
virtual void popAsCurrent() {
}
/** /**
* Get the number of contexts being used for the current simulation. * Get the number of contexts being used for the current simulation.
* This is relevant when a simulation is parallelized across multiple devices. In that case, * This is relevant when a simulation is parallelized across multiple devices. In that case,
......
#ifndef OPENMM_CONTEXTSELECTOR_H_
#define OPENMM_CONTEXTSELECTOR_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "ComputeContext.h"
namespace OpenMM {
/**
* This class provides a safe and easy way to select a ComputeContext as current
* for a block of code. The constructor calls pushAsCurrent() on the context.
* When it goes out of scope, the destructor calls popAsCurrent() on it. Simply
* create a local variable of this class, and the context will be current for
* the remainder of the code block in which it is declared.
*/
class OPENMM_EXPORT_COMMON ContextSelector {
public:
ContextSelector(ComputeContext& context) : context(context) {
context.pushAsCurrent();
}
~ContextSelector() {
context.popAsCurrent();
}
private:
ComputeContext& context;
};
} // namespace OpenMM
#endif /*OPENMM_CONTEXTSELECTOR_H_*/
This diff is collapsed.
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2019 Stanford University and the Authors. * * Portions copyright (c) 2019-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "openmm/common/ComputeContext.h" #include "openmm/common/ComputeContext.h"
#include "openmm/common/ContextSelector.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/VirtualSite.h" #include "openmm/VirtualSite.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
...@@ -362,6 +363,7 @@ bool ComputeContext::invalidateMolecules(ComputeForceInfo* force) { ...@@ -362,6 +363,7 @@ bool ComputeContext::invalidateMolecules(ComputeForceInfo* force) {
// atoms to their original order, rebuild the list of identical molecules, and sort them // atoms to their original order, rebuild the list of identical molecules, and sort them
// again. // again.
ContextSelector selector(*this);
vector<mm_int4> newCellOffsets(numAtoms); vector<mm_int4> newCellOffsets(numAtoms);
if (getUseDoublePrecision()) { if (getUseDoublePrecision()) {
vector<mm_double4> oldPosq(paddedNumAtoms); vector<mm_double4> oldPosq(paddedNumAtoms);
...@@ -598,6 +600,7 @@ void ComputeContext::reorderAtomsImpl() { ...@@ -598,6 +600,7 @@ void ComputeContext::reorderAtomsImpl() {
// Update the arrays. // Update the arrays.
ContextSelector selector(*this);
for (int i = 0; i < numAtoms; i++) { for (int i = 0; i < numAtoms; i++) {
atomIndex[i] = originalIndex[i]; atomIndex[i] = originalIndex[i];
posCellOffsets[i] = newCellOffsets[i]; posCellOffsets[i] = newCellOffsets[i];
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2020 Stanford University and the Authors. * * Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "openmm/common/IntegrationUtilities.h" #include "openmm/common/IntegrationUtilities.h"
#include "openmm/common/ComputeContext.h" #include "openmm/common/ComputeContext.h"
#include "openmm/common/ContextSelector.h"
#include "CommonKernelSources.h" #include "CommonKernelSources.h"
#include "openmm/internal/OSRngSeed.h" #include "openmm/internal/OSRngSeed.h"
#include "openmm/HarmonicAngleForce.h" #include "openmm/HarmonicAngleForce.h"
...@@ -736,6 +737,7 @@ void IntegrationUtilities::applyVelocityConstraints(double tol) { ...@@ -736,6 +737,7 @@ void IntegrationUtilities::applyVelocityConstraints(double tol) {
} }
void IntegrationUtilities::computeVirtualSites() { void IntegrationUtilities::computeVirtualSites() {
ContextSelector selector(context);
if (numVsites > 0) if (numVsites > 0)
vsitePositionKernel->execute(numVsites); vsitePositionKernel->execute(numVsites);
} }
...@@ -812,6 +814,7 @@ void IntegrationUtilities::loadCheckpoint(istream& stream) { ...@@ -812,6 +814,7 @@ void IntegrationUtilities::loadCheckpoint(istream& stream) {
} }
double IntegrationUtilities::computeKineticEnergy(double timeShift) { double IntegrationUtilities::computeKineticEnergy(double timeShift) {
ContextSelector selector(context);
int numParticles = context.getNumAtoms(); int numParticles = context.getNumAtoms();
if (timeShift != 0) { if (timeShift != 0) {
// Copy the velocities into the posDelta array while we temporarily modify them. // Copy the velocities into the posDelta array while we temporarily modify them.
......
...@@ -100,6 +100,16 @@ public: ...@@ -100,6 +100,16 @@ public:
* valid, this returns without doing anything. * valid, this returns without doing anything.
*/ */
void setAsCurrent(); void setAsCurrent();
/**
* Push the CUcontext associated with this object to be the current context. If the context is not
* valid, this returns without doing anything.
*/
void pushAsCurrent();
/**
* Pop the CUcontext associated with this object off the stack of contexts. If the context is not
* valid, this returns without doing anything.
*/
void popAsCurrent();
/** /**
* Get the CUdevice associated with this object. * Get the CUdevice associated with this object.
*/ */
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "CudaArray.h" #include "CudaArray.h"
#include "CudaContext.h" #include "CudaContext.h"
#include "openmm/common/ContextSelector.h"
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <vector> #include <vector>
...@@ -41,7 +42,7 @@ CudaArray::CudaArray(CudaContext& context, int size, int elementSize, const std: ...@@ -41,7 +42,7 @@ CudaArray::CudaArray(CudaContext& context, int size, int elementSize, const std:
CudaArray::~CudaArray() { CudaArray::~CudaArray() {
if (pointer != 0 && ownsMemory && context->getContextIsValid()) { if (pointer != 0 && ownsMemory && context->getContextIsValid()) {
context->setAsCurrent(); ContextSelector selector(*context);
CUresult result = cuMemFree(pointer); CUresult result = cuMemFree(pointer);
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
std::stringstream str; std::stringstream str;
...@@ -59,6 +60,7 @@ void CudaArray::initialize(ComputeContext& context, int size, int elementSize, c ...@@ -59,6 +60,7 @@ void CudaArray::initialize(ComputeContext& context, int size, int elementSize, c
this->elementSize = elementSize; this->elementSize = elementSize;
this->name = name; this->name = name;
ownsMemory = true; ownsMemory = true;
ContextSelector selector(*this->context);
CUresult result = cuMemAlloc(&pointer, size*elementSize); CUresult result = cuMemAlloc(&pointer, size*elementSize);
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
std::stringstream str; std::stringstream str;
...@@ -72,6 +74,7 @@ void CudaArray::resize(int size) { ...@@ -72,6 +74,7 @@ void CudaArray::resize(int size) {
throw OpenMMException("CudaArray has not been initialized"); throw OpenMMException("CudaArray has not been initialized");
if (!ownsMemory) if (!ownsMemory)
throw OpenMMException("Cannot resize an array that does not own its storage"); throw OpenMMException("Cannot resize an array that does not own its storage");
ContextSelector selector(*context);
CUresult result = cuMemFree(pointer); CUresult result = cuMemFree(pointer);
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
std::stringstream str; std::stringstream str;
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "CudaNonbondedUtilities.h" #include "CudaNonbondedUtilities.h"
#include "CudaProgram.h" #include "CudaProgram.h"
#include "openmm/common/ComputeArray.h" #include "openmm/common/ComputeArray.h"
#include "openmm/common/ContextSelector.h"
#include "SHA1.h" #include "SHA1.h"
#include "openmm/Platform.h" #include "openmm/Platform.h"
#include "openmm/System.h" #include "openmm/System.h"
...@@ -190,6 +191,8 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking ...@@ -190,6 +191,8 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
if (cuCtxCreate(&context, flags, device) == CUDA_SUCCESS) { if (cuCtxCreate(&context, flags, device) == CUDA_SUCCESS) {
this->deviceIndex = trialDeviceIndex; this->deviceIndex = trialDeviceIndex;
CUcontext popped;
cuCtxPopCurrent(&popped);
break; break;
} }
} }
...@@ -231,14 +234,16 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking ...@@ -231,14 +234,16 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
computeCapability = major+0.1*minor; computeCapability = major+0.1*minor;
contextIsValid = true; contextIsValid = true;
ContextSelector selector(*this);
CHECK_RESULT(cuCtxSetCacheConfig(CU_FUNC_CACHE_PREFER_SHARED)); CHECK_RESULT(cuCtxSetCacheConfig(CU_FUNC_CACHE_PREFER_SHARED));
if (contextIndex > 0) { if (contextIndex > 0) {
int canAccess; int canAccess;
cuDeviceCanAccessPeer(&canAccess, getDevice(), platformData.contexts[0]->getDevice()); cuDeviceCanAccessPeer(&canAccess, getDevice(), platformData.contexts[0]->getDevice());
if (canAccess) { if (canAccess) {
platformData.contexts[0]->setAsCurrent(); {
ContextSelector selector2(*platformData.contexts[0]);
CHECK_RESULT(cuCtxEnablePeerAccess(getContext(), 0)); CHECK_RESULT(cuCtxEnablePeerAccess(getContext(), 0));
setAsCurrent(); }
CHECK_RESULT(cuCtxEnablePeerAccess(platformData.contexts[0]->getContext(), 0)); CHECK_RESULT(cuCtxEnablePeerAccess(platformData.contexts[0]->getContext(), 0));
} }
} }
...@@ -397,7 +402,7 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking ...@@ -397,7 +402,7 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
} }
CudaContext::~CudaContext() { CudaContext::~CudaContext() {
setAsCurrent(); pushAsCurrent();
for (auto force : forces) for (auto force : forces)
delete force; delete force;
for (auto listener : reorderListeners) for (auto listener : reorderListeners)
...@@ -416,6 +421,7 @@ CudaContext::~CudaContext() { ...@@ -416,6 +421,7 @@ CudaContext::~CudaContext() {
delete bonded; delete bonded;
if (nonbonded != NULL) if (nonbonded != NULL)
delete nonbonded; delete nonbonded;
popAsCurrent();
string errorMessage = "Error deleting Context"; string errorMessage = "Error deleting Context";
if (contextIsValid && !isLinkedContext) { if (contextIsValid && !isLinkedContext) {
cuProfilerStop(); cuProfilerStop();
...@@ -425,7 +431,7 @@ CudaContext::~CudaContext() { ...@@ -425,7 +431,7 @@ CudaContext::~CudaContext() {
} }
void CudaContext::initialize() { void CudaContext::initialize() {
cuCtxSetCurrent(context); ContextSelector selector(*this);
string errorMessage = "Error initializing Context"; string errorMessage = "Error initializing Context";
int numEnergyBuffers = max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers()); int numEnergyBuffers = max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers());
if (useDoublePrecision) { if (useDoublePrecision) {
...@@ -478,6 +484,17 @@ void CudaContext::setAsCurrent() { ...@@ -478,6 +484,17 @@ void CudaContext::setAsCurrent() {
cuCtxSetCurrent(context); cuCtxSetCurrent(context);
} }
void CudaContext::pushAsCurrent() {
if (contextIsValid)
cuCtxPushCurrent(context);
}
void CudaContext::popAsCurrent() {
CUcontext popped;
if (contextIsValid)
cuCtxPopCurrent(&popped);
}
CUmodule CudaContext::createModule(const string source, const char* optimizationFlags) { CUmodule CudaContext::createModule(const string source, const char* optimizationFlags) {
return createModule(source, map<string, string>(), optimizationFlags); return createModule(source, map<string, string>(), optimizationFlags);
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2020 Stanford University and the Authors. * * Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "CudaIntegrationUtilities.h" #include "CudaIntegrationUtilities.h"
#include "CudaContext.h" #include "CudaContext.h"
#include "openmm/common/ContextSelector.h"
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
...@@ -46,7 +47,7 @@ CudaIntegrationUtilities::CudaIntegrationUtilities(CudaContext& context, const S ...@@ -46,7 +47,7 @@ CudaIntegrationUtilities::CudaIntegrationUtilities(CudaContext& context, const S
} }
CudaIntegrationUtilities::~CudaIntegrationUtilities() { CudaIntegrationUtilities::~CudaIntegrationUtilities() {
context.setAsCurrent(); ContextSelector selector(context);
if (ccmaConvergedMemory != NULL) { if (ccmaConvergedMemory != NULL) {
cuMemFreeHost(ccmaConvergedMemory); cuMemFreeHost(ccmaConvergedMemory);
cuEventDestroy(ccmaEvent); cuEventDestroy(ccmaEvent);
...@@ -66,6 +67,7 @@ CudaArray& CudaIntegrationUtilities::getStepSize() { ...@@ -66,6 +67,7 @@ CudaArray& CudaIntegrationUtilities::getStepSize() {
} }
void CudaIntegrationUtilities::applyConstraintsImpl(bool constrainVelocities, double tol) { void CudaIntegrationUtilities::applyConstraintsImpl(bool constrainVelocities, double tol) {
ContextSelector selector(context);
ComputeKernel settleKernel, shakeKernel, ccmaForceKernel; ComputeKernel settleKernel, shakeKernel, ccmaForceKernel;
if (constrainVelocities) { if (constrainVelocities) {
settleKernel = settleVelKernel; settleKernel = settleVelKernel;
...@@ -131,6 +133,7 @@ void CudaIntegrationUtilities::applyConstraintsImpl(bool constrainVelocities, do ...@@ -131,6 +133,7 @@ void CudaIntegrationUtilities::applyConstraintsImpl(bool constrainVelocities, do
} }
void CudaIntegrationUtilities::distributeForcesFromVirtualSites() { void CudaIntegrationUtilities::distributeForcesFromVirtualSites() {
ContextSelector selector(context);
if (numVsites > 0) { if (numVsites > 0) {
vsiteForceKernel->setArg(2, context.getLongForceBuffer()); vsiteForceKernel->setArg(2, context.getLongForceBuffer());
vsiteForceKernel->execute(numVsites); vsiteForceKernel->execute(numVsites);
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h" #include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/common/ContextSelector.h"
#include "CommonKernelSources.h" #include "CommonKernelSources.h"
#include "CudaBondedUtilities.h" #include "CudaBondedUtilities.h"
#include "CudaExpressionUtilities.h" #include "CudaExpressionUtilities.h"
...@@ -58,7 +59,7 @@ void CudaCalcForcesAndEnergyKernel::initialize(const System& system) { ...@@ -58,7 +59,7 @@ void CudaCalcForcesAndEnergyKernel::initialize(const System& system) {
void CudaCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { void CudaCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
cu.setForcesValid(true); cu.setForcesValid(true);
cu.setAsCurrent(); ContextSelector selector(cu);
cu.clearAutoclearBuffers(); cu.clearAutoclearBuffers();
for (auto computation : cu.getPreComputations()) for (auto computation : cu.getPreComputations())
computation->computeForceAndEnergy(includeForces, includeEnergy, groups); computation->computeForceAndEnergy(includeForces, includeEnergy, groups);
...@@ -71,7 +72,7 @@ void CudaCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool ...@@ -71,7 +72,7 @@ void CudaCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool
} }
double CudaCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups, bool& valid) { double CudaCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups, bool& valid) {
cu.setAsCurrent(); ContextSelector selector(cu);
cu.getBondedUtilities().computeInteractions(groups); cu.getBondedUtilities().computeInteractions(groups);
cu.getNonbondedUtilities().computeInteractions(groups, includeForces, includeEnergy); cu.getNonbondedUtilities().computeInteractions(groups, includeForces, includeEnergy);
double sum = 0.0; double sum = 0.0;
...@@ -109,7 +110,7 @@ void CudaUpdateStateDataKernel::setStepCount(const ContextImpl& context, long lo ...@@ -109,7 +110,7 @@ void CudaUpdateStateDataKernel::setStepCount(const ContextImpl& context, long lo
} }
void CudaUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) { void CudaUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) {
cu.setAsCurrent(); ContextSelector selector(cu);
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
positions.resize(numParticles); positions.resize(numParticles);
vector<float4> posCorrection; vector<float4> posCorrection;
...@@ -170,7 +171,7 @@ void CudaUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& ...@@ -170,7 +171,7 @@ void CudaUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>&
} }
void CudaUpdateStateDataKernel::setPositions(ContextImpl& context, const vector<Vec3>& positions) { void CudaUpdateStateDataKernel::setPositions(ContextImpl& context, const vector<Vec3>& positions) {
cu.setAsCurrent(); ContextSelector selector(cu);
const vector<int>& order = cu.getAtomIndex(); const vector<int>& order = cu.getAtomIndex();
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
if (cu.getUseDoublePrecision()) { if (cu.getUseDoublePrecision()) {
...@@ -221,7 +222,7 @@ void CudaUpdateStateDataKernel::setPositions(ContextImpl& context, const vector< ...@@ -221,7 +222,7 @@ void CudaUpdateStateDataKernel::setPositions(ContextImpl& context, const vector<
} }
void CudaUpdateStateDataKernel::getVelocities(ContextImpl& context, vector<Vec3>& velocities) { void CudaUpdateStateDataKernel::getVelocities(ContextImpl& context, vector<Vec3>& velocities) {
cu.setAsCurrent(); ContextSelector selector(cu);
const vector<int>& order = cu.getAtomIndex(); const vector<int>& order = cu.getAtomIndex();
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
velocities.resize(numParticles); velocities.resize(numParticles);
...@@ -246,7 +247,7 @@ void CudaUpdateStateDataKernel::getVelocities(ContextImpl& context, vector<Vec3> ...@@ -246,7 +247,7 @@ void CudaUpdateStateDataKernel::getVelocities(ContextImpl& context, vector<Vec3>
} }
void CudaUpdateStateDataKernel::setVelocities(ContextImpl& context, const vector<Vec3>& velocities) { void CudaUpdateStateDataKernel::setVelocities(ContextImpl& context, const vector<Vec3>& velocities) {
cu.setAsCurrent(); ContextSelector selector(cu);
const vector<int>& order = cu.getAtomIndex(); const vector<int>& order = cu.getAtomIndex();
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) { if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) {
...@@ -280,7 +281,7 @@ void CudaUpdateStateDataKernel::setVelocities(ContextImpl& context, const vector ...@@ -280,7 +281,7 @@ void CudaUpdateStateDataKernel::setVelocities(ContextImpl& context, const vector
} }
void CudaUpdateStateDataKernel::getForces(ContextImpl& context, vector<Vec3>& forces) { void CudaUpdateStateDataKernel::getForces(ContextImpl& context, vector<Vec3>& forces) {
cu.setAsCurrent(); ContextSelector selector(cu);
long long* force = (long long*) cu.getPinnedBuffer(); long long* force = (long long*) cu.getPinnedBuffer();
cu.getForce().download(force); cu.getForce().download(force);
const vector<int>& order = cu.getAtomIndex(); const vector<int>& order = cu.getAtomIndex();
...@@ -293,6 +294,7 @@ void CudaUpdateStateDataKernel::getForces(ContextImpl& context, vector<Vec3>& fo ...@@ -293,6 +294,7 @@ void CudaUpdateStateDataKernel::getForces(ContextImpl& context, vector<Vec3>& fo
} }
void CudaUpdateStateDataKernel::getEnergyParameterDerivatives(ContextImpl& context, map<string, double>& derivs) { void CudaUpdateStateDataKernel::getEnergyParameterDerivatives(ContextImpl& context, map<string, double>& derivs) {
ContextSelector selector(cu);
const vector<string>& paramDerivNames = cu.getEnergyParamDerivNames(); const vector<string>& paramDerivNames = cu.getEnergyParamDerivNames();
int numDerivs = paramDerivNames.size(); int numDerivs = paramDerivNames.size();
if (numDerivs == 0) if (numDerivs == 0)
...@@ -346,7 +348,7 @@ void CudaUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, cons ...@@ -346,7 +348,7 @@ void CudaUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, cons
} }
void CudaUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) { void CudaUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
cu.setAsCurrent(); ContextSelector selector(cu);
int version = 3; int version = 3;
stream.write((char*) &version, sizeof(int)); stream.write((char*) &version, sizeof(int));
int precision = (cu.getUseDoublePrecision() ? 2 : cu.getUseMixedPrecision() ? 1 : 0); int precision = (cu.getUseDoublePrecision() ? 2 : cu.getUseMixedPrecision() ? 1 : 0);
...@@ -376,7 +378,7 @@ void CudaUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& ...@@ -376,7 +378,7 @@ void CudaUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream&
} }
void CudaUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) { void CudaUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
cu.setAsCurrent(); ContextSelector selector(cu);
int version; int version;
stream.read((char*) &version, sizeof(int)); stream.read((char*) &version, sizeof(int));
if (version != 3) if (version != 3)
...@@ -458,7 +460,7 @@ public: ...@@ -458,7 +460,7 @@ public:
forceTemp.initialize<float4>(cu, cu.getNumAtoms(), "PmeForce"); forceTemp.initialize<float4>(cu, cu.getNumAtoms(), "PmeForce");
} }
float* getPosq() { float* getPosq() {
cu.setAsCurrent(); ContextSelector selector(cu);
cu.getPosq().download(posq); cu.getPosq().download(posq);
return (float*) &posq[0]; return (float*) &posq[0];
} }
...@@ -542,7 +544,7 @@ private: ...@@ -542,7 +544,7 @@ private:
}; };
CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() { CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() {
cu.setAsCurrent(); ContextSelector selector(cu);
if (sort != NULL) if (sort != NULL)
delete sort; delete sort;
if (fft != NULL) if (fft != NULL)
...@@ -569,7 +571,7 @@ CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() { ...@@ -569,7 +571,7 @@ CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() {
} }
void CudaCalcNonbondedForceKernel::initialize(const System& system, const NonbondedForce& force) { void CudaCalcNonbondedForceKernel::initialize(const System& system, const NonbondedForce& force) {
cu.setAsCurrent(); ContextSelector selector(cu);
int forceIndex; int forceIndex;
for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex) for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex)
; ;
...@@ -1129,6 +1131,7 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon ...@@ -1129,6 +1131,7 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) { double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) {
// Update particle and exception parameters. // Update particle and exception parameters.
ContextSelector selector(cu);
bool paramChanged = false; bool paramChanged = false;
for (int i = 0; i < paramNames.size(); i++) { for (int i = 0; i < paramNames.size(); i++) {
double value = context.getParameter(paramNames[i]); double value = context.getParameter(paramNames[i]);
...@@ -1364,7 +1367,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF ...@@ -1364,7 +1367,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
void CudaCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const NonbondedForce& force) { void CudaCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const NonbondedForce& force) {
// Make sure the new parameters are acceptable. // Make sure the new parameters are acceptable.
cu.setAsCurrent(); ContextSelector selector(cu);
if (force.getNumParticles() != cu.getNumAtoms()) if (force.getNumParticles() != cu.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed"); throw OpenMMException("updateParametersInContext: The number of particles has changed");
if (!hasCoulomb || !hasLJ) { if (!hasCoulomb || !hasLJ) {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2019 Stanford University and the Authors. * * Portions copyright (c) 2011-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "CudaParallelKernels.h" #include "CudaParallelKernels.h"
#include "CudaKernelSources.h" #include "CudaKernelSources.h"
#include "openmm/common/ContextSelector.h"
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
...@@ -69,7 +70,7 @@ public: ...@@ -69,7 +70,7 @@ public:
void execute() { void execute() {
// Copy coordinates over to this device and execute the kernel. // Copy coordinates over to this device and execute the kernel.
cu.setAsCurrent(); ContextSelector selector(cu);
if (cu.getContextIndex() > 0) { if (cu.getContextIndex() > 0) {
cuStreamWaitEvent(cu.getCurrentStream(), event, 0); cuStreamWaitEvent(cu.getCurrentStream(), event, 0);
if (!cu.getPlatformData().peerAccessSupported) if (!cu.getPlatformData().peerAccessSupported)
...@@ -100,6 +101,7 @@ public: ...@@ -100,6 +101,7 @@ public:
void execute() { void execute() {
// Execute the kernel, then download forces. // Execute the kernel, then download forces.
ContextSelector selector(cu);
energy += kernel.finishComputation(context, includeForce, includeEnergy, groups, valid); energy += kernel.finishComputation(context, includeForce, includeEnergy, groups, valid);
if (cu.getComputeForceCount() < 200) { if (cu.getComputeForceCount() < 200) {
// Record timing information for load balancing. Since this takes time, only do it at the start of the simulation. // Record timing information for load balancing. Since this takes time, only do it at the start of the simulation.
...@@ -148,7 +150,7 @@ CudaParallelCalcForcesAndEnergyKernel::CudaParallelCalcForcesAndEnergyKernel(str ...@@ -148,7 +150,7 @@ CudaParallelCalcForcesAndEnergyKernel::CudaParallelCalcForcesAndEnergyKernel(str
} }
CudaParallelCalcForcesAndEnergyKernel::~CudaParallelCalcForcesAndEnergyKernel() { CudaParallelCalcForcesAndEnergyKernel::~CudaParallelCalcForcesAndEnergyKernel() {
data.contexts[0]->setAsCurrent(); ContextSelector selector(*data.contexts[0]);
if (pinnedPositionBuffer != NULL) if (pinnedPositionBuffer != NULL)
cuMemFreeHost(pinnedPositionBuffer); cuMemFreeHost(pinnedPositionBuffer);
if (pinnedForceBuffer != NULL) if (pinnedForceBuffer != NULL)
...@@ -161,7 +163,7 @@ CudaParallelCalcForcesAndEnergyKernel::~CudaParallelCalcForcesAndEnergyKernel() ...@@ -161,7 +163,7 @@ CudaParallelCalcForcesAndEnergyKernel::~CudaParallelCalcForcesAndEnergyKernel()
void CudaParallelCalcForcesAndEnergyKernel::initialize(const System& system) { void CudaParallelCalcForcesAndEnergyKernel::initialize(const System& system) {
CudaContext& cu = *data.contexts[0]; CudaContext& cu = *data.contexts[0];
cu.setAsCurrent(); ContextSelector selector(cu);
CUmodule module = cu.createModule(CudaKernelSources::parallel); CUmodule module = cu.createModule(CudaKernelSources::parallel);
sumKernel = cu.getKernel(module, "sumForces"); sumKernel = cu.getKernel(module, "sumForces");
int numContexts = data.contexts.size(); int numContexts = data.contexts.size();
...@@ -176,7 +178,7 @@ void CudaParallelCalcForcesAndEnergyKernel::initialize(const System& system) { ...@@ -176,7 +178,7 @@ void CudaParallelCalcForcesAndEnergyKernel::initialize(const System& system) {
void CudaParallelCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) { void CudaParallelCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) {
CudaContext& cu = *data.contexts[0]; CudaContext& cu = *data.contexts[0];
cu.setAsCurrent(); ContextSelector selector(cu);
if (!contextForces.isInitialized()) { if (!contextForces.isInitialized()) {
contextForces.initialize<long long>(cu, 3*(data.contexts.size()-1)*cu.getPaddedNumAtoms(), "contextForces"); contextForces.initialize<long long>(cu, 3*(data.contexts.size()-1)*cu.getPaddedNumAtoms(), "contextForces");
CHECK_RESULT(cuMemHostAlloc((void**) &pinnedForceBuffer, 3*(data.contexts.size()-1)*cu.getPaddedNumAtoms()*sizeof(long long), CU_MEMHOSTALLOC_PORTABLE), "Error allocating pinned memory"); CHECK_RESULT(cuMemHostAlloc((void**) &pinnedForceBuffer, 3*(data.contexts.size()-1)*cu.getPaddedNumAtoms()*sizeof(long long), CU_MEMHOSTALLOC_PORTABLE), "Error allocating pinned memory");
...@@ -219,6 +221,7 @@ double CudaParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& con ...@@ -219,6 +221,7 @@ double CudaParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& con
// Sum the forces from all devices. // Sum the forces from all devices.
CudaContext& cu = *data.contexts[0]; CudaContext& cu = *data.contexts[0];
ContextSelector selector(cu);
if (!cu.getPlatformData().peerAccessSupported) if (!cu.getPlatformData().peerAccessSupported)
contextForces.upload(pinnedForceBuffer, false); contextForces.upload(pinnedForceBuffer, false);
int bufferSize = 3*cu.getPaddedNumAtoms(); int bufferSize = 3*cu.getPaddedNumAtoms();
......
...@@ -59,6 +59,7 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize) { ...@@ -59,6 +59,7 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
platform.getPropertyDefaultValue(CudaPlatform::CudaHostCompiler()), platform.getPropertyDefaultValue(CudaPlatform::CudaDisablePmeStream()), "false", true, 1, NULL); platform.getPropertyDefaultValue(CudaPlatform::CudaHostCompiler()), platform.getPropertyDefaultValue(CudaPlatform::CudaDisablePmeStream()), "false", true, 1, NULL);
CudaContext& context = *platformData.contexts[0]; CudaContext& context = *platformData.contexts[0];
context.initialize(); context.initialize();
context.setAsCurrent();
OpenMM_SFMT::SFMT sfmt; OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt); init_gen_rand(0, sfmt);
vector<Real2> original(xsize*ysize*zsize); vector<Real2> original(xsize*ysize*zsize);
......
...@@ -172,7 +172,7 @@ bool canRunHugeTest() { ...@@ -172,7 +172,7 @@ bool canRunHugeTest() {
// Only run the huge test if the device has at least 4 GB of memory. // Only run the huge test if the device has at least 4 GB of memory.
return (memory >= 4*(1<<30)); return (memory >= 4L*(1<<30));
} }
void runPlatformTests() { void runPlatformTests() {
......
...@@ -59,6 +59,7 @@ void testGaussian() { ...@@ -59,6 +59,7 @@ void testGaussian() {
platform.getPropertyDefaultValue(CudaPlatform::CudaHostCompiler()), platform.getPropertyDefaultValue(CudaPlatform::CudaDisablePmeStream()), "false", true, 1, NULL); platform.getPropertyDefaultValue(CudaPlatform::CudaHostCompiler()), platform.getPropertyDefaultValue(CudaPlatform::CudaDisablePmeStream()), "false", true, 1, NULL);
CudaContext& context = *platformData.contexts[0]; CudaContext& context = *platformData.contexts[0];
context.initialize(); context.initialize();
context.setAsCurrent();
context.getIntegrationUtilities().initRandomNumberGenerator(0); context.getIntegrationUtilities().initRandomNumberGenerator(0);
CudaArray& random = context.getIntegrationUtilities().getRandom(); CudaArray& random = context.getIntegrationUtilities().getRandom();
context.getIntegrationUtilities().prepareRandomNumbers(random.getSize()); context.getIntegrationUtilities().prepareRandomNumbers(random.getSize());
......
...@@ -69,6 +69,7 @@ void verifySorting(vector<float> array) { ...@@ -69,6 +69,7 @@ void verifySorting(vector<float> array) {
platform.getPropertyDefaultValue(CudaPlatform::CudaHostCompiler()), platform.getPropertyDefaultValue(CudaPlatform::CudaDisablePmeStream()), "false", true, 1, NULL); platform.getPropertyDefaultValue(CudaPlatform::CudaHostCompiler()), platform.getPropertyDefaultValue(CudaPlatform::CudaDisablePmeStream()), "false", true, 1, NULL);
CudaContext& context = *platformData.contexts[0]; CudaContext& context = *platformData.contexts[0];
context.initialize(); context.initialize();
context.setAsCurrent();
CudaArray data(context, array.size(), 4, "sortData"); CudaArray data(context, array.size(), 4, "sortData");
data.upload(array); data.upload(array);
CudaSort sort(context, new SortTrait(), array.size()); CudaSort sort(context, new SortTrait(), array.size());
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#endif #endif
#include "AmoebaCommonKernels.h" #include "AmoebaCommonKernels.h"
#include "CommonAmoebaKernelSources.h" #include "CommonAmoebaKernelSources.h"
#include "openmm/common/ContextSelector.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/internal/AmoebaGeneralizedKirkwoodForceImpl.h" #include "openmm/internal/AmoebaGeneralizedKirkwoodForceImpl.h"
#include "openmm/internal/AmoebaMultipoleForceImpl.h" #include "openmm/internal/AmoebaMultipoleForceImpl.h"
...@@ -122,7 +123,7 @@ CommonCalcAmoebaTorsionTorsionForceKernel::CommonCalcAmoebaTorsionTorsionForceKe ...@@ -122,7 +123,7 @@ CommonCalcAmoebaTorsionTorsionForceKernel::CommonCalcAmoebaTorsionTorsionForceKe
} }
void CommonCalcAmoebaTorsionTorsionForceKernel::initialize(const System& system, const AmoebaTorsionTorsionForce& force) { void CommonCalcAmoebaTorsionTorsionForceKernel::initialize(const System& system, const AmoebaTorsionTorsionForce& force) {
cc.setAsCurrent(); ContextSelector selector(cc);
int numContexts = cc.getNumContexts(); int numContexts = cc.getNumContexts();
int startIndex = cc.getContextIndex()*force.getNumTorsionTorsions()/numContexts; int startIndex = cc.getContextIndex()*force.getNumTorsionTorsions()/numContexts;
int endIndex = (cc.getContextIndex()+1)*force.getNumTorsionTorsions()/numContexts; int endIndex = (cc.getContextIndex()+1)*force.getNumTorsionTorsions()/numContexts;
...@@ -230,11 +231,10 @@ CommonCalcAmoebaMultipoleForceKernel::CommonCalcAmoebaMultipoleForceKernel(const ...@@ -230,11 +231,10 @@ CommonCalcAmoebaMultipoleForceKernel::CommonCalcAmoebaMultipoleForceKernel(const
} }
CommonCalcAmoebaMultipoleForceKernel::~CommonCalcAmoebaMultipoleForceKernel() { CommonCalcAmoebaMultipoleForceKernel::~CommonCalcAmoebaMultipoleForceKernel() {
cc.setAsCurrent();
} }
void CommonCalcAmoebaMultipoleForceKernel::initialize(const System& system, const AmoebaMultipoleForce& force) { void CommonCalcAmoebaMultipoleForceKernel::initialize(const System& system, const AmoebaMultipoleForce& force) {
cc.setAsCurrent(); ContextSelector selector(cc);
if (!cc.getSupports64BitGlobalAtomics()) if (!cc.getSupports64BitGlobalAtomics())
throw OpenMMException("AmoebaMultipoleForce requires a device that supports 64 bit atomic operations"); throw OpenMMException("AmoebaMultipoleForce requires a device that supports 64 bit atomic operations");
...@@ -1045,6 +1045,7 @@ void CommonCalcAmoebaMultipoleForceKernel::initializeScaleFactors() { ...@@ -1045,6 +1045,7 @@ void CommonCalcAmoebaMultipoleForceKernel::initializeScaleFactors() {
} }
double CommonCalcAmoebaMultipoleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcAmoebaMultipoleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
if (!hasInitializedScaleFactors) { if (!hasInitializedScaleFactors) {
initializeScaleFactors(); initializeScaleFactors();
for (auto impl : context.getForceImpls()) { for (auto impl : context.getForceImpls()) {
...@@ -1412,6 +1413,7 @@ void CommonCalcAmoebaMultipoleForceKernel::ensureMultipolesValid(ContextImpl& co ...@@ -1412,6 +1413,7 @@ void CommonCalcAmoebaMultipoleForceKernel::ensureMultipolesValid(ContextImpl& co
} }
void CommonCalcAmoebaMultipoleForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) { void CommonCalcAmoebaMultipoleForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
ContextSelector selector(cc);
ensureMultipolesValid(context); ensureMultipolesValid(context);
int numParticles = cc.getNumAtoms(); int numParticles = cc.getNumAtoms();
dipoles.resize(numParticles); dipoles.resize(numParticles);
...@@ -1432,6 +1434,7 @@ void CommonCalcAmoebaMultipoleForceKernel::getLabFramePermanentDipoles(ContextIm ...@@ -1432,6 +1434,7 @@ void CommonCalcAmoebaMultipoleForceKernel::getLabFramePermanentDipoles(ContextIm
void CommonCalcAmoebaMultipoleForceKernel::getInducedDipoles(ContextImpl& context, vector<Vec3>& dipoles) { void CommonCalcAmoebaMultipoleForceKernel::getInducedDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
ContextSelector selector(cc);
ensureMultipolesValid(context); ensureMultipolesValid(context);
int numParticles = cc.getNumAtoms(); int numParticles = cc.getNumAtoms();
dipoles.resize(numParticles); dipoles.resize(numParticles);
...@@ -1452,6 +1455,7 @@ void CommonCalcAmoebaMultipoleForceKernel::getInducedDipoles(ContextImpl& contex ...@@ -1452,6 +1455,7 @@ void CommonCalcAmoebaMultipoleForceKernel::getInducedDipoles(ContextImpl& contex
void CommonCalcAmoebaMultipoleForceKernel::getTotalDipoles(ContextImpl& context, vector<Vec3>& dipoles) { void CommonCalcAmoebaMultipoleForceKernel::getTotalDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
ContextSelector selector(cc);
ensureMultipolesValid(context); ensureMultipolesValid(context);
int numParticles = cc.getNumAtoms(); int numParticles = cc.getNumAtoms();
dipoles.resize(numParticles); dipoles.resize(numParticles);
...@@ -1493,6 +1497,7 @@ void CommonCalcAmoebaMultipoleForceKernel::getTotalDipoles(ContextImpl& context, ...@@ -1493,6 +1497,7 @@ void CommonCalcAmoebaMultipoleForceKernel::getTotalDipoles(ContextImpl& context,
} }
void CommonCalcAmoebaMultipoleForceKernel::getElectrostaticPotential(ContextImpl& context, const vector<Vec3>& inputGrid, vector<double>& outputElectrostaticPotential) { void CommonCalcAmoebaMultipoleForceKernel::getElectrostaticPotential(ContextImpl& context, const vector<Vec3>& inputGrid, vector<double>& outputElectrostaticPotential) {
ContextSelector selector(cc);
ensureMultipolesValid(context); ensureMultipolesValid(context);
int numPoints = inputGrid.size(); int numPoints = inputGrid.size();
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float)); int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
...@@ -1647,6 +1652,7 @@ void CommonCalcAmoebaMultipoleForceKernel::computeSystemMultipoleMoments(Context ...@@ -1647,6 +1652,7 @@ void CommonCalcAmoebaMultipoleForceKernel::computeSystemMultipoleMoments(Context
void CommonCalcAmoebaMultipoleForceKernel::getSystemMultipoleMoments(ContextImpl& context, vector<double>& outputMultipoleMoments) { void CommonCalcAmoebaMultipoleForceKernel::getSystemMultipoleMoments(ContextImpl& context, vector<double>& outputMultipoleMoments) {
ContextSelector selector(cc);
ensureMultipolesValid(context); ensureMultipolesValid(context);
if (cc.getUseDoublePrecision()) if (cc.getUseDoublePrecision())
computeSystemMultipoleMoments<double, mm_double4, mm_double4>(context, outputMultipoleMoments); computeSystemMultipoleMoments<double, mm_double4, mm_double4>(context, outputMultipoleMoments);
...@@ -1659,7 +1665,7 @@ void CommonCalcAmoebaMultipoleForceKernel::getSystemMultipoleMoments(ContextImpl ...@@ -1659,7 +1665,7 @@ void CommonCalcAmoebaMultipoleForceKernel::getSystemMultipoleMoments(ContextImpl
void CommonCalcAmoebaMultipoleForceKernel::copyParametersToContext(ContextImpl& context, const AmoebaMultipoleForce& force) { void CommonCalcAmoebaMultipoleForceKernel::copyParametersToContext(ContextImpl& context, const AmoebaMultipoleForce& force) {
// Make sure the new parameters are acceptable. // Make sure the new parameters are acceptable.
cc.setAsCurrent(); ContextSelector selector(cc);
if (force.getNumMultipoles() != cc.getNumAtoms()) if (force.getNumMultipoles() != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of multipoles has changed"); throw OpenMMException("updateParametersInContext: The number of multipoles has changed");
...@@ -1749,7 +1755,7 @@ CommonCalcAmoebaGeneralizedKirkwoodForceKernel::CommonCalcAmoebaGeneralizedKirkw ...@@ -1749,7 +1755,7 @@ CommonCalcAmoebaGeneralizedKirkwoodForceKernel::CommonCalcAmoebaGeneralizedKirkw
} }
void CommonCalcAmoebaGeneralizedKirkwoodForceKernel::initialize(const System& system, const AmoebaGeneralizedKirkwoodForce& force) { void CommonCalcAmoebaGeneralizedKirkwoodForceKernel::initialize(const System& system, const AmoebaGeneralizedKirkwoodForce& force) {
cc.setAsCurrent(); ContextSelector selector(cc);
if (cc.getNumContexts() > 1) if (cc.getNumContexts() > 1)
throw OpenMMException("AmoebaGeneralizedKirkwoodForce does not support using multiple devices"); throw OpenMMException("AmoebaGeneralizedKirkwoodForce does not support using multiple devices");
const AmoebaMultipoleForce* multipoles = NULL; const AmoebaMultipoleForce* multipoles = NULL;
...@@ -1976,7 +1982,7 @@ void CommonCalcAmoebaGeneralizedKirkwoodForceKernel::finishComputation() { ...@@ -1976,7 +1982,7 @@ void CommonCalcAmoebaGeneralizedKirkwoodForceKernel::finishComputation() {
void CommonCalcAmoebaGeneralizedKirkwoodForceKernel::copyParametersToContext(ContextImpl& context, const AmoebaGeneralizedKirkwoodForce& force) { void CommonCalcAmoebaGeneralizedKirkwoodForceKernel::copyParametersToContext(ContextImpl& context, const AmoebaGeneralizedKirkwoodForce& force) {
// Make sure the new parameters are acceptable. // Make sure the new parameters are acceptable.
cc.setAsCurrent(); ContextSelector selector(cc);
if (force.getNumParticles() != cc.getNumAtoms()) if (force.getNumParticles() != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed"); throw OpenMMException("updateParametersInContext: The number of particles has changed");
...@@ -2017,13 +2023,13 @@ CommonCalcAmoebaVdwForceKernel::CommonCalcAmoebaVdwForceKernel(const std::string ...@@ -2017,13 +2023,13 @@ CommonCalcAmoebaVdwForceKernel::CommonCalcAmoebaVdwForceKernel(const std::string
} }
CommonCalcAmoebaVdwForceKernel::~CommonCalcAmoebaVdwForceKernel() { CommonCalcAmoebaVdwForceKernel::~CommonCalcAmoebaVdwForceKernel() {
cc.setAsCurrent(); ContextSelector selector(cc);
if (nonbonded != NULL) if (nonbonded != NULL)
delete nonbonded; delete nonbonded;
} }
void CommonCalcAmoebaVdwForceKernel::initialize(const System& system, const AmoebaVdwForce& force) { void CommonCalcAmoebaVdwForceKernel::initialize(const System& system, const AmoebaVdwForce& force) {
cc.setAsCurrent(); ContextSelector selector(cc);
int paddedNumAtoms = cc.getPaddedNumAtoms(); int paddedNumAtoms = cc.getPaddedNumAtoms();
bondReductionAtoms.initialize<int>(cc, paddedNumAtoms, "bondReductionAtoms"); bondReductionAtoms.initialize<int>(cc, paddedNumAtoms, "bondReductionAtoms");
bondReductionFactors.initialize<float>(cc, paddedNumAtoms, "bondReductionFactors"); bondReductionFactors.initialize<float>(cc, paddedNumAtoms, "bondReductionFactors");
...@@ -2131,6 +2137,7 @@ void CommonCalcAmoebaVdwForceKernel::initialize(const System& system, const Amoe ...@@ -2131,6 +2137,7 @@ void CommonCalcAmoebaVdwForceKernel::initialize(const System& system, const Amoe
} }
double CommonCalcAmoebaVdwForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcAmoebaVdwForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
if (!hasInitializedNonbonded) { if (!hasInitializedNonbonded) {
hasInitializedNonbonded = true; hasInitializedNonbonded = true;
nonbonded->initialize(system); nonbonded->initialize(system);
...@@ -2160,7 +2167,7 @@ double CommonCalcAmoebaVdwForceKernel::execute(ContextImpl& context, bool includ ...@@ -2160,7 +2167,7 @@ double CommonCalcAmoebaVdwForceKernel::execute(ContextImpl& context, bool includ
void CommonCalcAmoebaVdwForceKernel::copyParametersToContext(ContextImpl& context, const AmoebaVdwForce& force) { void CommonCalcAmoebaVdwForceKernel::copyParametersToContext(ContextImpl& context, const AmoebaVdwForce& force) {
// Make sure the new parameters are acceptable. // Make sure the new parameters are acceptable.
cc.setAsCurrent(); ContextSelector selector(cc);
if (force.getNumParticles() != cc.getNumAtoms()) if (force.getNumParticles() != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed"); throw OpenMMException("updateParametersInContext: The number of particles has changed");
...@@ -2229,6 +2236,7 @@ void CommonCalcAmoebaWcaDispersionForceKernel::initialize(const System& system, ...@@ -2229,6 +2236,7 @@ void CommonCalcAmoebaWcaDispersionForceKernel::initialize(const System& system,
// Record parameters. // Record parameters.
ContextSelector selector(cc);
vector<mm_float2> radiusEpsilonVec(paddedNumAtoms, mm_float2(0, 0)); vector<mm_float2> radiusEpsilonVec(paddedNumAtoms, mm_float2(0, 0));
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
double radius, epsilon; double radius, epsilon;
...@@ -2272,6 +2280,7 @@ void CommonCalcAmoebaWcaDispersionForceKernel::initialize(const System& system, ...@@ -2272,6 +2280,7 @@ void CommonCalcAmoebaWcaDispersionForceKernel::initialize(const System& system,
} }
double CommonCalcAmoebaWcaDispersionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcAmoebaWcaDispersionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
NonbondedUtilities& nb = cc.getNonbondedUtilities(); NonbondedUtilities& nb = cc.getNonbondedUtilities();
int startTileIndex = nb.getStartTileIndex(); int startTileIndex = nb.getStartTileIndex();
int numTileIndices = nb.getNumTiles(); int numTileIndices = nb.getNumTiles();
...@@ -2285,7 +2294,7 @@ double CommonCalcAmoebaWcaDispersionForceKernel::execute(ContextImpl& context, b ...@@ -2285,7 +2294,7 @@ double CommonCalcAmoebaWcaDispersionForceKernel::execute(ContextImpl& context, b
void CommonCalcAmoebaWcaDispersionForceKernel::copyParametersToContext(ContextImpl& context, const AmoebaWcaDispersionForce& force) { void CommonCalcAmoebaWcaDispersionForceKernel::copyParametersToContext(ContextImpl& context, const AmoebaWcaDispersionForce& force) {
// Make sure the new parameters are acceptable. // Make sure the new parameters are acceptable.
cc.setAsCurrent(); ContextSelector selector(cc);
if (force.getNumParticles() != cc.getNumAtoms()) if (force.getNumParticles() != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed"); throw OpenMMException("updateParametersInContext: The number of particles has changed");
...@@ -2374,7 +2383,7 @@ CommonCalcHippoNonbondedForceKernel::CommonCalcHippoNonbondedForceKernel(const s ...@@ -2374,7 +2383,7 @@ CommonCalcHippoNonbondedForceKernel::CommonCalcHippoNonbondedForceKernel(const s
} }
void CommonCalcHippoNonbondedForceKernel::initialize(const System& system, const HippoNonbondedForce& force) { void CommonCalcHippoNonbondedForceKernel::initialize(const System& system, const HippoNonbondedForce& force) {
cc.setAsCurrent(); ContextSelector selector(cc);
if (!cc.getSupports64BitGlobalAtomics()) if (!cc.getSupports64BitGlobalAtomics())
throw OpenMMException("HippoNonbondedForce requires a device that supports 64 bit atomic operations"); throw OpenMMException("HippoNonbondedForce requires a device that supports 64 bit atomic operations");
extrapolationCoefficients = force.getExtrapolationCoefficients(); extrapolationCoefficients = force.getExtrapolationCoefficients();
...@@ -3170,6 +3179,7 @@ void CommonCalcHippoNonbondedForceKernel::createFieldKernel(const string& intera ...@@ -3170,6 +3179,7 @@ void CommonCalcHippoNonbondedForceKernel::createFieldKernel(const string& intera
} }
double CommonCalcHippoNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcHippoNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
NonbondedUtilities& nb = cc.getNonbondedUtilities(); NonbondedUtilities& nb = cc.getNonbondedUtilities();
if (!hasInitializedKernels) { if (!hasInitializedKernels) {
hasInitializedKernels = true; hasInitializedKernels = true;
...@@ -3385,6 +3395,7 @@ void CommonCalcHippoNonbondedForceKernel::addTorquesToForces() { ...@@ -3385,6 +3395,7 @@ void CommonCalcHippoNonbondedForceKernel::addTorquesToForces() {
} }
void CommonCalcHippoNonbondedForceKernel::getInducedDipoles(ContextImpl& context, vector<Vec3>& dipoles) { void CommonCalcHippoNonbondedForceKernel::getInducedDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
ContextSelector selector(cc);
ensureMultipolesValid(context); ensureMultipolesValid(context);
int numParticles = cc.getNumAtoms(); int numParticles = cc.getNumAtoms();
dipoles.resize(numParticles); dipoles.resize(numParticles);
...@@ -3432,6 +3443,7 @@ void CommonCalcHippoNonbondedForceKernel::ensureMultipolesValid(ContextImpl& con ...@@ -3432,6 +3443,7 @@ void CommonCalcHippoNonbondedForceKernel::ensureMultipolesValid(ContextImpl& con
} }
void CommonCalcHippoNonbondedForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) { void CommonCalcHippoNonbondedForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
ContextSelector selector(cc);
ensureMultipolesValid(context); ensureMultipolesValid(context);
int numParticles = cc.getNumAtoms(); int numParticles = cc.getNumAtoms();
dipoles.resize(numParticles); dipoles.resize(numParticles);
...@@ -3453,7 +3465,7 @@ void CommonCalcHippoNonbondedForceKernel::getLabFramePermanentDipoles(ContextImp ...@@ -3453,7 +3465,7 @@ void CommonCalcHippoNonbondedForceKernel::getLabFramePermanentDipoles(ContextImp
void CommonCalcHippoNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const HippoNonbondedForce& force) { void CommonCalcHippoNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const HippoNonbondedForce& force) {
// Make sure the new parameters are acceptable. // Make sure the new parameters are acceptable.
cc.setAsCurrent(); ContextSelector selector(cc);
if (force.getNumParticles() != cc.getNumAtoms()) if (force.getNumParticles() != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed"); throw OpenMMException("updateParametersInContext: The number of particles has changed");
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2020 Stanford University and the Authors. * * Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman, Mark Friedrichs * * Authors: Peter Eastman, Mark Friedrichs *
* Contributors: * * Contributors: *
* * * *
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#endif #endif
#include "AmoebaCudaKernels.h" #include "AmoebaCudaKernels.h"
#include "CudaAmoebaKernelSources.h" #include "CudaAmoebaKernelSources.h"
#include "openmm/common/ContextSelector.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/internal/AmoebaGeneralizedKirkwoodForceImpl.h" #include "openmm/internal/AmoebaGeneralizedKirkwoodForceImpl.h"
#include "openmm/internal/AmoebaMultipoleForceImpl.h" #include "openmm/internal/AmoebaMultipoleForceImpl.h"
...@@ -83,7 +84,7 @@ static void setPeriodicBoxArgs(ComputeContext& cc, ComputeKernel kernel, int ind ...@@ -83,7 +84,7 @@ static void setPeriodicBoxArgs(ComputeContext& cc, ComputeKernel kernel, int ind
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
CudaCalcAmoebaMultipoleForceKernel::~CudaCalcAmoebaMultipoleForceKernel() { CudaCalcAmoebaMultipoleForceKernel::~CudaCalcAmoebaMultipoleForceKernel() {
cc.setAsCurrent(); ContextSelector selector(cc);
if (hasInitializedFFT) if (hasInitializedFFT)
cufftDestroy(fft); cufftDestroy(fft);
} }
...@@ -91,6 +92,7 @@ CudaCalcAmoebaMultipoleForceKernel::~CudaCalcAmoebaMultipoleForceKernel() { ...@@ -91,6 +92,7 @@ CudaCalcAmoebaMultipoleForceKernel::~CudaCalcAmoebaMultipoleForceKernel() {
void CudaCalcAmoebaMultipoleForceKernel::initialize(const System& system, const AmoebaMultipoleForce& force) { void CudaCalcAmoebaMultipoleForceKernel::initialize(const System& system, const AmoebaMultipoleForce& force) {
CommonCalcAmoebaMultipoleForceKernel::initialize(system, force); CommonCalcAmoebaMultipoleForceKernel::initialize(system, force);
if (usePME) { if (usePME) {
ContextSelector selector(cc);
cufftResult result = cufftPlan3d(&fft, gridSizeX, gridSizeY, gridSizeZ, cc.getUseDoublePrecision() ? CUFFT_Z2Z : CUFFT_C2C); cufftResult result = cufftPlan3d(&fft, gridSizeX, gridSizeY, gridSizeZ, cc.getUseDoublePrecision() ? CUFFT_Z2Z : CUFFT_C2C);
if (result != CUFFT_SUCCESS) if (result != CUFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cc.intToString(result)); throw OpenMMException("Error initializing FFT: "+cc.intToString(result));
...@@ -120,7 +122,7 @@ void CudaCalcAmoebaMultipoleForceKernel::computeFFT(bool forward) { ...@@ -120,7 +122,7 @@ void CudaCalcAmoebaMultipoleForceKernel::computeFFT(bool forward) {
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
CudaCalcHippoNonbondedForceKernel::~CudaCalcHippoNonbondedForceKernel() { CudaCalcHippoNonbondedForceKernel::~CudaCalcHippoNonbondedForceKernel() {
cc.setAsCurrent(); ContextSelector selector(cc);
if (sort != NULL) if (sort != NULL)
delete sort; delete sort;
if (hasInitializedFFT) { if (hasInitializedFFT) {
...@@ -134,6 +136,7 @@ CudaCalcHippoNonbondedForceKernel::~CudaCalcHippoNonbondedForceKernel() { ...@@ -134,6 +136,7 @@ CudaCalcHippoNonbondedForceKernel::~CudaCalcHippoNonbondedForceKernel() {
void CudaCalcHippoNonbondedForceKernel::initialize(const System& system, const HippoNonbondedForce& force) { void CudaCalcHippoNonbondedForceKernel::initialize(const System& system, const HippoNonbondedForce& force) {
CommonCalcHippoNonbondedForceKernel::initialize(system, force); CommonCalcHippoNonbondedForceKernel::initialize(system, force);
if (usePME) { if (usePME) {
ContextSelector selector(cc);
CudaContext& cu = dynamic_cast<CudaContext&>(cc); CudaContext& cu = dynamic_cast<CudaContext&>(cc);
sort = new CudaSort(cu, new SortTrait(), cc.getNumAtoms()); sort = new CudaSort(cu, new SortTrait(), cc.getNumAtoms());
cufftResult result = cufftPlan3d(&fftForward, gridSizeX, gridSizeY, gridSizeZ, cc.getUseDoublePrecision() ? CUFFT_D2Z : CUFFT_R2C); cufftResult result = cufftPlan3d(&fftForward, gridSizeX, gridSizeY, gridSizeZ, cc.getUseDoublePrecision() ? CUFFT_D2Z : CUFFT_R2C);
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/common/BondedUtilities.h" #include "openmm/common/BondedUtilities.h"
#include "openmm/common/ComputeForceInfo.h" #include "openmm/common/ComputeForceInfo.h"
#include "openmm/common/ContextSelector.h"
#include "openmm/common/IntegrationUtilities.h" #include "openmm/common/IntegrationUtilities.h"
#include "CommonKernelSources.h" #include "CommonKernelSources.h"
#include "SimTKOpenMMRealType.h" #include "SimTKOpenMMRealType.h"
...@@ -101,9 +102,9 @@ private: ...@@ -101,9 +102,9 @@ private:
}; };
void CommonCalcDrudeForceKernel::initialize(const System& system, const DrudeForce& force) { void CommonCalcDrudeForceKernel::initialize(const System& system, const DrudeForce& force) {
cc.setAsCurrent();
if (cc.getContextIndex() != 0) if (cc.getContextIndex() != 0)
return; // This is run entirely on one device return; // This is run entirely on one device
ContextSelector selector(cc);
int numParticles = force.getNumParticles(); int numParticles = force.getNumParticles();
if (numParticles > 0) { if (numParticles > 0) {
// Create the harmonic interaction . // Create the harmonic interaction .
...@@ -173,6 +174,7 @@ void CommonCalcDrudeForceKernel::copyParametersToContext(ContextImpl& context, c ...@@ -173,6 +174,7 @@ void CommonCalcDrudeForceKernel::copyParametersToContext(ContextImpl& context, c
// Set the particle parameters. // Set the particle parameters.
ContextSelector selector(cc);
int numParticles = force.getNumParticles(); int numParticles = force.getNumParticles();
if (numParticles > 0) { if (numParticles > 0) {
if (!particleParams.isInitialized() || numParticles != particleParams.getSize()) if (!particleParams.isInitialized() || numParticles != particleParams.getSize())
...@@ -222,6 +224,7 @@ void CommonCalcDrudeForceKernel::copyParametersToContext(ContextImpl& context, c ...@@ -222,6 +224,7 @@ void CommonCalcDrudeForceKernel::copyParametersToContext(ContextImpl& context, c
void CommonIntegrateDrudeLangevinStepKernel::initialize(const System& system, const DrudeLangevinIntegrator& integrator, const DrudeForce& force) { void CommonIntegrateDrudeLangevinStepKernel::initialize(const System& system, const DrudeLangevinIntegrator& integrator, const DrudeForce& force) {
cc.initializeContexts(); cc.initializeContexts();
ContextSelector selector(cc);
cc.getIntegrationUtilities().initRandomNumberGenerator((unsigned int) integrator.getRandomNumberSeed()); cc.getIntegrationUtilities().initRandomNumberGenerator((unsigned int) integrator.getRandomNumberSeed());
// Identify particle pairs and ordinary particles. // Identify particle pairs and ordinary particles.
...@@ -263,7 +266,7 @@ void CommonIntegrateDrudeLangevinStepKernel::initialize(const System& system, co ...@@ -263,7 +266,7 @@ void CommonIntegrateDrudeLangevinStepKernel::initialize(const System& system, co
} }
void CommonIntegrateDrudeLangevinStepKernel::execute(ContextImpl& context, const DrudeLangevinIntegrator& integrator) { void CommonIntegrateDrudeLangevinStepKernel::execute(ContextImpl& context, const DrudeLangevinIntegrator& integrator) {
cc.setAsCurrent(); ContextSelector selector(cc);
IntegrationUtilities& integration = cc.getIntegrationUtilities(); IntegrationUtilities& integration = cc.getIntegrationUtilities();
int numAtoms = cc.getNumAtoms(); int numAtoms = cc.getNumAtoms();
if (!hasInitializedKernels) { if (!hasInitializedKernels) {
...@@ -378,7 +381,7 @@ CommonIntegrateDrudeSCFStepKernel::~CommonIntegrateDrudeSCFStepKernel() { ...@@ -378,7 +381,7 @@ CommonIntegrateDrudeSCFStepKernel::~CommonIntegrateDrudeSCFStepKernel() {
void CommonIntegrateDrudeSCFStepKernel::initialize(const System& system, const DrudeSCFIntegrator& integrator, const DrudeForce& force) { void CommonIntegrateDrudeSCFStepKernel::initialize(const System& system, const DrudeSCFIntegrator& integrator, const DrudeForce& force) {
cc.initializeContexts(); cc.initializeContexts();
cc.setAsCurrent(); ContextSelector selector(cc);
// Identify Drude particles. // Identify Drude particles.
...@@ -406,7 +409,7 @@ void CommonIntegrateDrudeSCFStepKernel::initialize(const System& system, const D ...@@ -406,7 +409,7 @@ void CommonIntegrateDrudeSCFStepKernel::initialize(const System& system, const D
} }
void CommonIntegrateDrudeSCFStepKernel::execute(ContextImpl& context, const DrudeSCFIntegrator& integrator) { void CommonIntegrateDrudeSCFStepKernel::execute(ContextImpl& context, const DrudeSCFIntegrator& integrator) {
cc.setAsCurrent(); ContextSelector selector(cc);
IntegrationUtilities& integration = cc.getIntegrationUtilities(); IntegrationUtilities& integration = cc.getIntegrationUtilities();
int numAtoms = cc.getNumAtoms(); int numAtoms = cc.getNumAtoms();
double dt = integrator.getStepSize(); double dt = integrator.getStepSize();
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "CommonRpmdKernels.h" #include "CommonRpmdKernels.h"
#include "CommonRpmdKernelSources.h" #include "CommonRpmdKernelSources.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/common/ContextSelector.h"
#include "openmm/common/IntegrationUtilities.h" #include "openmm/common/IntegrationUtilities.h"
#include "openmm/common/ExpressionUtilities.h" #include "openmm/common/ExpressionUtilities.h"
#include "openmm/common/NonbondedUtilities.h" #include "openmm/common/NonbondedUtilities.h"
...@@ -63,6 +64,7 @@ static int findFFTDimension(int minimum) { ...@@ -63,6 +64,7 @@ static int findFFTDimension(int minimum) {
void CommonIntegrateRPMDStepKernel::initialize(const System& system, const RPMDIntegrator& integrator) { void CommonIntegrateRPMDStepKernel::initialize(const System& system, const RPMDIntegrator& integrator) {
cc.initializeContexts(); cc.initializeContexts();
ContextSelector selector(cc);
numCopies = integrator.getNumCopies(); numCopies = integrator.getNumCopies();
numParticles = system.getNumParticles(); numParticles = system.getNumParticles();
workgroupSize = numCopies; workgroupSize = numCopies;
...@@ -213,7 +215,7 @@ void CommonIntegrateRPMDStepKernel::initializeKernels(ContextImpl& context) { ...@@ -213,7 +215,7 @@ void CommonIntegrateRPMDStepKernel::initializeKernels(ContextImpl& context) {
} }
void CommonIntegrateRPMDStepKernel::execute(ContextImpl& context, const RPMDIntegrator& integrator, bool forcesAreValid) { void CommonIntegrateRPMDStepKernel::execute(ContextImpl& context, const RPMDIntegrator& integrator, bool forcesAreValid) {
cc.setAsCurrent(); ContextSelector selector(cc);
if (!hasInitializedKernels) if (!hasInitializedKernels)
initializeKernels(context); initializeKernels(context);
IntegrationUtilities& integration = cc.getIntegrationUtilities(); IntegrationUtilities& integration = cc.getIntegrationUtilities();
...@@ -364,6 +366,7 @@ void CommonIntegrateRPMDStepKernel::setPositions(int copy, const vector<Vec3>& p ...@@ -364,6 +366,7 @@ void CommonIntegrateRPMDStepKernel::setPositions(int copy, const vector<Vec3>& p
// Record the positions. // Record the positions.
ContextSelector selector(cc);
if (cc.getUseDoublePrecision()) { if (cc.getUseDoublePrecision()) {
vector<mm_double4> posq(cc.getPaddedNumAtoms()); vector<mm_double4> posq(cc.getPaddedNumAtoms());
cc.getPosq().download(posq); cc.getPosq().download(posq);
...@@ -393,6 +396,7 @@ void CommonIntegrateRPMDStepKernel::setVelocities(int copy, const vector<Vec3>& ...@@ -393,6 +396,7 @@ void CommonIntegrateRPMDStepKernel::setVelocities(int copy, const vector<Vec3>&
throw OpenMMException("RPMDIntegrator: Cannot set velocities before the integrator is added to a Context"); throw OpenMMException("RPMDIntegrator: Cannot set velocities before the integrator is added to a Context");
if (vel.size() != numParticles) if (vel.size() != numParticles)
throw OpenMMException("RPMDIntegrator: wrong number of values passed to setVelocities()"); throw OpenMMException("RPMDIntegrator: wrong number of values passed to setVelocities()");
ContextSelector selector(cc);
if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) { if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
vector<mm_double4> velm(cc.getPaddedNumAtoms()); vector<mm_double4> velm(cc.getPaddedNumAtoms());
cc.getVelm().download(velm); cc.getVelm().download(velm);
...@@ -410,6 +414,7 @@ void CommonIntegrateRPMDStepKernel::setVelocities(int copy, const vector<Vec3>& ...@@ -410,6 +414,7 @@ void CommonIntegrateRPMDStepKernel::setVelocities(int copy, const vector<Vec3>&
} }
void CommonIntegrateRPMDStepKernel::copyToContext(int copy, ContextImpl& context) { void CommonIntegrateRPMDStepKernel::copyToContext(int copy, ContextImpl& context) {
ContextSelector selector(cc);
if (!hasInitializedKernels) if (!hasInitializedKernels)
initializeKernels(context); initializeKernels(context);
copyToContextKernel->setArg(2, positions); copyToContextKernel->setArg(2, positions);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment