Commit f915b68a authored by peastman's avatar peastman
Browse files

Merge pull request #118 from rmcgibbo/clplatformselect

Iterate through available OpenCL platforms when searching for the fastest OpenCL device
parents acf7db67 eed3efbf
...@@ -191,6 +191,12 @@ public: ...@@ -191,6 +191,12 @@ public:
int getDeviceIndex() { int getDeviceIndex() {
return deviceIndex; return deviceIndex;
} }
/**
* Get the index of the cl::Platform associated with this object.
*/
int getPlatformIndex() {
return platformIndex;
}
/** /**
* Get the PlatformData object this context is part of. * Get the PlatformData object this context is part of.
*/ */
...@@ -604,6 +610,7 @@ private: ...@@ -604,6 +610,7 @@ private:
double time; double time;
OpenCLPlatform::PlatformData& platformData; OpenCLPlatform::PlatformData& platformData;
int deviceIndex; int deviceIndex;
int platformIndex;
int contextIndex; int contextIndex;
int stepCount; int stepCount;
int computeForceCount; int computeForceCount;
......
...@@ -88,17 +88,25 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device ...@@ -88,17 +88,25 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
contextIndex = platformData.contexts.size(); contextIndex = platformData.contexts.size();
std::vector<cl::Platform> platforms; std::vector<cl::Platform> platforms;
cl::Platform::get(&platforms); cl::Platform::get(&platforms);
if (platformIndex < 0 || platformIndex >= (int) platforms.size())
throw OpenMMException("Illegal value for OpenCL platform index");
string platformVendor = platforms[platformIndex].getInfo<CL_PLATFORM_VENDOR>();
vector<cl::Device> devices;
platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices);
const int minThreadBlockSize = 32; const int minThreadBlockSize = 32;
if (deviceIndex < 0 || deviceIndex >= (int) devices.size()) {
// Try to figure out which device is the fastest.
int bestSpeed = -1; int bestSpeed = -1;
int bestDevice = -1;
int bestPlatform = -1;
for (int j = 0; j < platforms.size(); j++) {
// if they supplied a valid platformIndex, we only look through that platform
if (j != platformIndex && platformIndex >= 0 && platformIndex < (int) platforms.size())
continue;
string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>();
vector<cl::Device> devices;
platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices);
for (int i = 0; i < (int) devices.size(); i++) { for (int i = 0; i < (int) devices.size(); i++) {
// if they supplied a valid deviceIndex, we only look through that one
if (i != deviceIndex && deviceIndex >= 0 && deviceIndex < (int) devices.size())
continue;
if (platformVendor == "Apple" && devices[i].getInfo<CL_DEVICE_VENDOR>() == "AMD") if (platformVendor == "Apple" && devices[i].getInfo<CL_DEVICE_VENDOR>() == "AMD")
continue; // Don't use AMD GPUs on OS X due to serious bugs. continue; // Don't use AMD GPUs on OS X due to serious bugs.
int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0]; int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0];
...@@ -137,15 +145,26 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device ...@@ -137,15 +145,26 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
} }
int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*processingElementsPerComputeUnit*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>(); int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*processingElementsPerComputeUnit*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
if (maxSize >= minThreadBlockSize && speed > bestSpeed) { if (maxSize >= minThreadBlockSize && speed > bestSpeed) {
deviceIndex = i; bestDevice = i;
bestSpeed = speed; bestSpeed = speed;
bestPlatform = j;
} }
} }
} }
if (deviceIndex == -1)
if (bestPlatform == -1)
throw OpenMMException("No compatible OpenCL platform is available");
if (bestDevice == -1)
throw OpenMMException("No compatible OpenCL device is available"); throw OpenMMException("No compatible OpenCL device is available");
device = devices[deviceIndex];
this->deviceIndex = deviceIndex; vector<cl::Device> devices;
platforms[bestPlatform].getDevices(CL_DEVICE_TYPE_ALL, &devices);
string platformVendor = platforms[bestPlatform].getInfo<CL_PLATFORM_VENDOR>();
device = devices[bestDevice];
this->deviceIndex = bestDevice;
this->platformIndex = bestPlatform;
if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize) if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize)
throw OpenMMException("The specified OpenCL device is not compatible with OpenMM"); throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize); compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize);
...@@ -227,7 +246,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device ...@@ -227,7 +246,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)"; compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)";
vector<cl::Device> contextDevices; vector<cl::Device> contextDevices;
contextDevices.push_back(device); contextDevices.push_back(device);
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[platformIndex](), 0}; cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
context = cl::Context(contextDevices, cprops, errorCallback); context = cl::Context(contextDevices, cprops, errorCallback);
queue = cl::CommandQueue(context, device); queue = cl::CommandQueue(context, device);
numAtoms = system.getNumParticles(); numAtoms = system.getNumParticles();
......
...@@ -133,7 +133,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { ...@@ -133,7 +133,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& platformPropValue, const string& deviceIndexProperty, OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& platformPropValue, const string& deviceIndexProperty,
const string& precisionProperty, const string& cpuPmeProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) { const string& precisionProperty, const string& cpuPmeProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
int platformIndex = 0; int platformIndex = -1;
if (platformPropValue.length() > 0) if (platformPropValue.length() > 0)
stringstream(platformPropValue) >> platformIndex; stringstream(platformPropValue) >> platformIndex;
vector<string> devices; vector<string> devices;
...@@ -161,6 +161,8 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p ...@@ -161,6 +161,8 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p
deviceIndex << contexts[i]->getDeviceIndex(); deviceIndex << contexts[i]->getDeviceIndex();
deviceName << contexts[i]->getDevice().getInfo<CL_DEVICE_NAME>(); deviceName << contexts[i]->getDevice().getInfo<CL_DEVICE_NAME>();
} }
platformIndex = contexts[0]->getPlatformIndex();
useCpuPme = (cpuPmeProperty == "true" && !contexts[0]->getUseDoublePrecision()); useCpuPme = (cpuPmeProperty == "true" && !contexts[0]->getUseDoublePrecision());
propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = deviceIndex.str(); propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = deviceIndex.str();
propertyValues[OpenCLPlatform::OpenCLDeviceName()] = deviceName.str(); propertyValues[OpenCLPlatform::OpenCLDeviceName()] = deviceName.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment