Commit d5f5193f authored by peastman's avatar peastman
Browse files

Throw an exception if an illegal device or platform index is specified

parent 8dece9f1
...@@ -120,7 +120,9 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking ...@@ -120,7 +120,9 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
int numDevices; int numDevices;
string errorMessage = "Error initializing Context"; string errorMessage = "Error initializing Context";
CHECK_RESULT(cuDeviceGetCount(&numDevices)); CHECK_RESULT(cuDeviceGetCount(&numDevices));
if (deviceIndex < 0 || deviceIndex >= numDevices) { if (deviceIndex < -1 || deviceIndex >= numDevices)
throw OpenMMException("Illegal value for CudaDeviceIndex: "+intToString(deviceIndex));
if (deviceIndex == -1) {
// Try to figure out which device is the fastest. // Try to figure out which device is the fastest.
int bestSpeed = -1; int bestSpeed = -1;
......
...@@ -89,23 +89,27 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device ...@@ -89,23 +89,27 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
contextIndex = platformData.contexts.size(); contextIndex = platformData.contexts.size();
std::vector<cl::Platform> platforms; std::vector<cl::Platform> platforms;
cl::Platform::get(&platforms); cl::Platform::get(&platforms);
if (platformIndex < -1 || platformIndex >= (int) platforms.size())
throw OpenMMException("Illegal value for OpenCLPlatformIndex: "+intToString(platformIndex));
const int minThreadBlockSize = 32; const int minThreadBlockSize = 32;
int bestSpeed = -1; int bestSpeed = -1;
int bestDevice = -1; int bestDevice = -1;
int bestPlatform = -1; int bestPlatform = -1;
for (int j = 0; j < platforms.size(); j++) { for (int j = 0; j < platforms.size(); j++) {
// if they supplied a valid platformIndex, we only look through that platform // If they supplied a valid platformIndex, we only look through that platform
if (j != platformIndex && platformIndex >= 0 && platformIndex < (int) platforms.size()) if (j != platformIndex && platformIndex != -1)
continue; continue;
string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>(); string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>();
vector<cl::Device> devices; vector<cl::Device> devices;
platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices); platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices);
if (deviceIndex < -1 || deviceIndex >= (int) devices.size())
throw OpenMMException("Illegal value for OpenCLDeviceIndex: "+intToString(deviceIndex));
for (int i = 0; i < (int) devices.size(); i++) { for (int i = 0; i < (int) devices.size(); i++) {
// if they supplied a valid deviceIndex, we only look through that one // If they supplied a valid deviceIndex, we only look through that one
if (i != deviceIndex && deviceIndex >= 0 && deviceIndex < (int) devices.size()) if (i != deviceIndex && deviceIndex != -1)
continue; continue;
if (platformVendor == "Apple" && (devices[i].getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU)) if (platformVendor == "Apple" && (devices[i].getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU))
continue; // The CPU device on OS X won't work correctly. continue; // The CPU device on OS X won't work correctly.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment