Commit 27b802cf authored by Robert McGibbon's avatar Robert McGibbon
Browse files

bugfixes to previous commit

parent 576eb902
...@@ -191,6 +191,12 @@ public: ...@@ -191,6 +191,12 @@ public:
int getDeviceIndex() { int getDeviceIndex() {
return deviceIndex; return deviceIndex;
} }
/**
* Get the index of the cl::Platform associated with this object.
*/
int getPlatformIndex() {
return platformIndex;
}
/** /**
* Get the PlatformData object this context is part of. * Get the PlatformData object this context is part of.
*/ */
...@@ -604,6 +610,7 @@ private: ...@@ -604,6 +610,7 @@ private:
double time; double time;
OpenCLPlatform::PlatformData& platformData; OpenCLPlatform::PlatformData& platformData;
int deviceIndex; int deviceIndex;
int platformIndex;
int contextIndex; int contextIndex;
int stepCount; int stepCount;
int computeForceCount; int computeForceCount;
......
...@@ -88,19 +88,19 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device ...@@ -88,19 +88,19 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
contextIndex = platformData.contexts.size(); contextIndex = platformData.contexts.size();
std::vector<cl::Platform> platforms; std::vector<cl::Platform> platforms;
cl::Platform::get(&platforms); cl::Platform::get(&platforms);
const int minThreadBlockSize = 32;
int bestSpeed = -1; int bestSpeed = -1;
int bestDevice = -1; int bestDevice = -1;
int bestPlatform = -1; int bestPlatform = -1;
for (j = 0; i < platforms.size(); j++) { for (int j = 0; j < platforms.size(); j++) {
// if they supplied a valid platformIndex, we only look through that platform // if they supplied a valid platformIndex, we only look through that platform
if (j != plaformIndex && platformIndex >= 0 && platformIndex < (int) platforms.size()) if (j != platformIndex && platformIndex >= 0 && platformIndex < (int) platforms.size())
continue; continue;
string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>(); string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>();
vector<cl::Device> devices; vector<cl::Device> devices;
platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices); platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices);
const int minThreadBlockSize = 32;
for (int i = 0; i < (int) devices.size(); i++) { for (int i = 0; i < (int) devices.size(); i++) {
// if they supplied a valid deviceIndex, we only look through that one // if they supplied a valid deviceIndex, we only look through that one
...@@ -160,9 +160,11 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device ...@@ -160,9 +160,11 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
vector<cl::Device> devices; vector<cl::Device> devices;
platforms[bestPlatform].getDevices(CL_DEVICE_TYPE_ALL, &devices); platforms[bestPlatform].getDevices(CL_DEVICE_TYPE_ALL, &devices);
string platformVendor = platforms[bestPlatform].getInfo<CL_PLATFORM_VENDOR>();
device = devices[bestDevice]; device = devices[bestDevice];
this->deviceIndex = bestDevice; this->deviceIndex = bestDevice;
this->platformIndex = bestPlatform;
if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize) if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize)
throw OpenMMException("The specified OpenCL device is not compatible with OpenMM"); throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize); compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize);
...@@ -244,7 +246,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device ...@@ -244,7 +246,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)"; compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)";
vector<cl::Device> contextDevices; vector<cl::Device> contextDevices;
contextDevices.push_back(device); contextDevices.push_back(device);
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[platformIndex](), 0}; cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
context = cl::Context(contextDevices, cprops, errorCallback); context = cl::Context(contextDevices, cprops, errorCallback);
queue = cl::CommandQueue(context, device); queue = cl::CommandQueue(context, device);
numAtoms = system.getNumParticles(); numAtoms = system.getNumParticles();
......
...@@ -161,6 +161,8 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p ...@@ -161,6 +161,8 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p
deviceIndex << contexts[i]->getDeviceIndex(); deviceIndex << contexts[i]->getDeviceIndex();
deviceName << contexts[i]->getDevice().getInfo<CL_DEVICE_NAME>(); deviceName << contexts[i]->getDevice().getInfo<CL_DEVICE_NAME>();
} }
platformIndex = contexts[0]->getPlatformIndex();
useCpuPme = (cpuPmeProperty == "true" && !contexts[0]->getUseDoublePrecision()); useCpuPme = (cpuPmeProperty == "true" && !contexts[0]->getUseDoublePrecision());
propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = deviceIndex.str(); propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = deviceIndex.str();
propertyValues[OpenCLPlatform::OpenCLDeviceName()] = deviceName.str(); propertyValues[OpenCLPlatform::OpenCLDeviceName()] = deviceName.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment