Commit cfcebe26 authored by Peter Eastman's avatar Peter Eastman
Browse files

Added a version of Platform::getPlatform() that looks up platforms by name. ...

Added a version of Platform::getPlatform() that looks up platforms by name.  Added a Context constructor that allows platform-specific properties to be specified.
parent 3a6d79b3
......@@ -119,8 +119,11 @@ public:
/**
* This is called whenever a new Context is created. It gives the Platform a chance to initialize
* the context and store platform-specific data in it.
*
* @param context the newly created context
* @param properties a set of values for platform-specific properties. Keys are the property names.
*/
virtual void contextCreated(ContextImpl& context) const;
virtual void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
/**
* This is called whenever a Context is deleted. It gives the Platform a chance to clean up
* any platform-specific data that was stored in it.
......@@ -168,6 +171,11 @@ public:
* Get a registered Platform by index.
*/
static Platform& getPlatform(int index);
/**
* Get the registered Platform with a particular name. If no Platform with that name has been
* registered, this throws an exception.
*/
static Platform& getPlatform(const std::string& name);
/**
* Find a Platform which can be used to perform a calculation.
*
......
......@@ -91,7 +91,7 @@ void Platform::setPropertyDefaultValue(const string& property, const string& val
defaultProperties[property] = value;
}
void Platform::contextCreated(ContextImpl& context) const {
void Platform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
}
void Platform::contextDestroyed(ContextImpl& context) const {
......@@ -131,6 +131,13 @@ Platform& Platform::getPlatform(int index) {
return *getPlatforms()[index];
}
Platform& Platform::getPlatform(const string& name) {
for (int i = 0; i < getNumPlatforms(); i++)
if (getPlatform(i).getName() == name)
return getPlatform(i);
throw OpenMMException("There is no registered Platform called \""+name+"\"");
}
Platform& Platform::findPlatform(const vector<string>& kernelNames) {
Platform* best = 0;
vector<Platform*>& platforms = getPlatforms();
......
......@@ -35,6 +35,7 @@
#include "Integrator.h"
#include "State.h"
#include "System.h"
#include <map>
#include <string>
#include <vector>
#include "internal/windowsExport.h"
......@@ -79,6 +80,16 @@ public:
* @param platform the Platform to use for calculations
*/
Context(System& system, Integrator& integrator, Platform& platform);
/**
* Construct a new Context in which to run a simulation, explicitly specifying what Platform should be used
* to perform calculations and the values of platform-specific properties.
*
* @param system the System which will be simulated
* @param integrator the Integrator which will be used to simulate the System
* @param platform the Platform to use for calculations
* @param properties a set of values for platform-specific properties. Keys are the property names.
*/
Context(System& system, Integrator& integrator, Platform& platform, const std::map<std::string, std::string>& properties);
~Context();
/**
* Get System being simulated in this context.
......@@ -155,6 +166,7 @@ public:
private:
friend class Platform;
ContextImpl* impl;
std::map<std::string, std::string> properties;
};
} // namespace OpenMM
......
......@@ -54,7 +54,7 @@ public:
/**
* Create an ContextImpl for a Context;
*/
ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform);
ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform, const std::map<std::string, std::string>& properties);
~ContextImpl();
/**
* Get the Context for which this is the implementation.
......
......@@ -36,12 +36,16 @@
using namespace OpenMM;
using namespace std;
Context::Context(System& system, Integrator& integrator) {
impl = new ContextImpl(*this, system, integrator, 0);
Context::Context(System& system, Integrator& integrator) : properties(map<string, string>()) {
impl = new ContextImpl(*this, system, integrator, 0, properties);
}
Context::Context(System& system, Integrator& integrator, Platform& platform) {
impl = new ContextImpl(*this, system, integrator, &platform);
Context::Context(System& system, Integrator& integrator, Platform& platform) : properties(map<string, string>()) {
impl = new ContextImpl(*this, system, integrator, &platform, properties);
}
Context::Context(System& system, Integrator& integrator, Platform& platform, const map<string, string>& properties) : properties(properties) {
impl = new ContextImpl(*this, system, integrator, &platform, properties);
}
Context::~Context() {
......@@ -121,5 +125,5 @@ void Context::reinitialize() {
Integrator& integrator = impl->getIntegrator();
Platform& platform = impl->getPlatform();
delete impl;
impl = new ContextImpl(*this, system, integrator, &platform);
impl = new ContextImpl(*this, system, integrator, &platform, properties);
}
......@@ -44,7 +44,7 @@ using std::map;
using std::vector;
using std::string;
ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform) :
ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties) :
owner(owner), system(system), integrator(integrator), platform(platform), platformData(NULL) {
vector<string> kernelNames;
kernelNames.push_back(CalcKineticEnergyKernel::Name());
......@@ -63,7 +63,7 @@ ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator,
this->platform = platform = &Platform::findPlatform(kernelNames);
else if (!platform->supportsKernels(kernelNames))
throw OpenMMException("Specified a Platform for a Context which does not support all required kernels");
platform->contextCreated(*this);
platform->contextCreated(*this, properties);
initializeForcesKernel = platform->createKernel(CalcForcesAndEnergyKernel::Name(), *this);
dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl()).initialize(system);
kineticEnergyKernel = platform->createKernel(CalcKineticEnergyKernel::Name(), *this);
......
......@@ -51,7 +51,7 @@ public:
bool supportsDoublePrecision() const;
const std::string& getPropertyValue(const Context& context, const std::string& property) const;
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
void contextCreated(ContextImpl& context) const;
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void contextDestroyed(ContextImpl& context) const;
/**
* This is the name of the parameter for selecting which CUDA device to use.
......
......@@ -88,13 +88,16 @@ const string& CudaPlatform::getPropertyValue(const Context& context, const strin
void CudaPlatform::setPropertyValue(Context& context, const string& property, const string& value) const {
}
void CudaPlatform::contextCreated(ContextImpl& context) const {
void CudaPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
unsigned int device = 0;
const string& devicePropValue = getPropertyDefaultValue(CudaDevice());
const string& devicePropValue = (properties.find(CudaDevice()) == properties.end() ?
getPropertyDefaultValue(CudaDevice()) : properties.find(CudaDevice())->second);
if (devicePropValue.length() > 0)
stringstream(devicePropValue) >> device;
int numParticles = context.getSystem().getNumParticles();
_gpuContext* gpu = (_gpuContext*) gpuInit(numParticles, device, getPropertyDefaultValue(CudaUseBlockingSync()) == "true");
const string& blockingSync = (properties.find(CudaUseBlockingSync()) == properties.end() ?
getPropertyDefaultValue(CudaUseBlockingSync()) : properties.find(CudaUseBlockingSync())->second);
_gpuContext* gpu = (_gpuContext*) gpuInit(numParticles, device, blockingSync == "true");
context.setPlatformData(new PlatformData(gpu));
}
......
......@@ -53,7 +53,7 @@ public:
bool supportsDoublePrecision() const;
const std::string& getPropertyValue(const Context& context, const std::string& property) const;
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
void contextCreated(ContextImpl& context) const;
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void contextDestroyed(ContextImpl& context) const;
/**
* This is the name of the parameter for selecting which OpenCL device to use.
......
......@@ -84,9 +84,10 @@ const string& OpenCLPlatform::getPropertyValue(const Context& context, const str
void OpenCLPlatform::setPropertyValue(Context& context, const string& property, const string& value) const {
}
void OpenCLPlatform::contextCreated(ContextImpl& context) const {
void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
unsigned int deviceIndex = -1;
const string& devicePropValue = getPropertyDefaultValue(OpenCLDeviceIndex());
const string& devicePropValue = (properties.find(OpenCLDeviceIndex()) == properties.end() ?
getPropertyDefaultValue(OpenCLDeviceIndex()) : properties.find(OpenCLDeviceIndex())->second);
if (devicePropValue.length() > 0)
stringstream(devicePropValue) >> deviceIndex;
int numParticles = context.getSystem().getNumParticles();
......
......@@ -53,7 +53,7 @@ public:
return 1;
}
bool supportsDoublePrecision() const;
void contextCreated(ContextImpl& context) const;
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void contextDestroyed(ContextImpl& context) const;
};
......
......@@ -36,6 +36,7 @@
#include "SimTKUtilities/SimTKOpenMMRealType.h"
using namespace OpenMM;
using namespace std;
ReferencePlatform::ReferencePlatform() {
ReferenceKernelFactory* factory = new ReferenceKernelFactory();
......@@ -66,7 +67,7 @@ bool ReferencePlatform::supportsDoublePrecision() const {
return (sizeof(RealOpenMM) >= sizeof(double));
}
void ReferencePlatform::contextCreated(ContextImpl& context) const {
void ReferencePlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
context.setPlatformData(new PlatformData(context.getSystem().getNumParticles()));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment