"...ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "6f8534dcb39f06f9fa22210ce9438b35eb60bb25"
Commit cfcebe26 authored by Peter Eastman's avatar Peter Eastman
Browse files

Added a version of Platform::getPlatform() that looks up platforms by name. ...

Added a version of Platform::getPlatform() that looks up platforms by name.  Added a Context constructor that allows platform-specific properties to be specified.
parent 3a6d79b3
...@@ -119,8 +119,11 @@ public: ...@@ -119,8 +119,11 @@ public:
/** /**
* This is called whenever a new Context is created. It gives the Platform a chance to initialize * This is called whenever a new Context is created. It gives the Platform a chance to initialize
* the context and store platform-specific data in it. * the context and store platform-specific data in it.
*
* @param context the newly created context
* @param properties a set of values for platform-specific properties. Keys are the property names.
*/ */
virtual void contextCreated(ContextImpl& context) const; virtual void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
/** /**
* This is called whenever a Context is deleted. It gives the Platform a chance to clean up * This is called whenever a Context is deleted. It gives the Platform a chance to clean up
* any platform-specific data that was stored in it. * any platform-specific data that was stored in it.
...@@ -168,6 +171,11 @@ public: ...@@ -168,6 +171,11 @@ public:
* Get a registered Platform by index. * Get a registered Platform by index.
*/ */
static Platform& getPlatform(int index); static Platform& getPlatform(int index);
/**
* Get the registered Platform with a particular name. If no Platform with that name has been
* registered, this throws an exception.
*/
static Platform& getPlatform(const std::string& name);
/** /**
* Find a Platform which can be used to perform a calculation. * Find a Platform which can be used to perform a calculation.
* *
......
...@@ -91,7 +91,7 @@ void Platform::setPropertyDefaultValue(const string& property, const string& val ...@@ -91,7 +91,7 @@ void Platform::setPropertyDefaultValue(const string& property, const string& val
defaultProperties[property] = value; defaultProperties[property] = value;
} }
void Platform::contextCreated(ContextImpl& context) const { void Platform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
} }
void Platform::contextDestroyed(ContextImpl& context) const { void Platform::contextDestroyed(ContextImpl& context) const {
...@@ -131,6 +131,13 @@ Platform& Platform::getPlatform(int index) { ...@@ -131,6 +131,13 @@ Platform& Platform::getPlatform(int index) {
return *getPlatforms()[index]; return *getPlatforms()[index];
} }
Platform& Platform::getPlatform(const string& name) {
for (int i = 0; i < getNumPlatforms(); i++)
if (getPlatform(i).getName() == name)
return getPlatform(i);
throw OpenMMException("There is no registered Platform called \""+name+"\"");
}
Platform& Platform::findPlatform(const vector<string>& kernelNames) { Platform& Platform::findPlatform(const vector<string>& kernelNames) {
Platform* best = 0; Platform* best = 0;
vector<Platform*>& platforms = getPlatforms(); vector<Platform*>& platforms = getPlatforms();
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "Integrator.h" #include "Integrator.h"
#include "State.h" #include "State.h"
#include "System.h" #include "System.h"
#include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "internal/windowsExport.h" #include "internal/windowsExport.h"
...@@ -79,6 +80,16 @@ public: ...@@ -79,6 +80,16 @@ public:
* @param platform the Platform to use for calculations * @param platform the Platform to use for calculations
*/ */
Context(System& system, Integrator& integrator, Platform& platform); Context(System& system, Integrator& integrator, Platform& platform);
/**
* Construct a new Context in which to run a simulation, explicitly specifying what Platform should be used
* to perform calculations and the values of platform-specific properties.
*
* @param system the System which will be simulated
* @param integrator the Integrator which will be used to simulate the System
* @param platform the Platform to use for calculations
* @param properties a set of values for platform-specific properties. Keys are the property names.
*/
Context(System& system, Integrator& integrator, Platform& platform, const std::map<std::string, std::string>& properties);
~Context(); ~Context();
/** /**
* Get System being simulated in this context. * Get System being simulated in this context.
...@@ -155,6 +166,7 @@ public: ...@@ -155,6 +166,7 @@ public:
private: private:
friend class Platform; friend class Platform;
ContextImpl* impl; ContextImpl* impl;
std::map<std::string, std::string> properties;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
/** /**
* Create an ContextImpl for a Context; * Create an ContextImpl for a Context;
*/ */
ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform); ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform, const std::map<std::string, std::string>& properties);
~ContextImpl(); ~ContextImpl();
/** /**
* Get the Context for which this is the implementation. * Get the Context for which this is the implementation.
......
...@@ -36,12 +36,16 @@ ...@@ -36,12 +36,16 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
Context::Context(System& system, Integrator& integrator) { Context::Context(System& system, Integrator& integrator) : properties(map<string, string>()) {
impl = new ContextImpl(*this, system, integrator, 0); impl = new ContextImpl(*this, system, integrator, 0, properties);
} }
Context::Context(System& system, Integrator& integrator, Platform& platform) { Context::Context(System& system, Integrator& integrator, Platform& platform) : properties(map<string, string>()) {
impl = new ContextImpl(*this, system, integrator, &platform); impl = new ContextImpl(*this, system, integrator, &platform, properties);
}
Context::Context(System& system, Integrator& integrator, Platform& platform, const map<string, string>& properties) : properties(properties) {
impl = new ContextImpl(*this, system, integrator, &platform, properties);
} }
Context::~Context() { Context::~Context() {
...@@ -121,5 +125,5 @@ void Context::reinitialize() { ...@@ -121,5 +125,5 @@ void Context::reinitialize() {
Integrator& integrator = impl->getIntegrator(); Integrator& integrator = impl->getIntegrator();
Platform& platform = impl->getPlatform(); Platform& platform = impl->getPlatform();
delete impl; delete impl;
impl = new ContextImpl(*this, system, integrator, &platform); impl = new ContextImpl(*this, system, integrator, &platform, properties);
} }
...@@ -44,7 +44,7 @@ using std::map; ...@@ -44,7 +44,7 @@ using std::map;
using std::vector; using std::vector;
using std::string; using std::string;
ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform) : ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties) :
owner(owner), system(system), integrator(integrator), platform(platform), platformData(NULL) { owner(owner), system(system), integrator(integrator), platform(platform), platformData(NULL) {
vector<string> kernelNames; vector<string> kernelNames;
kernelNames.push_back(CalcKineticEnergyKernel::Name()); kernelNames.push_back(CalcKineticEnergyKernel::Name());
...@@ -63,7 +63,7 @@ ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, ...@@ -63,7 +63,7 @@ ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator,
this->platform = platform = &Platform::findPlatform(kernelNames); this->platform = platform = &Platform::findPlatform(kernelNames);
else if (!platform->supportsKernels(kernelNames)) else if (!platform->supportsKernels(kernelNames))
throw OpenMMException("Specified a Platform for a Context which does not support all required kernels"); throw OpenMMException("Specified a Platform for a Context which does not support all required kernels");
platform->contextCreated(*this); platform->contextCreated(*this, properties);
initializeForcesKernel = platform->createKernel(CalcForcesAndEnergyKernel::Name(), *this); initializeForcesKernel = platform->createKernel(CalcForcesAndEnergyKernel::Name(), *this);
dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl()).initialize(system); dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl()).initialize(system);
kineticEnergyKernel = platform->createKernel(CalcKineticEnergyKernel::Name(), *this); kineticEnergyKernel = platform->createKernel(CalcKineticEnergyKernel::Name(), *this);
......
...@@ -51,7 +51,7 @@ public: ...@@ -51,7 +51,7 @@ public:
bool supportsDoublePrecision() const; bool supportsDoublePrecision() const;
const std::string& getPropertyValue(const Context& context, const std::string& property) const; const std::string& getPropertyValue(const Context& context, const std::string& property) const;
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const; void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
void contextCreated(ContextImpl& context) const; void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void contextDestroyed(ContextImpl& context) const; void contextDestroyed(ContextImpl& context) const;
/** /**
* This is the name of the parameter for selecting which CUDA device to use. * This is the name of the parameter for selecting which CUDA device to use.
......
...@@ -88,13 +88,16 @@ const string& CudaPlatform::getPropertyValue(const Context& context, const strin ...@@ -88,13 +88,16 @@ const string& CudaPlatform::getPropertyValue(const Context& context, const strin
void CudaPlatform::setPropertyValue(Context& context, const string& property, const string& value) const { void CudaPlatform::setPropertyValue(Context& context, const string& property, const string& value) const {
} }
void CudaPlatform::contextCreated(ContextImpl& context) const { void CudaPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
unsigned int device = 0; unsigned int device = 0;
const string& devicePropValue = getPropertyDefaultValue(CudaDevice()); const string& devicePropValue = (properties.find(CudaDevice()) == properties.end() ?
getPropertyDefaultValue(CudaDevice()) : properties.find(CudaDevice())->second);
if (devicePropValue.length() > 0) if (devicePropValue.length() > 0)
stringstream(devicePropValue) >> device; stringstream(devicePropValue) >> device;
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
_gpuContext* gpu = (_gpuContext*) gpuInit(numParticles, device, getPropertyDefaultValue(CudaUseBlockingSync()) == "true"); const string& blockingSync = (properties.find(CudaUseBlockingSync()) == properties.end() ?
getPropertyDefaultValue(CudaUseBlockingSync()) : properties.find(CudaUseBlockingSync())->second);
_gpuContext* gpu = (_gpuContext*) gpuInit(numParticles, device, blockingSync == "true");
context.setPlatformData(new PlatformData(gpu)); context.setPlatformData(new PlatformData(gpu));
} }
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
bool supportsDoublePrecision() const; bool supportsDoublePrecision() const;
const std::string& getPropertyValue(const Context& context, const std::string& property) const; const std::string& getPropertyValue(const Context& context, const std::string& property) const;
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const; void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
void contextCreated(ContextImpl& context) const; void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void contextDestroyed(ContextImpl& context) const; void contextDestroyed(ContextImpl& context) const;
/** /**
* This is the name of the parameter for selecting which OpenCL device to use. * This is the name of the parameter for selecting which OpenCL device to use.
......
...@@ -84,9 +84,10 @@ const string& OpenCLPlatform::getPropertyValue(const Context& context, const str ...@@ -84,9 +84,10 @@ const string& OpenCLPlatform::getPropertyValue(const Context& context, const str
void OpenCLPlatform::setPropertyValue(Context& context, const string& property, const string& value) const { void OpenCLPlatform::setPropertyValue(Context& context, const string& property, const string& value) const {
} }
void OpenCLPlatform::contextCreated(ContextImpl& context) const { void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
unsigned int deviceIndex = -1; unsigned int deviceIndex = -1;
const string& devicePropValue = getPropertyDefaultValue(OpenCLDeviceIndex()); const string& devicePropValue = (properties.find(OpenCLDeviceIndex()) == properties.end() ?
getPropertyDefaultValue(OpenCLDeviceIndex()) : properties.find(OpenCLDeviceIndex())->second);
if (devicePropValue.length() > 0) if (devicePropValue.length() > 0)
stringstream(devicePropValue) >> deviceIndex; stringstream(devicePropValue) >> deviceIndex;
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
return 1; return 1;
} }
bool supportsDoublePrecision() const; bool supportsDoublePrecision() const;
void contextCreated(ContextImpl& context) const; void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void contextDestroyed(ContextImpl& context) const; void contextDestroyed(ContextImpl& context) const;
}; };
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include "SimTKUtilities/SimTKOpenMMRealType.h" #include "SimTKUtilities/SimTKOpenMMRealType.h"
using namespace OpenMM; using namespace OpenMM;
using namespace std;
ReferencePlatform::ReferencePlatform() { ReferencePlatform::ReferencePlatform() {
ReferenceKernelFactory* factory = new ReferenceKernelFactory(); ReferenceKernelFactory* factory = new ReferenceKernelFactory();
...@@ -66,7 +67,7 @@ bool ReferencePlatform::supportsDoublePrecision() const { ...@@ -66,7 +67,7 @@ bool ReferencePlatform::supportsDoublePrecision() const {
return (sizeof(RealOpenMM) >= sizeof(double)); return (sizeof(RealOpenMM) >= sizeof(double));
} }
void ReferencePlatform::contextCreated(ContextImpl& context) const { void ReferencePlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
context.setPlatformData(new PlatformData(context.getSystem().getNumParticles())); context.setPlatformData(new PlatformData(context.getSystem().getNumParticles()));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment