Commit 6660be2f authored by Peter Eastman's avatar Peter Eastman
Browse files

Added OpenCLPlatformIndex property for selecting platform

parent 0308dbf8
...@@ -61,11 +61,18 @@ public: ...@@ -61,11 +61,18 @@ public:
static const std::string key = "OpenCLDeviceIndex"; static const std::string key = "OpenCLDeviceIndex";
return key; return key;
} }
/**
* This is the name of the parameter for selecting which OpenCL platform to use.
*/
static const std::string& OpenCLPlatformIndex() {
static const std::string key = "OpenCLPlatformIndex";
return key;
}
}; };
class OPENMM_EXPORT OpenCLPlatform::PlatformData { class OPENMM_EXPORT OpenCLPlatform::PlatformData {
public: public:
PlatformData(int numParticles, const std::string& deviceIndexProperty); PlatformData(int numParticles, const std::string& platformPropValue, const std::string& deviceIndexProperty);
~PlatformData(); ~PlatformData();
void initializeContexts(const System& system); void initializeContexts(const System& system);
void syncContexts(); void syncContexts();
......
...@@ -60,7 +60,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i ...@@ -60,7 +60,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i
std::cerr << "OpenCL internal error: " << errinfo << std::endl; std::cerr << "OpenCL internal error: " << errinfo << std::endl;
} }
OpenCLContext::OpenCLContext(int numParticles, int deviceIndex, OpenCLPlatform::PlatformData& platformData) : OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData) :
time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), atomsWereReordered(false), posq(NULL), time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), atomsWereReordered(false), posq(NULL),
velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndex(NULL), integration(NULL), velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndex(NULL), integration(NULL),
bonded(NULL), nonbonded(NULL), thread(NULL) { bonded(NULL), nonbonded(NULL), thread(NULL) {
...@@ -68,7 +68,9 @@ OpenCLContext::OpenCLContext(int numParticles, int deviceIndex, OpenCLPlatform:: ...@@ -68,7 +68,9 @@ OpenCLContext::OpenCLContext(int numParticles, int deviceIndex, OpenCLPlatform::
contextIndex = platformData.contexts.size(); contextIndex = platformData.contexts.size();
std::vector<cl::Platform> platforms; std::vector<cl::Platform> platforms;
cl::Platform::get(&platforms); cl::Platform::get(&platforms);
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[0](), 0}; if (platformIndex < 0 || platformIndex >= platforms.size())
throw OpenMMException("Illegal value for OpenCL platform index");
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[platformIndex](), 0};
context = cl::Context(CL_DEVICE_TYPE_ALL, cprops, errorCallback); context = cl::Context(CL_DEVICE_TYPE_ALL, cprops, errorCallback);
vector<cl::Device> devices = context.getInfo<CL_CONTEXT_DEVICES>(); vector<cl::Device> devices = context.getInfo<CL_CONTEXT_DEVICES>();
const int minThreadBlockSize = 32; const int minThreadBlockSize = 32;
......
...@@ -145,7 +145,7 @@ public: ...@@ -145,7 +145,7 @@ public:
class WorkThread; class WorkThread;
static const int ThreadBlockSize; static const int ThreadBlockSize;
static const int TileSize; static const int TileSize;
OpenCLContext(int numParticles, int deviceIndex, OpenCLPlatform::PlatformData& platformData); OpenCLContext(int numParticles, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData);
~OpenCLContext(); ~OpenCLContext();
/** /**
* This is called to initialize internal data structures after all Forces in the system * This is called to initialize internal data structures after all Forces in the system
......
...@@ -72,7 +72,9 @@ OpenCLPlatform::OpenCLPlatform() { ...@@ -72,7 +72,9 @@ OpenCLPlatform::OpenCLPlatform() {
registerKernelFactory(CalcKineticEnergyKernel::Name(), factory); registerKernelFactory(CalcKineticEnergyKernel::Name(), factory);
registerKernelFactory(RemoveCMMotionKernel::Name(), factory); registerKernelFactory(RemoveCMMotionKernel::Name(), factory);
platformProperties.push_back(OpenCLDeviceIndex()); platformProperties.push_back(OpenCLDeviceIndex());
platformProperties.push_back(OpenCLPlatformIndex());
setPropertyDefaultValue(OpenCLDeviceIndex(), ""); setPropertyDefaultValue(OpenCLDeviceIndex(), "");
setPropertyDefaultValue(OpenCLPlatformIndex(), "");
} }
bool OpenCLPlatform::supportsDoublePrecision() const { bool OpenCLPlatform::supportsDoublePrecision() const {
...@@ -92,10 +94,12 @@ void OpenCLPlatform::setPropertyValue(Context& context, const string& property, ...@@ -92,10 +94,12 @@ void OpenCLPlatform::setPropertyValue(Context& context, const string& property,
} }
void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const { void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
const string& platformPropValue = (properties.find(OpenCLPlatformIndex()) == properties.end() ?
getPropertyDefaultValue(OpenCLPlatformIndex()) : properties.find(OpenCLPlatformIndex())->second);
const string& devicePropValue = (properties.find(OpenCLDeviceIndex()) == properties.end() ? const string& devicePropValue = (properties.find(OpenCLDeviceIndex()) == properties.end() ?
getPropertyDefaultValue(OpenCLDeviceIndex()) : properties.find(OpenCLDeviceIndex())->second); getPropertyDefaultValue(OpenCLDeviceIndex()) : properties.find(OpenCLDeviceIndex())->second);
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
context.setPlatformData(new PlatformData(numParticles, devicePropValue)); context.setPlatformData(new PlatformData(numParticles, platformPropValue, devicePropValue));
} }
void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
...@@ -103,7 +107,10 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { ...@@ -103,7 +107,10 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
delete data; delete data;
} }
OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) { OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& platformPropValue, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
int platformIndex = 0;
if (platformPropValue.length() > 0)
stringstream(platformPropValue) >> platformIndex;
vector<string> devices; vector<string> devices;
size_t searchPos = 0, nextPos; size_t searchPos = 0, nextPos;
while ((nextPos = deviceIndexProperty.find_first_of(", ", searchPos)) != string::npos) { while ((nextPos = deviceIndexProperty.find_first_of(", ", searchPos)) != string::npos) {
...@@ -115,11 +122,11 @@ OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& devic ...@@ -115,11 +122,11 @@ OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& devic
if (devices[i].length() > 0) { if (devices[i].length() > 0) {
unsigned int deviceIndex; unsigned int deviceIndex;
stringstream(devices[i]) >> deviceIndex; stringstream(devices[i]) >> deviceIndex;
contexts.push_back(new OpenCLContext(numParticles, deviceIndex, *this)); contexts.push_back(new OpenCLContext(numParticles, platformIndex, deviceIndex, *this));
} }
} }
if (contexts.size() == 0) if (contexts.size() == 0)
contexts.push_back(new OpenCLContext(numParticles, -1, *this)); contexts.push_back(new OpenCLContext(numParticles, platformIndex, -1, *this));
stringstream device; stringstream device;
for (int i = 0; i < (int) contexts.size(); i++) { for (int i = 0; i < (int) contexts.size(); i++) {
if (i > 0) if (i > 0)
...@@ -127,6 +134,7 @@ OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& devic ...@@ -127,6 +134,7 @@ OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& devic
device << contexts[i]->getDeviceIndex(); device << contexts[i]->getDeviceIndex();
} }
propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = device.str(); propertyValues[OpenCLPlatform::OpenCLDeviceIndex()] = device.str();
propertyValues[OpenCLPlatform::OpenCLPlatformIndex()] = OpenCLExpressionUtilities::intToString(platformIndex);
contextEnergy.resize(contexts.size()); contextEnergy.resize(contexts.size());
} }
......
...@@ -51,7 +51,7 @@ using namespace std; ...@@ -51,7 +51,7 @@ using namespace std;
void testTransform() { void testTransform() {
System system; System system;
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(1, ""); OpenCLPlatform::PlatformData platformData(1, "", "");
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(system); context.initialize(system);
OpenMM_SFMT::SFMT sfmt; OpenMM_SFMT::SFMT sfmt;
......
...@@ -48,7 +48,7 @@ void testGaussian() { ...@@ -48,7 +48,7 @@ void testGaussian() {
System system; System system;
for (int i = 0; i < numAtoms; i++) for (int i = 0; i < numAtoms; i++)
system.addParticle(1.0); system.addParticle(1.0);
OpenCLPlatform::PlatformData platformData(numAtoms, ""); OpenCLPlatform::PlatformData platformData(numAtoms, "", "");
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(system); context.initialize(system);
context.getIntegrationUtilities().initRandomNumberGenerator(0); context.getIntegrationUtilities().initRandomNumberGenerator(0);
......
...@@ -51,7 +51,7 @@ void verifySorting(vector<float> array) { ...@@ -51,7 +51,7 @@ void verifySorting(vector<float> array) {
System system; System system;
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(1, ""); OpenCLPlatform::PlatformData platformData(1, "", "");
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(system); context.initialize(system);
OpenCLArray<float> data(context, array.size(), "sortData"); OpenCLArray<float> data(context, array.size(), "sortData");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment