"...ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "a468fa3a31c83b6262e8be29b842662734fee98c"
Commit 9532c446 authored by Peter Eastman's avatar Peter Eastman
Browse files

Select the fastest OpenCL device automatically

parent ad4e1203
...@@ -55,13 +55,6 @@ public: ...@@ -55,13 +55,6 @@ public:
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const; void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
void contextCreated(ContextImpl& context) const; void contextCreated(ContextImpl& context) const;
void contextDestroyed(ContextImpl& context) const; void contextDestroyed(ContextImpl& context) const;
/**
* This is the name of the parameter for selecting which OpenCL platform to use.
*/
static const std::string& OpenCLPlatformIndex() {
static const std::string key = "OpenCLPlatformIndex";
return key;
}
/** /**
* This is the name of the parameter for selecting which OpenCL device to use. * This is the name of the parameter for selecting which OpenCL device to use.
*/ */
...@@ -73,7 +66,7 @@ public: ...@@ -73,7 +66,7 @@ public:
class OpenCLPlatform::PlatformData { class OpenCLPlatform::PlatformData {
public: public:
PlatformData(int numParticles, int platformIndex, int deviceIndex); PlatformData(int numParticles, int deviceIndex);
OpenCLContext* context; OpenCLContext* context;
bool removeCM; bool removeCM;
int cmMotionFrequency; int cmMotionFrequency;
......
...@@ -36,16 +36,32 @@ ...@@ -36,16 +36,32 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceIndex) { OpenCLContext::OpenCLContext(int numParticles, int deviceIndex) {
// TODO Select the platform and device correctly context = cl::Context(CL_DEVICE_TYPE_ALL);
context = cl::Context(CL_DEVICE_TYPE_GPU); vector<cl::Device> devices = context.getInfo<CL_CONTEXT_DEVICES>();
device = context.getInfo<CL_CONTEXT_DEVICES>()[0]; const int minThreadBlockSize = 32;
if (deviceIndex < 0 || deviceIndex >= devices.size()) {
// Try to figure out which device is the fastest.
int bestSpeed = 0;
for (int i = 0; i < devices.size(); i++) {
int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0];
int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
if (maxSize >= minThreadBlockSize && speed > bestSpeed)
deviceIndex = i;
}
}
if (deviceIndex == -1)
throw OpenMMException("No compatible OpenCL device is available");
device = devices[deviceIndex];
if (device.getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0] < minThreadBlockSize)
throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
queue = cl::CommandQueue(context, device); queue = cl::CommandQueue(context, device);
numAtoms = numParticles; numAtoms = numParticles;
paddedNumAtoms = TileSize*((numParticles+TileSize-1)/TileSize); paddedNumAtoms = TileSize*((numParticles+TileSize-1)/TileSize);
numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize; numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
numTiles = numAtomBlocks*(numAtomBlocks+1)/2; numTiles = numAtomBlocks*(numAtomBlocks+1)/2;
numThreadBlocks = 8*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>(); numThreadBlocks = device.getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0]/ThreadBlockSize;
// Create utility kernels that are used in multiple places. // Create utility kernels that are used in multiple places.
......
...@@ -64,7 +64,7 @@ class OpenCLContext { ...@@ -64,7 +64,7 @@ class OpenCLContext {
public: public:
static const int ThreadBlockSize = 64; static const int ThreadBlockSize = 64;
static const int TileSize = 32; static const int TileSize = 32;
OpenCLContext(int numParticles, int platformIndex, int deviceIndex); OpenCLContext(int numParticles, int deviceIndex);
~OpenCLContext(); ~OpenCLContext();
/** /**
* This is called to initialize internal data structures after all Forces in the system * This is called to initialize internal data structures after all Forces in the system
......
...@@ -62,10 +62,8 @@ OpenCLPlatform::OpenCLPlatform() { ...@@ -62,10 +62,8 @@ OpenCLPlatform::OpenCLPlatform() {
// registerKernelFactory(ApplyAndersenThermostatKernel::Name(), factory); // registerKernelFactory(ApplyAndersenThermostatKernel::Name(), factory);
registerKernelFactory(CalcKineticEnergyKernel::Name(), factory); registerKernelFactory(CalcKineticEnergyKernel::Name(), factory);
// registerKernelFactory(RemoveCMMotionKernel::Name(), factory); // registerKernelFactory(RemoveCMMotionKernel::Name(), factory);
platformProperties.push_back(OpenCLPlatformIndex());
platformProperties.push_back(OpenCLDeviceIndex()); platformProperties.push_back(OpenCLDeviceIndex());
setPropertyDefaultValue(OpenCLPlatformIndex(), "0"); setPropertyDefaultValue(OpenCLDeviceIndex(), "");
setPropertyDefaultValue(OpenCLDeviceIndex(), "0");
} }
bool OpenCLPlatform::supportsDoublePrecision() const { bool OpenCLPlatform::supportsDoublePrecision() const {
...@@ -85,16 +83,12 @@ void OpenCLPlatform::setPropertyValue(Context& context, const string& property, ...@@ -85,16 +83,12 @@ void OpenCLPlatform::setPropertyValue(Context& context, const string& property,
} }
void OpenCLPlatform::contextCreated(ContextImpl& context) const { void OpenCLPlatform::contextCreated(ContextImpl& context) const {
unsigned int platformIndex = 0; unsigned int deviceIndex = -1;
const string& platformPropValue = getPropertyDefaultValue(OpenCLPlatformIndex());
if (platformPropValue.length() > 0)
stringstream(platformPropValue) >> platformIndex;
unsigned int deviceIndex = 0;
const string& devicePropValue = getPropertyDefaultValue(OpenCLDeviceIndex()); const string& devicePropValue = getPropertyDefaultValue(OpenCLDeviceIndex());
if (devicePropValue.length() > 0) if (devicePropValue.length() > 0)
stringstream(devicePropValue) >> deviceIndex; stringstream(devicePropValue) >> deviceIndex;
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
context.setPlatformData(new PlatformData(numParticles, platformIndex, deviceIndex)); context.setPlatformData(new PlatformData(numParticles, deviceIndex));
} }
void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
...@@ -103,12 +97,9 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { ...@@ -103,12 +97,9 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
delete data; delete data;
} }
OpenCLPlatform::PlatformData::PlatformData(int numParticles, int platformIndex, int deviceIndex) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) { OpenCLPlatform::PlatformData::PlatformData(int numParticles, int deviceIndex) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
context = new OpenCLContext(numParticles, platformIndex, deviceIndex); context = new OpenCLContext(numParticles, deviceIndex);
stringstream platform;
// device << gpu->platform;
stringstream device; stringstream device;
// device << gpu->device; // device << gpu->device;
propertyValues[OpenCLPlatform::OpenCLPlatformIndex()] = platform.str();
propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = device.str(); propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = device.str();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment