"vscode:/vscode.git/clone" did not exist on "802be281a29a0b1e32b829c98141e71cadcab3d4"
Commit f915b68a authored by peastman's avatar peastman
Browse files

Merge pull request #118 from rmcgibbo/clplatformselect

Iterate through available OpenCL platforms when searching for the fastest OpenCL device
parents acf7db67 eed3efbf
......@@ -191,6 +191,12 @@ public:
int getDeviceIndex() {
return deviceIndex;
}
/**
* Get the index of the cl::Platform associated with this object.
*/
int getPlatformIndex() {
return platformIndex;
}
/**
* Get the PlatformData object this context is part of.
*/
......@@ -604,6 +610,7 @@ private:
double time;
OpenCLPlatform::PlatformData& platformData;
int deviceIndex;
int platformIndex;
int contextIndex;
int stepCount;
int computeForceCount;
......
......@@ -88,17 +88,25 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
contextIndex = platformData.contexts.size();
std::vector<cl::Platform> platforms;
cl::Platform::get(&platforms);
if (platformIndex < 0 || platformIndex >= (int) platforms.size())
throw OpenMMException("Illegal value for OpenCL platform index");
string platformVendor = platforms[platformIndex].getInfo<CL_PLATFORM_VENDOR>();
vector<cl::Device> devices;
platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices);
const int minThreadBlockSize = 32;
if (deviceIndex < 0 || deviceIndex >= (int) devices.size()) {
// Try to figure out which device is the fastest.
int bestSpeed = -1;
int bestDevice = -1;
int bestPlatform = -1;
for (int j = 0; j < platforms.size(); j++) {
// if they supplied a valid platformIndex, we only look through that platform
if (j != platformIndex && platformIndex >= 0 && platformIndex < (int) platforms.size())
continue;
string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>();
vector<cl::Device> devices;
platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices);
for (int i = 0; i < (int) devices.size(); i++) {
// if they supplied a valid deviceIndex, we only look through that one
if (i != deviceIndex && deviceIndex >= 0 && deviceIndex < (int) devices.size())
continue;
if (platformVendor == "Apple" && devices[i].getInfo<CL_DEVICE_VENDOR>() == "AMD")
continue; // Don't use AMD GPUs on OS X due to serious bugs.
int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0];
......@@ -137,15 +145,26 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
}
int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*processingElementsPerComputeUnit*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
if (maxSize >= minThreadBlockSize && speed > bestSpeed) {
deviceIndex = i;
bestDevice = i;
bestSpeed = speed;
bestPlatform = j;
}
}
}
if (deviceIndex == -1)
if (bestPlatform == -1)
throw OpenMMException("No compatible OpenCL platform is available");
if (bestDevice == -1)
throw OpenMMException("No compatible OpenCL device is available");
device = devices[deviceIndex];
this->deviceIndex = deviceIndex;
vector<cl::Device> devices;
platforms[bestPlatform].getDevices(CL_DEVICE_TYPE_ALL, &devices);
string platformVendor = platforms[bestPlatform].getInfo<CL_PLATFORM_VENDOR>();
device = devices[bestDevice];
this->deviceIndex = bestDevice;
this->platformIndex = bestPlatform;
if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize)
throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize);
......@@ -227,7 +246,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)";
vector<cl::Device> contextDevices;
contextDevices.push_back(device);
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[platformIndex](), 0};
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
context = cl::Context(contextDevices, cprops, errorCallback);
queue = cl::CommandQueue(context, device);
numAtoms = system.getNumParticles();
......
......@@ -133,7 +133,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& platformPropValue, const string& deviceIndexProperty,
const string& precisionProperty, const string& cpuPmeProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
int platformIndex = 0;
int platformIndex = -1;
if (platformPropValue.length() > 0)
stringstream(platformPropValue) >> platformIndex;
vector<string> devices;
......@@ -161,6 +161,8 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p
deviceIndex << contexts[i]->getDeviceIndex();
deviceName << contexts[i]->getDevice().getInfo<CL_DEVICE_NAME>();
}
platformIndex = contexts[0]->getPlatformIndex();
useCpuPme = (cpuPmeProperty == "true" && !contexts[0]->getUseDoublePrecision());
propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = deviceIndex.str();
propertyValues[OpenCLPlatform::OpenCLDeviceName()] = deviceName.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment