Commit aee84c82 authored by peastman's avatar peastman
Browse files

Completed OpenCL implementation of CustomCVForce

parent 48378da7
......@@ -127,6 +127,14 @@ public:
* @param properties a set of values for platform-specific properties. Keys are the property names.
*/
virtual void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
/**
* This is called whenever a new Context is created using ContextImpl::createLinkedContext(). It gives the
* Platform a chance to initialize the context and store platform-specific data in it.
*
* @param context the newly created context
* @param originalContext the original context it is linked to
*/
virtual void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const;
/**
* This is called whenever a Context is deleted. It gives the Platform a chance to clean up
* any platform-specific data that was stored in it.
......
......@@ -959,8 +959,9 @@ public:
*
* @param system the System this kernel will be applied to
* @param force the CustomCVForce this kernel will be used for
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/
virtual void initialize(const System& system, const CustomCVForce& force) = 0;
virtual void initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) = 0;
/**
* Execute the kernel to calculate the forces and/or energy.
*
......
......@@ -34,6 +34,7 @@
#include "openmm/OpenMMException.h"
#include "openmm/Kernel.h"
#include "openmm/KernelFactory.h"
#include "openmm/internal/ContextImpl.h"
#ifdef WIN32
#include <windows.h>
#include <sstream>
......@@ -113,6 +114,16 @@ void Platform::setPropertyDefaultValue(const string& property, const string& val
void Platform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
}
void Platform::linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const {
// The default implementation just copies over the properties and calls contextCreated().
// Subclasses may override this to do something different.
map<string, string> properties;
for (auto& name : getPropertyNames())
properties[name] = getPropertyValue(originalContext.getOwner(), name);
contextCreated(context, properties);
}
void Platform::contextDestroyed(ContextImpl& context) const {
}
......
......@@ -263,9 +263,11 @@ public:
*/
const std::vector<std::vector<int> >& getMolecules() const;
private:
friend class ContextImpl;
friend class Force;
friend class ForceImpl;
friend class Platform;
Context(const System& system, Integrator& integrator, ContextImpl& linked);
ContextImpl& getImpl();
const ContextImpl& getImpl() const;
ContextImpl* impl;
......
......@@ -55,7 +55,8 @@ public:
/**
* Create an ContextImpl for a Context;
*/
ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const std::map<std::string, std::string>& properties);
ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const std::map<std::string, std::string>& properties,
ContextImpl* originalContext=NULL);
~ContextImpl();
/**
* Get the Context for which this is the implementation.
......@@ -264,6 +265,15 @@ public:
* you should never call it. It is exposed here because the same logic is useful to other classes too.
*/
static std::vector<std::vector<int> > findMolecules(int numParticles, std::vector<std::vector<int> >& particleBonds);
/**
* Create a new Context based on this one. The new context will use the same Platform, device, and property
* values as this one. With the CUDA and OpenCL platforms, it also shares the same GPU context, allowing data
* to be transferred between them without leaving the GPU.
*
* This method exists for very specialized purposes. If you aren't certain whether you should use it, that probably
* means you shouldn't.
*/
Context* createLinkedContext(const System& system, Integrator& integrator);
private:
friend class Context;
void initialize();
......
......@@ -40,6 +40,12 @@
using namespace OpenMM;
using namespace std;
Context::Context(const System& system, Integrator& integrator, ContextImpl& linked) : properties(linked.getOwner().properties) {
// This is used by ContextImpl::createLinkedContext().
impl = new ContextImpl(*this, system, integrator, &linked.getPlatform(), properties, &linked);
impl->initialize();
}
Context::Context(const System& system, Integrator& integrator) : properties(map<string, string>()) {
impl = new ContextImpl(*this, system, integrator, 0, properties);
impl->initialize();
......
......@@ -53,7 +53,7 @@ using namespace std;
const static char CHECKPOINT_MAGIC_BYTES[] = "OpenMM Binary Checkpoint\n";
ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties) :
ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties, ContextImpl* originalContext) :
owner(owner), system(system), integrator(integrator), hasInitializedForces(false), hasSetPositions(false), integratorIsDeleted(false),
lastForceGroups(-1), platform(platform), platformData(NULL) {
int numParticles = system.getNumParticles();
......@@ -152,7 +152,10 @@ ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integ
for (int i = candidatePlatforms.size()-1; i >= 0; i--) {
try {
this->platform = platform = candidatePlatforms[i].second;
platform->contextCreated(*this, validatedProperties);
if (originalContext == NULL)
platform->contextCreated(*this, validatedProperties);
else
platform->linkedContextCreated(*this, *originalContext);
break;
}
catch (...) {
......@@ -481,3 +484,7 @@ void ContextImpl::loadCheckpoint(istream& stream) {
void ContextImpl::systemChanged() {
integrator.stateChanged(State::Energy);
}
Context* ContextImpl::createLinkedContext(const System& system, Integrator& integrator) {
return new Context(system, integrator, *this);
}
......@@ -66,18 +66,14 @@ void CustomCVForceImpl::initialize(ContextImpl& context) {
// Create the inner context.
Platform& platform = context.getPlatform();
map<string, string> properties;
for (auto& name : platform.getPropertyNames())
properties[name] = platform.getPropertyValue(context.getOwner(), name);
innerContext = new Context(innerSystem, innerIntegrator, platform, properties);
innerContext = context.createLinkedContext(innerSystem, innerIntegrator);
vector<Vec3> positions(system.getNumParticles(), Vec3());
innerContext->setPositions(positions);
// Create the kernel.
kernel = context.getPlatform().createKernel(CalcCustomCVForceKernel::Name(), context);
kernel.getAs<CalcCustomCVForceKernel>().initialize(context.getSystem(), owner);
kernel.getAs<CalcCustomCVForceKernel>().initialize(context.getSystem(), owner, getContextImpl(*innerContext));
}
double CustomCVForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
......
......@@ -163,7 +163,8 @@ public:
class ForcePostComputation;
static const int ThreadBlockSize;
static const int TileSize;
OpenCLContext(const System& system, int platformIndex, int deviceIndex, const std::string& precision, OpenCLPlatform::PlatformData& platformData);
OpenCLContext(const System& system, int platformIndex, int deviceIndex, const std::string& precision, OpenCLPlatform::PlatformData& platformData,
OpenCLContext* originalContext);
~OpenCLContext();
/**
* This is called to initialize internal data structures after all Forces in the system
......
......@@ -1222,8 +1222,9 @@ public:
*
* @param system the System this kernel will be applied to
* @param force the CustomCVForce this kernel will be used for
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/
void initialize(const System& system, const CustomCVForce& force);
void initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext);
/**
* Execute the kernel to calculate the forces and/or energy.
*
......
......@@ -53,6 +53,7 @@ public:
const std::string& getPropertyValue(const Context& context, const std::string& property) const;
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const;
void contextDestroyed(ContextImpl& context) const;
/**
* This is the name of the parameter for selecting which OpenCL device or devices to use.
......@@ -108,7 +109,7 @@ public:
class OPENMM_EXPORT_OPENCL OpenCLPlatform::PlatformData {
public:
PlatformData(const System& system, const std::string& platformPropValue, const std::string& deviceIndexProperty, const std::string& precisionProperty,
const std::string& cpuPmeProperty, const std::string& pmeStreamProperty, int numThreads);
const std::string& cpuPmeProperty, const std::string& pmeStreamProperty, int numThreads, ContextImpl* originalContext);
~PlatformData();
void initializeContexts(const System& system);
void syncContexts();
......
......@@ -67,7 +67,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i
std::cerr << "OpenCL internal error: " << errinfo << std::endl;
}
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, const string& precision, OpenCLPlatform::PlatformData& platformData) :
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, const string& precision, OpenCLPlatform::PlatformData& platformData, OpenCLContext* originalContext) :
system(system), time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), stepsSinceReorder(99999), atomsWereReordered(false), posq(NULL),
posqCorrection(NULL), velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), energyParamDerivBuffer(NULL), atomIndexDevice(NULL),
chargeBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), thread(NULL) {
......@@ -261,8 +261,14 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
vector<cl::Device> contextDevices;
contextDevices.push_back(device);
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
context = cl::Context(contextDevices, cprops, errorCallback);
defaultQueue = cl::CommandQueue(context, device);
if (originalContext == NULL) {
context = cl::Context(contextDevices, cprops, errorCallback);
defaultQueue = cl::CommandQueue(context, device);
}
else {
context = originalContext->context;
defaultQueue = originalContext->defaultQueue;
}
currentQueue = defaultQueue;
numAtoms = system.getNumParticles();
paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
......
......@@ -6879,12 +6879,11 @@ public:
ReorderListener(OpenCLContext& cl, OpenCLArray& invAtomOrder) : cl(cl), invAtomOrder(invAtomOrder) {
}
void execute() {
vector<cl_int> invOrder(cl.getNumAtoms());
vector<cl_int> invOrder(cl.getPaddedNumAtoms());
const vector<int>& order = cl.getAtomIndex();
for (int i = 0; i < order.size(); i++)
invOrder[order[i]] = i;
invAtomOrder.upload(invOrder);
cl.getQueue().finish();
}
private:
OpenCLContext& cl;
......@@ -6900,7 +6899,7 @@ OpenCLCalcCustomCVForceKernel::~OpenCLCalcCustomCVForceKernel() {
delete innerInvAtomOrder;
}
void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force) {
void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) {
int numCVs = force.getNumCollectiveVariables();
cl.addForce(new OpenCLForceInfo(1));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
......@@ -6925,20 +6924,27 @@ void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const Custo
string name = force.getEnergyParameterDerivativeName(i);
paramDerivNames.push_back(name);
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
cl.addEnergyParameterDerivative(name);
}
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
// Copy parameter derivatives from the inner context.
OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
for (auto& param : cl2.getEnergyParamDerivNames())
cl.addEnergyParameterDerivative(param);
// Create arrays for storing information.
int elementSize = (cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
for (int i = 0; i < numCVs; i++)
cvForces.push_back(new OpenCLArray(cl, cl.getNumAtoms(), elementSize, "cvForce"));
invAtomOrder = OpenCLArray::create<cl_int>(cl, cl.getNumAtoms(), "invAtomOrder");
innerInvAtomOrder = OpenCLArray::create<cl_int>(cl, cl.getNumAtoms(), "innerInvAtomOrder");
cvForces.push_back(new OpenCLArray(cl, cl.getNumAtoms(), 4*elementSize, "cvForce"));
invAtomOrder = OpenCLArray::create<cl_int>(cl, cl.getPaddedNumAtoms(), "invAtomOrder");
innerInvAtomOrder = OpenCLArray::create<cl_int>(cl, cl.getPaddedNumAtoms(), "innerInvAtomOrder");
// Create the kernels.
......@@ -6984,8 +6990,8 @@ double OpenCLCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
addForcesKernel.setArg<cl_double>(2*i+3, dEdV);
else
addForcesKernel.setArg<cl_float>(2*i+3, dEdV);
cl.executeKernel(addForcesKernel, numAtoms);
}
cl.executeKernel(addForcesKernel, numAtoms);
// Compute the energy parameter derivatives.
......@@ -7002,12 +7008,12 @@ double OpenCLCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
void OpenCLCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& innerContext) {
int numAtoms = cl.getNumAtoms();
OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
if (!hasInitializedKernels) {
hasInitializedKernels = true;
// Initialize the listeners.
OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
ReorderListener* listener1 = new ReorderListener(cl, *invAtomOrder);
ReorderListener* listener2 = new ReorderListener(cl2, *innerInvAtomOrder);
cl.addReorderListener(listener1);
......
......@@ -180,7 +180,20 @@ void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, stri
char* threadsEnv = getenv("OPENMM_CPU_THREADS");
if (threadsEnv != NULL)
stringstream(threadsEnv) >> threads;
context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue, pmeStreamPropValue, threads));
context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue,
pmeStreamPropValue, threads, NULL));
}
void OpenCLPlatform::linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const {
Platform& platform = originalContext.getPlatform();
string platformPropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLPlatformIndex());
string devicePropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLDeviceIndex());
string precisionPropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLPrecision());
string cpuPmePropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLUseCpuPme());
string pmeStreamPropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLDisablePmeStream());
int threads = reinterpret_cast<PlatformData*>(originalContext.getPlatformData())->threads.getNumThreads();
context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue,
pmeStreamPropValue, threads, &originalContext));
}
void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
......@@ -189,7 +202,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
}
OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& platformPropValue, const string& deviceIndexProperty,
const string& precisionProperty, const string& cpuPmeProperty, const string& pmeStreamProperty, int numThreads) :
const string& precisionProperty, const string& cpuPmeProperty, const string& pmeStreamProperty, int numThreads, ContextImpl* originalContext) :
removeCM(false), stepCount(0), computeForceCount(0), time(0.0), hasInitializedContexts(false), threads(numThreads) {
int platformIndex = -1;
if (platformPropValue.length() > 0)
......@@ -201,16 +214,19 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p
searchPos = nextPos+1;
}
devices.push_back(deviceIndexProperty.substr(searchPos));
PlatformData* originalData = NULL;
if (originalContext != NULL)
originalData = reinterpret_cast<PlatformData*>(originalContext->getPlatformData());
try {
for (int i = 0; i < (int) devices.size(); i++) {
if (devices[i].length() > 0) {
int deviceIndex;
stringstream(devices[i]) >> deviceIndex;
contexts.push_back(new OpenCLContext(system, platformIndex, deviceIndex, precisionProperty, *this));
contexts.push_back(new OpenCLContext(system, platformIndex, deviceIndex, precisionProperty, *this, (originalData == NULL ? NULL : originalData->contexts[i])));
}
}
if (contexts.size() == 0)
contexts.push_back(new OpenCLContext(system, platformIndex, -1, precisionProperty, *this));
contexts.push_back(new OpenCLContext(system, platformIndex, -1, precisionProperty, *this, (originalData == NULL ? NULL : originalData->contexts[0])));
}
catch (...) {
// If an exception was thrown, do our best to clean up memory.
......
......@@ -28,7 +28,7 @@ __kernel void copyForces(__global real4* forces, __global int* restrict invAtomO
/**
* Add all the forces from the CVs.
*/
__kernel void addForces(__global real4* forces, int numAtoms, int numCVs
__kernel void addForces(__global real4* forces, int numAtoms
PARAMETER_ARGUMENTS) {
for (int i = get_global_id(0); i < numAtoms; i += get_global_size(0)) {
real4 f = forces[i];
......
......@@ -54,7 +54,7 @@ template <class Real2>
void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
System system;
system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL);
OpenCLContext& context = *platformData.contexts[0];
context.initialize();
OpenMM_SFMT::SFMT sfmt;
......
......@@ -54,7 +54,7 @@ void testGaussian() {
System system;
for (int i = 0; i < numAtoms; i++)
system.addParticle(1.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL);
OpenCLContext& context = *platformData.contexts[0];
context.initialize();
context.getIntegrationUtilities().initRandomNumberGenerator(0);
......
......@@ -64,7 +64,7 @@ void verifySorting(vector<float> array) {
System system;
system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL);
OpenCLContext& context = *platformData.contexts[0];
context.initialize();
OpenCLArray data(context, array.size(), sizeof(float), "sortData");
......
......@@ -1020,8 +1020,9 @@ public:
*
* @param system the System this kernel will be applied to
* @param force the CustomCVForce this kernel will be used for
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/
void initialize(const System& system, const CustomCVForce& force);
void initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext);
/**
* Execute the kernel to calculate the forces and/or energy.
*
......
......@@ -2021,7 +2021,7 @@ ReferenceCalcCustomCVForceKernel::~ReferenceCalcCustomCVForceKernel() {
delete ixn;
}
void ReferenceCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force) {
void ReferenceCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) {
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment