"...src/ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "b619675415da2666ed773109b025da6541a7eb4a"
Commit aee84c82 authored by peastman's avatar peastman
Browse files

Completed OpenCL implementation of CustomCVForce

parent 48378da7
...@@ -127,6 +127,14 @@ public: ...@@ -127,6 +127,14 @@ public:
* @param properties a set of values for platform-specific properties. Keys are the property names. * @param properties a set of values for platform-specific properties. Keys are the property names.
*/ */
virtual void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const; virtual void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
/**
* This is called whenever a new Context is created using ContextImpl::createLinkedContext(). It gives the
* Platform a chance to initialize the context and store platform-specific data in it.
*
* @param context the newly created context
* @param originalContext the original context it is linked to
*/
virtual void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const;
/** /**
* This is called whenever a Context is deleted. It gives the Platform a chance to clean up * This is called whenever a Context is deleted. It gives the Platform a chance to clean up
* any platform-specific data that was stored in it. * any platform-specific data that was stored in it.
......
...@@ -959,8 +959,9 @@ public: ...@@ -959,8 +959,9 @@ public:
* *
* @param system the System this kernel will be applied to * @param system the System this kernel will be applied to
* @param force the CustomCVForce this kernel will be used for * @param force the CustomCVForce this kernel will be used for
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/ */
virtual void initialize(const System& system, const CustomCVForce& force) = 0; virtual void initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) = 0;
/** /**
* Execute the kernel to calculate the forces and/or energy. * Execute the kernel to calculate the forces and/or energy.
* *
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/Kernel.h" #include "openmm/Kernel.h"
#include "openmm/KernelFactory.h" #include "openmm/KernelFactory.h"
#include "openmm/internal/ContextImpl.h"
#ifdef WIN32 #ifdef WIN32
#include <windows.h> #include <windows.h>
#include <sstream> #include <sstream>
...@@ -113,6 +114,16 @@ void Platform::setPropertyDefaultValue(const string& property, const string& val ...@@ -113,6 +114,16 @@ void Platform::setPropertyDefaultValue(const string& property, const string& val
void Platform::contextCreated(ContextImpl& context, const map<string, string>& properties) const { void Platform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
} }
void Platform::linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const {
// The default implementation just copies over the properties and calls contextCreated().
// Subclasses may override this to do something different.
map<string, string> properties;
for (auto& name : getPropertyNames())
properties[name] = getPropertyValue(originalContext.getOwner(), name);
contextCreated(context, properties);
}
void Platform::contextDestroyed(ContextImpl& context) const { void Platform::contextDestroyed(ContextImpl& context) const {
} }
......
...@@ -263,9 +263,11 @@ public: ...@@ -263,9 +263,11 @@ public:
*/ */
const std::vector<std::vector<int> >& getMolecules() const; const std::vector<std::vector<int> >& getMolecules() const;
private: private:
friend class ContextImpl;
friend class Force; friend class Force;
friend class ForceImpl; friend class ForceImpl;
friend class Platform; friend class Platform;
Context(const System& system, Integrator& integrator, ContextImpl& linked);
ContextImpl& getImpl(); ContextImpl& getImpl();
const ContextImpl& getImpl() const; const ContextImpl& getImpl() const;
ContextImpl* impl; ContextImpl* impl;
......
...@@ -55,7 +55,8 @@ public: ...@@ -55,7 +55,8 @@ public:
/** /**
* Create an ContextImpl for a Context; * Create an ContextImpl for a Context;
*/ */
ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const std::map<std::string, std::string>& properties); ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const std::map<std::string, std::string>& properties,
ContextImpl* originalContext=NULL);
~ContextImpl(); ~ContextImpl();
/** /**
* Get the Context for which this is the implementation. * Get the Context for which this is the implementation.
...@@ -264,6 +265,15 @@ public: ...@@ -264,6 +265,15 @@ public:
* you should never call it. It is exposed here because the same logic is useful to other classes too. * you should never call it. It is exposed here because the same logic is useful to other classes too.
*/ */
static std::vector<std::vector<int> > findMolecules(int numParticles, std::vector<std::vector<int> >& particleBonds); static std::vector<std::vector<int> > findMolecules(int numParticles, std::vector<std::vector<int> >& particleBonds);
/**
* Create a new Context based on this one. The new context will use the same Platform, device, and property
* values as this one. With the CUDA and OpenCL platforms, it also shares the same GPU context, allowing data
* to be transferred between them without leaving the GPU.
*
* This method exists for very specialized purposes. If you aren't certain whether you should use it, that probably
* means you shouldn't.
*/
Context* createLinkedContext(const System& system, Integrator& integrator);
private: private:
friend class Context; friend class Context;
void initialize(); void initialize();
......
...@@ -40,6 +40,12 @@ ...@@ -40,6 +40,12 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
Context::Context(const System& system, Integrator& integrator, ContextImpl& linked) : properties(linked.getOwner().properties) {
// This is used by ContextImpl::createLinkedContext().
impl = new ContextImpl(*this, system, integrator, &linked.getPlatform(), properties, &linked);
impl->initialize();
}
Context::Context(const System& system, Integrator& integrator) : properties(map<string, string>()) { Context::Context(const System& system, Integrator& integrator) : properties(map<string, string>()) {
impl = new ContextImpl(*this, system, integrator, 0, properties); impl = new ContextImpl(*this, system, integrator, 0, properties);
impl->initialize(); impl->initialize();
......
...@@ -53,7 +53,7 @@ using namespace std; ...@@ -53,7 +53,7 @@ using namespace std;
const static char CHECKPOINT_MAGIC_BYTES[] = "OpenMM Binary Checkpoint\n"; const static char CHECKPOINT_MAGIC_BYTES[] = "OpenMM Binary Checkpoint\n";
ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties) : ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties, ContextImpl* originalContext) :
owner(owner), system(system), integrator(integrator), hasInitializedForces(false), hasSetPositions(false), integratorIsDeleted(false), owner(owner), system(system), integrator(integrator), hasInitializedForces(false), hasSetPositions(false), integratorIsDeleted(false),
lastForceGroups(-1), platform(platform), platformData(NULL) { lastForceGroups(-1), platform(platform), platformData(NULL) {
int numParticles = system.getNumParticles(); int numParticles = system.getNumParticles();
...@@ -152,7 +152,10 @@ ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integ ...@@ -152,7 +152,10 @@ ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integ
for (int i = candidatePlatforms.size()-1; i >= 0; i--) { for (int i = candidatePlatforms.size()-1; i >= 0; i--) {
try { try {
this->platform = platform = candidatePlatforms[i].second; this->platform = platform = candidatePlatforms[i].second;
platform->contextCreated(*this, validatedProperties); if (originalContext == NULL)
platform->contextCreated(*this, validatedProperties);
else
platform->linkedContextCreated(*this, *originalContext);
break; break;
} }
catch (...) { catch (...) {
...@@ -481,3 +484,7 @@ void ContextImpl::loadCheckpoint(istream& stream) { ...@@ -481,3 +484,7 @@ void ContextImpl::loadCheckpoint(istream& stream) {
void ContextImpl::systemChanged() { void ContextImpl::systemChanged() {
integrator.stateChanged(State::Energy); integrator.stateChanged(State::Energy);
} }
Context* ContextImpl::createLinkedContext(const System& system, Integrator& integrator) {
return new Context(system, integrator, *this);
}
...@@ -66,18 +66,14 @@ void CustomCVForceImpl::initialize(ContextImpl& context) { ...@@ -66,18 +66,14 @@ void CustomCVForceImpl::initialize(ContextImpl& context) {
// Create the inner context. // Create the inner context.
Platform& platform = context.getPlatform(); innerContext = context.createLinkedContext(innerSystem, innerIntegrator);
map<string, string> properties;
for (auto& name : platform.getPropertyNames())
properties[name] = platform.getPropertyValue(context.getOwner(), name);
innerContext = new Context(innerSystem, innerIntegrator, platform, properties);
vector<Vec3> positions(system.getNumParticles(), Vec3()); vector<Vec3> positions(system.getNumParticles(), Vec3());
innerContext->setPositions(positions); innerContext->setPositions(positions);
// Create the kernel. // Create the kernel.
kernel = context.getPlatform().createKernel(CalcCustomCVForceKernel::Name(), context); kernel = context.getPlatform().createKernel(CalcCustomCVForceKernel::Name(), context);
kernel.getAs<CalcCustomCVForceKernel>().initialize(context.getSystem(), owner); kernel.getAs<CalcCustomCVForceKernel>().initialize(context.getSystem(), owner, getContextImpl(*innerContext));
} }
double CustomCVForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomCVForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
......
...@@ -163,7 +163,8 @@ public: ...@@ -163,7 +163,8 @@ public:
class ForcePostComputation; class ForcePostComputation;
static const int ThreadBlockSize; static const int ThreadBlockSize;
static const int TileSize; static const int TileSize;
OpenCLContext(const System& system, int platformIndex, int deviceIndex, const std::string& precision, OpenCLPlatform::PlatformData& platformData); OpenCLContext(const System& system, int platformIndex, int deviceIndex, const std::string& precision, OpenCLPlatform::PlatformData& platformData,
OpenCLContext* originalContext);
~OpenCLContext(); ~OpenCLContext();
/** /**
* This is called to initialize internal data structures after all Forces in the system * This is called to initialize internal data structures after all Forces in the system
......
...@@ -1222,8 +1222,9 @@ public: ...@@ -1222,8 +1222,9 @@ public:
* *
* @param system the System this kernel will be applied to * @param system the System this kernel will be applied to
* @param force the CustomCVForce this kernel will be used for * @param force the CustomCVForce this kernel will be used for
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/ */
void initialize(const System& system, const CustomCVForce& force); void initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext);
/** /**
* Execute the kernel to calculate the forces and/or energy. * Execute the kernel to calculate the forces and/or energy.
* *
......
...@@ -53,6 +53,7 @@ public: ...@@ -53,6 +53,7 @@ public:
const std::string& getPropertyValue(const Context& context, const std::string& property) const; const std::string& getPropertyValue(const Context& context, const std::string& property) const;
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const; void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const; void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const;
void contextDestroyed(ContextImpl& context) const; void contextDestroyed(ContextImpl& context) const;
/** /**
* This is the name of the parameter for selecting which OpenCL device or devices to use. * This is the name of the parameter for selecting which OpenCL device or devices to use.
...@@ -108,7 +109,7 @@ public: ...@@ -108,7 +109,7 @@ public:
class OPENMM_EXPORT_OPENCL OpenCLPlatform::PlatformData { class OPENMM_EXPORT_OPENCL OpenCLPlatform::PlatformData {
public: public:
PlatformData(const System& system, const std::string& platformPropValue, const std::string& deviceIndexProperty, const std::string& precisionProperty, PlatformData(const System& system, const std::string& platformPropValue, const std::string& deviceIndexProperty, const std::string& precisionProperty,
const std::string& cpuPmeProperty, const std::string& pmeStreamProperty, int numThreads); const std::string& cpuPmeProperty, const std::string& pmeStreamProperty, int numThreads, ContextImpl* originalContext);
~PlatformData(); ~PlatformData();
void initializeContexts(const System& system); void initializeContexts(const System& system);
void syncContexts(); void syncContexts();
......
...@@ -67,7 +67,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i ...@@ -67,7 +67,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i
std::cerr << "OpenCL internal error: " << errinfo << std::endl; std::cerr << "OpenCL internal error: " << errinfo << std::endl;
} }
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, const string& precision, OpenCLPlatform::PlatformData& platformData) : OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, const string& precision, OpenCLPlatform::PlatformData& platformData, OpenCLContext* originalContext) :
system(system), time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), stepsSinceReorder(99999), atomsWereReordered(false), posq(NULL), system(system), time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), stepsSinceReorder(99999), atomsWereReordered(false), posq(NULL),
posqCorrection(NULL), velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), energyParamDerivBuffer(NULL), atomIndexDevice(NULL), posqCorrection(NULL), velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), energyParamDerivBuffer(NULL), atomIndexDevice(NULL),
chargeBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), thread(NULL) { chargeBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), thread(NULL) {
...@@ -261,8 +261,14 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device ...@@ -261,8 +261,14 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
vector<cl::Device> contextDevices; vector<cl::Device> contextDevices;
contextDevices.push_back(device); contextDevices.push_back(device);
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0}; cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
context = cl::Context(contextDevices, cprops, errorCallback); if (originalContext == NULL) {
defaultQueue = cl::CommandQueue(context, device); context = cl::Context(contextDevices, cprops, errorCallback);
defaultQueue = cl::CommandQueue(context, device);
}
else {
context = originalContext->context;
defaultQueue = originalContext->defaultQueue;
}
currentQueue = defaultQueue; currentQueue = defaultQueue;
numAtoms = system.getNumParticles(); numAtoms = system.getNumParticles();
paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize); paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
......
...@@ -6879,12 +6879,11 @@ public: ...@@ -6879,12 +6879,11 @@ public:
ReorderListener(OpenCLContext& cl, OpenCLArray& invAtomOrder) : cl(cl), invAtomOrder(invAtomOrder) { ReorderListener(OpenCLContext& cl, OpenCLArray& invAtomOrder) : cl(cl), invAtomOrder(invAtomOrder) {
} }
void execute() { void execute() {
vector<cl_int> invOrder(cl.getNumAtoms()); vector<cl_int> invOrder(cl.getPaddedNumAtoms());
const vector<int>& order = cl.getAtomIndex(); const vector<int>& order = cl.getAtomIndex();
for (int i = 0; i < order.size(); i++) for (int i = 0; i < order.size(); i++)
invOrder[order[i]] = i; invOrder[order[i]] = i;
invAtomOrder.upload(invOrder); invAtomOrder.upload(invOrder);
cl.getQueue().finish();
} }
private: private:
OpenCLContext& cl; OpenCLContext& cl;
...@@ -6900,7 +6899,7 @@ OpenCLCalcCustomCVForceKernel::~OpenCLCalcCustomCVForceKernel() { ...@@ -6900,7 +6899,7 @@ OpenCLCalcCustomCVForceKernel::~OpenCLCalcCustomCVForceKernel() {
delete innerInvAtomOrder; delete innerInvAtomOrder;
} }
void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force) { void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) {
int numCVs = force.getNumCollectiveVariables(); int numCVs = force.getNumCollectiveVariables();
cl.addForce(new OpenCLForceInfo(1)); cl.addForce(new OpenCLForceInfo(1));
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
...@@ -6925,20 +6924,27 @@ void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const Custo ...@@ -6925,20 +6924,27 @@ void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const Custo
string name = force.getEnergyParameterDerivativeName(i); string name = force.getEnergyParameterDerivativeName(i);
paramDerivNames.push_back(name); paramDerivNames.push_back(name);
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram()); paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
cl.addEnergyParameterDerivative(name);
} }
// Delete the custom functions. // Delete the custom functions.
for (auto& function : functions) for (auto& function : functions)
delete function.second; delete function.second;
// Copy parameter derivatives from the inner context.
OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
for (auto& param : cl2.getEnergyParamDerivNames())
cl.addEnergyParameterDerivative(param);
// Create arrays for storing information. // Create arrays for storing information.
int elementSize = (cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(double) : sizeof(float)); int elementSize = (cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
for (int i = 0; i < numCVs; i++) for (int i = 0; i < numCVs; i++)
cvForces.push_back(new OpenCLArray(cl, cl.getNumAtoms(), elementSize, "cvForce")); cvForces.push_back(new OpenCLArray(cl, cl.getNumAtoms(), 4*elementSize, "cvForce"));
invAtomOrder = OpenCLArray::create<cl_int>(cl, cl.getNumAtoms(), "invAtomOrder"); invAtomOrder = OpenCLArray::create<cl_int>(cl, cl.getPaddedNumAtoms(), "invAtomOrder");
innerInvAtomOrder = OpenCLArray::create<cl_int>(cl, cl.getNumAtoms(), "innerInvAtomOrder"); innerInvAtomOrder = OpenCLArray::create<cl_int>(cl, cl.getPaddedNumAtoms(), "innerInvAtomOrder");
// Create the kernels. // Create the kernels.
...@@ -6984,8 +6990,8 @@ double OpenCLCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl& ...@@ -6984,8 +6990,8 @@ double OpenCLCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
addForcesKernel.setArg<cl_double>(2*i+3, dEdV); addForcesKernel.setArg<cl_double>(2*i+3, dEdV);
else else
addForcesKernel.setArg<cl_float>(2*i+3, dEdV); addForcesKernel.setArg<cl_float>(2*i+3, dEdV);
cl.executeKernel(addForcesKernel, numAtoms);
} }
cl.executeKernel(addForcesKernel, numAtoms);
// Compute the energy parameter derivatives. // Compute the energy parameter derivatives.
...@@ -7002,12 +7008,12 @@ double OpenCLCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl& ...@@ -7002,12 +7008,12 @@ double OpenCLCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
void OpenCLCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& innerContext) { void OpenCLCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& innerContext) {
int numAtoms = cl.getNumAtoms(); int numAtoms = cl.getNumAtoms();
OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
if (!hasInitializedKernels) { if (!hasInitializedKernels) {
hasInitializedKernels = true; hasInitializedKernels = true;
// Initialize the listeners. // Initialize the listeners.
OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
ReorderListener* listener1 = new ReorderListener(cl, *invAtomOrder); ReorderListener* listener1 = new ReorderListener(cl, *invAtomOrder);
ReorderListener* listener2 = new ReorderListener(cl2, *innerInvAtomOrder); ReorderListener* listener2 = new ReorderListener(cl2, *innerInvAtomOrder);
cl.addReorderListener(listener1); cl.addReorderListener(listener1);
......
...@@ -180,7 +180,20 @@ void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, stri ...@@ -180,7 +180,20 @@ void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, stri
char* threadsEnv = getenv("OPENMM_CPU_THREADS"); char* threadsEnv = getenv("OPENMM_CPU_THREADS");
if (threadsEnv != NULL) if (threadsEnv != NULL)
stringstream(threadsEnv) >> threads; stringstream(threadsEnv) >> threads;
context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue, pmeStreamPropValue, threads)); context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue,
pmeStreamPropValue, threads, NULL));
}
void OpenCLPlatform::linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const {
Platform& platform = originalContext.getPlatform();
string platformPropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLPlatformIndex());
string devicePropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLDeviceIndex());
string precisionPropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLPrecision());
string cpuPmePropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLUseCpuPme());
string pmeStreamPropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLDisablePmeStream());
int threads = reinterpret_cast<PlatformData*>(originalContext.getPlatformData())->threads.getNumThreads();
context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue,
pmeStreamPropValue, threads, &originalContext));
} }
void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
...@@ -189,7 +202,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { ...@@ -189,7 +202,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
} }
OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& platformPropValue, const string& deviceIndexProperty, OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& platformPropValue, const string& deviceIndexProperty,
const string& precisionProperty, const string& cpuPmeProperty, const string& pmeStreamProperty, int numThreads) : const string& precisionProperty, const string& cpuPmeProperty, const string& pmeStreamProperty, int numThreads, ContextImpl* originalContext) :
removeCM(false), stepCount(0), computeForceCount(0), time(0.0), hasInitializedContexts(false), threads(numThreads) { removeCM(false), stepCount(0), computeForceCount(0), time(0.0), hasInitializedContexts(false), threads(numThreads) {
int platformIndex = -1; int platformIndex = -1;
if (platformPropValue.length() > 0) if (platformPropValue.length() > 0)
...@@ -201,16 +214,19 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p ...@@ -201,16 +214,19 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p
searchPos = nextPos+1; searchPos = nextPos+1;
} }
devices.push_back(deviceIndexProperty.substr(searchPos)); devices.push_back(deviceIndexProperty.substr(searchPos));
PlatformData* originalData = NULL;
if (originalContext != NULL)
originalData = reinterpret_cast<PlatformData*>(originalContext->getPlatformData());
try { try {
for (int i = 0; i < (int) devices.size(); i++) { for (int i = 0; i < (int) devices.size(); i++) {
if (devices[i].length() > 0) { if (devices[i].length() > 0) {
int deviceIndex; int deviceIndex;
stringstream(devices[i]) >> deviceIndex; stringstream(devices[i]) >> deviceIndex;
contexts.push_back(new OpenCLContext(system, platformIndex, deviceIndex, precisionProperty, *this)); contexts.push_back(new OpenCLContext(system, platformIndex, deviceIndex, precisionProperty, *this, (originalData == NULL ? NULL : originalData->contexts[i])));
} }
} }
if (contexts.size() == 0) if (contexts.size() == 0)
contexts.push_back(new OpenCLContext(system, platformIndex, -1, precisionProperty, *this)); contexts.push_back(new OpenCLContext(system, platformIndex, -1, precisionProperty, *this, (originalData == NULL ? NULL : originalData->contexts[0])));
} }
catch (...) { catch (...) {
// If an exception was thrown, do our best to clean up memory. // If an exception was thrown, do our best to clean up memory.
......
...@@ -28,7 +28,7 @@ __kernel void copyForces(__global real4* forces, __global int* restrict invAtomO ...@@ -28,7 +28,7 @@ __kernel void copyForces(__global real4* forces, __global int* restrict invAtomO
/** /**
* Add all the forces from the CVs. * Add all the forces from the CVs.
*/ */
__kernel void addForces(__global real4* forces, int numAtoms, int numCVs __kernel void addForces(__global real4* forces, int numAtoms
PARAMETER_ARGUMENTS) { PARAMETER_ARGUMENTS) {
for (int i = get_global_id(0); i < numAtoms; i += get_global_size(0)) { for (int i = get_global_id(0); i < numAtoms; i += get_global_size(0)) {
real4 f = forces[i]; real4 f = forces[i];
......
...@@ -54,7 +54,7 @@ template <class Real2> ...@@ -54,7 +54,7 @@ template <class Real2>
void testTransform(bool realToComplex, int xsize, int ysize, int zsize) { void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
System system; System system;
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1); OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL);
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(); context.initialize();
OpenMM_SFMT::SFMT sfmt; OpenMM_SFMT::SFMT sfmt;
......
...@@ -54,7 +54,7 @@ void testGaussian() { ...@@ -54,7 +54,7 @@ void testGaussian() {
System system; System system;
for (int i = 0; i < numAtoms; i++) for (int i = 0; i < numAtoms; i++)
system.addParticle(1.0); system.addParticle(1.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1); OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL);
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(); context.initialize();
context.getIntegrationUtilities().initRandomNumberGenerator(0); context.getIntegrationUtilities().initRandomNumberGenerator(0);
......
...@@ -64,7 +64,7 @@ void verifySorting(vector<float> array) { ...@@ -64,7 +64,7 @@ void verifySorting(vector<float> array) {
System system; System system;
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1); OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL);
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(); context.initialize();
OpenCLArray data(context, array.size(), sizeof(float), "sortData"); OpenCLArray data(context, array.size(), sizeof(float), "sortData");
......
...@@ -1020,8 +1020,9 @@ public: ...@@ -1020,8 +1020,9 @@ public:
* *
* @param system the System this kernel will be applied to * @param system the System this kernel will be applied to
* @param force the CustomCVForce this kernel will be used for * @param force the CustomCVForce this kernel will be used for
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/ */
void initialize(const System& system, const CustomCVForce& force); void initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext);
/** /**
* Execute the kernel to calculate the forces and/or energy. * Execute the kernel to calculate the forces and/or energy.
* *
......
...@@ -2021,7 +2021,7 @@ ReferenceCalcCustomCVForceKernel::~ReferenceCalcCustomCVForceKernel() { ...@@ -2021,7 +2021,7 @@ ReferenceCalcCustomCVForceKernel::~ReferenceCalcCustomCVForceKernel() {
delete ixn; delete ixn;
} }
void ReferenceCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force) { void ReferenceCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) {
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment