Commit 6660be2f authored by Peter Eastman's avatar Peter Eastman
Browse files

Added OpenCLPlatformIndex property for selecting platform

parent 0308dbf8
......@@ -61,11 +61,18 @@ public:
static const std::string key = "OpenCLDeviceIndex";
return key;
}
/**
* This is the name of the parameter for selecting which OpenCL platform to use.
*/
static const std::string& OpenCLPlatformIndex() {
static const std::string key = "OpenCLPlatformIndex";
return key;
}
};
class OPENMM_EXPORT OpenCLPlatform::PlatformData {
public:
PlatformData(int numParticles, const std::string& deviceIndexProperty);
PlatformData(int numParticles, const std::string& platformPropValue, const std::string& deviceIndexProperty);
~PlatformData();
void initializeContexts(const System& system);
void syncContexts();
......
......@@ -60,7 +60,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i
std::cerr << "OpenCL internal error: " << errinfo << std::endl;
}
OpenCLContext::OpenCLContext(int numParticles, int deviceIndex, OpenCLPlatform::PlatformData& platformData) :
OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData) :
time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), atomsWereReordered(false), posq(NULL),
velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndex(NULL), integration(NULL),
bonded(NULL), nonbonded(NULL), thread(NULL) {
......@@ -68,7 +68,9 @@ OpenCLContext::OpenCLContext(int numParticles, int deviceIndex, OpenCLPlatform::
contextIndex = platformData.contexts.size();
std::vector<cl::Platform> platforms;
cl::Platform::get(&platforms);
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[0](), 0};
if (platformIndex < 0 || platformIndex >= platforms.size())
throw OpenMMException("Illegal value for OpenCL platform index");
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[platformIndex](), 0};
context = cl::Context(CL_DEVICE_TYPE_ALL, cprops, errorCallback);
vector<cl::Device> devices = context.getInfo<CL_CONTEXT_DEVICES>();
const int minThreadBlockSize = 32;
......
......@@ -145,7 +145,7 @@ public:
class WorkThread;
static const int ThreadBlockSize;
static const int TileSize;
OpenCLContext(int numParticles, int deviceIndex, OpenCLPlatform::PlatformData& platformData);
OpenCLContext(int numParticles, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData);
~OpenCLContext();
/**
* This is called to initialize internal data structures after all Forces in the system
......
......@@ -72,7 +72,9 @@ OpenCLPlatform::OpenCLPlatform() {
registerKernelFactory(CalcKineticEnergyKernel::Name(), factory);
registerKernelFactory(RemoveCMMotionKernel::Name(), factory);
platformProperties.push_back(OpenCLDeviceIndex());
platformProperties.push_back(OpenCLPlatformIndex());
setPropertyDefaultValue(OpenCLDeviceIndex(), "");
setPropertyDefaultValue(OpenCLPlatformIndex(), "");
}
bool OpenCLPlatform::supportsDoublePrecision() const {
......@@ -92,10 +94,12 @@ void OpenCLPlatform::setPropertyValue(Context& context, const string& property,
}
void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
const string& platformPropValue = (properties.find(OpenCLPlatformIndex()) == properties.end() ?
getPropertyDefaultValue(OpenCLPlatformIndex()) : properties.find(OpenCLPlatformIndex())->second);
const string& devicePropValue = (properties.find(OpenCLDeviceIndex()) == properties.end() ?
getPropertyDefaultValue(OpenCLDeviceIndex()) : properties.find(OpenCLDeviceIndex())->second);
int numParticles = context.getSystem().getNumParticles();
context.setPlatformData(new PlatformData(numParticles, devicePropValue));
context.setPlatformData(new PlatformData(numParticles, platformPropValue, devicePropValue));
}
void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
......@@ -103,7 +107,10 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
delete data;
}
OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& platformPropValue, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
int platformIndex = 0;
if (platformPropValue.length() > 0)
stringstream(platformPropValue) >> platformIndex;
vector<string> devices;
size_t searchPos = 0, nextPos;
while ((nextPos = deviceIndexProperty.find_first_of(", ", searchPos)) != string::npos) {
......@@ -115,11 +122,11 @@ OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& devic
if (devices[i].length() > 0) {
unsigned int deviceIndex;
stringstream(devices[i]) >> deviceIndex;
contexts.push_back(new OpenCLContext(numParticles, deviceIndex, *this));
contexts.push_back(new OpenCLContext(numParticles, platformIndex, deviceIndex, *this));
}
}
if (contexts.size() == 0)
contexts.push_back(new OpenCLContext(numParticles, -1, *this));
contexts.push_back(new OpenCLContext(numParticles, platformIndex, -1, *this));
stringstream device;
for (int i = 0; i < (int) contexts.size(); i++) {
if (i > 0)
......@@ -127,6 +134,7 @@ OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& devic
device << contexts[i]->getDeviceIndex();
}
propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = device.str();
propertyValues[OpenCLPlatform::OpenCLPlatformIndex()] = OpenCLExpressionUtilities::intToString(platformIndex);
contextEnergy.resize(contexts.size());
}
......
......@@ -51,7 +51,7 @@ using namespace std;
void testTransform() {
System system;
system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(1, "");
OpenCLPlatform::PlatformData platformData(1, "", "");
OpenCLContext& context = *platformData.contexts[0];
context.initialize(system);
OpenMM_SFMT::SFMT sfmt;
......
......@@ -48,7 +48,7 @@ void testGaussian() {
System system;
for (int i = 0; i < numAtoms; i++)
system.addParticle(1.0);
OpenCLPlatform::PlatformData platformData(numAtoms, "");
OpenCLPlatform::PlatformData platformData(numAtoms, "", "");
OpenCLContext& context = *platformData.contexts[0];
context.initialize(system);
context.getIntegrationUtilities().initRandomNumberGenerator(0);
......
......@@ -51,7 +51,7 @@ void verifySorting(vector<float> array) {
System system;
system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(1, "");
OpenCLPlatform::PlatformData platformData(1, "", "");
OpenCLContext& context = *platformData.contexts[0];
context.initialize(system);
OpenCLArray<float> data(context, array.size(), "sortData");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment