Unverified Commit add95438 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

API for querying devices (#5192)

* API for querying devices

* CUDA and HIP implementations of getDevices()

* Fix test failures

* Fix test failures

* CUDA returns correct devices even if no context has been created

* Return a single device for Reference and CPU

* Fix CI failure
parent 8eeee16d
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2024 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -122,6 +122,27 @@ public: ...@@ -122,6 +122,27 @@ public:
* @param value the value to set for the property * @param value the value to set for the property
*/ */
void setPropertyDefaultValue(const std::string& property, const std::string& value); void setPropertyDefaultValue(const std::string& property, const std::string& value);
/**
* Get a list of available devices for this Platform.
*
* This method is relevant to Platforms that offer a choice of devices to run calculations on.
* For example, if a Platform does calculations on GPUs and it is running on a computer with
* multiple GPUs, you can choose which one to use by specifying Platform-specific properties.
*
* The returned value contains one entry for each available device. Each entry contains the
* values of any properties that are specific to the device, such as its name or index. You
* can select a device for a Context by passing the entry to its constructor.
*
* You can optionally pass one or more property values to this method, in which case they act
* as filters. Devices are only returned if they are compatible with the specified values.
* For example, some Platforms offer a choice of precision modes, but not all devices support
* all precision modes. If you specify double precision, only devices that support double
* precision will be returned.
*
* Some Platforms do not offer a choice of devices. In those cases, this method returns a
* single entry that contains no properties.
*/
virtual std::vector<std::map<std::string, std::string> > getDevices(const std::map<std::string, std::string>& filters={}) const;
/** /**
* This is called whenever a new Context is created. It gives the Platform a chance to initialize * This is called whenever a new Context is created. It gives the Platform a chance to initialize
* the context and store platform-specific data in it. * the context and store platform-specific data in it.
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2024 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -109,6 +109,10 @@ void Platform::setPropertyDefaultValue(const string& property, const string& val ...@@ -109,6 +109,10 @@ void Platform::setPropertyDefaultValue(const string& property, const string& val
throw OpenMMException("setPropertyDefaultValue: Illegal property name"); throw OpenMMException("setPropertyDefaultValue: Illegal property name");
} }
vector<map<string, string> > Platform::getDevices(const map<string, string>& filters) const {
return {{}};
}
void Platform::contextCreated(ContextImpl& context, const map<string, string>& properties) const { void Platform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
} }
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2009-2025 Stanford University and the Authors. * * Portions copyright (c) 2009-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -578,6 +578,12 @@ public: ...@@ -578,6 +578,12 @@ public:
* Get the flags that should be used when creating CUevent objects. * Get the flags that should be used when creating CUevent objects.
*/ */
unsigned int getEventFlags(); unsigned int getEventFlags();
/**
* Ensure that CUDA has been initialized. This usually does not need to be called directly, because
* it is called automatically when a CudaContext is created. You can call it if you want to be sure
* CUDA has been initialized without creating a CudaContext.
*/
static void ensureCudaInitialized();
private: private:
/** /**
* Compute a sorted list of device indices in decreasing order of desirability * Compute a sorted list of device indices in decreasing order of desirability
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2023 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -50,6 +50,7 @@ public: ...@@ -50,6 +50,7 @@ public:
bool supportsDoublePrecision() const; bool supportsDoublePrecision() const;
const std::string& getPropertyValue(const Context& context, const std::string& property) const; const std::string& getPropertyValue(const Context& context, const std::string& property) const;
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const; void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
std::vector<std::map<std::string, std::string> > getDevices(const std::map<std::string, std::string>& filters={}) const;
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const; void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const; void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const;
void contextDestroyed(ContextImpl& context) const; void contextDestroyed(ContextImpl& context) const;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2009-2025 Stanford University and the Authors. * * Portions copyright (c) 2009-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -88,10 +88,7 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking ...@@ -88,10 +88,7 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
pinnedBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), useBlockingSync(useBlockingSync) { pinnedBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), useBlockingSync(useBlockingSync) {
int cudaDriverVersion; int cudaDriverVersion;
cuDriverGetVersion(&cudaDriverVersion); cuDriverGetVersion(&cudaDriverVersion);
if (!hasInitializedCuda) { ensureCudaInitialized();
CHECK_RESULT2(cuInit(0), "Error initializing CUDA");
hasInitializedCuda = true;
}
if (precision == "single") { if (precision == "single") {
useDoublePrecision = false; useDoublePrecision = false;
useMixedPrecision = false; useMixedPrecision = false;
...@@ -882,3 +879,10 @@ unsigned int CudaContext::getEventFlags() { ...@@ -882,3 +879,10 @@ unsigned int CudaContext::getEventFlags() {
flags += CU_EVENT_BLOCKING_SYNC; flags += CU_EVENT_BLOCKING_SYNC;
return flags; return flags;
} }
void CudaContext::ensureCudaInitialized() {
if (!hasInitializedCuda) {
CHECK_RESULT2(cuInit(0), "Error initializing CUDA");
hasInitializedCuda = true;
}
}
\ No newline at end of file
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -163,6 +163,48 @@ const string& CudaPlatform::getPropertyValue(const Context& context, const strin ...@@ -163,6 +163,48 @@ const string& CudaPlatform::getPropertyValue(const Context& context, const strin
void CudaPlatform::setPropertyValue(Context& context, const string& property, const string& value) const { void CudaPlatform::setPropertyValue(Context& context, const string& property, const string& value) const {
} }
vector<map<string, string> > CudaPlatform::getDevices(const map<string, string>& filters) const {
try {
CudaContext::ensureCudaInitialized();
}
catch (...) {
// CUDA couldn't be initialized, so report no devices.
return {};
}
// Check for properties that might act as filters.
int deviceIndex = -1;
if (filters.find(CudaDeviceIndex()) != filters.end())
stringstream(filters.at(CudaDeviceIndex())) >> deviceIndex;
string deviceName = (filters.find(CudaDeviceName()) == filters.end() ? "" : filters.at(CudaDeviceName()));
// Loop over devices.
vector<map<string, string> > results;
int numDevices;
if (cuDeviceGetCount(&numDevices) != CUDA_SUCCESS)
numDevices = 0;
for (int i = 0; i < numDevices; i++) {
if (deviceIndex != -1 && deviceIndex != i)
continue;
char name[1000];
CUdevice device;
CHECK_RESULT(cuDeviceGet(&device, i), "Error querying device");
CHECK_RESULT(cuDeviceGetName(name, 1000, device), "Error querying device name");
stringstream deviceNameStr;
deviceNameStr << name;
if (deviceName.size() > 0 && deviceName != deviceNameStr.str())
continue;
stringstream deviceIndexStr;
deviceIndexStr << i;
map<string, string> properties = {{CudaDeviceIndex(), deviceIndexStr.str()},
{CudaDeviceName(), deviceNameStr.str()}};
results.push_back(properties);
}
return results;
}
void CudaPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const { void CudaPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
const string& devicePropValue = (properties.find(CudaDeviceIndex()) == properties.end() ? const string& devicePropValue = (properties.find(CudaDeviceIndex()) == properties.end() ?
getPropertyDefaultValue(CudaDeviceIndex()) : properties.find(CudaDeviceIndex())->second); getPropertyDefaultValue(CudaDeviceIndex()) : properties.find(CudaDeviceIndex())->second);
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Portions copyright (c) 2020 Advanced Micro Devices, Inc. * * Portions copyright (c) 2020 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis * * Authors: Peter Eastman, Nicholas Curtis *
* Contributors: * * Contributors: *
...@@ -51,6 +51,7 @@ public: ...@@ -51,6 +51,7 @@ public:
bool supportsDoublePrecision() const; bool supportsDoublePrecision() const;
const std::string& getPropertyValue(const Context& context, const std::string& property) const; const std::string& getPropertyValue(const Context& context, const std::string& property) const;
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const; void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
std::vector<std::map<std::string, std::string> > getDevices(const std::map<std::string, std::string>& filters={}) const;
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const; void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const; void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const;
void contextDestroyed(ContextImpl& context) const; void contextDestroyed(ContextImpl& context) const;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Portions copyright (c) 2020 Advanced Micro Devices, Inc. * * Portions copyright (c) 2020 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis * * Authors: Peter Eastman, Nicholas Curtis *
* Contributors: * * Contributors: *
...@@ -163,6 +163,40 @@ const string& HipPlatform::getPropertyValue(const Context& context, const string ...@@ -163,6 +163,40 @@ const string& HipPlatform::getPropertyValue(const Context& context, const string
void HipPlatform::setPropertyValue(Context& context, const string& property, const string& value) const { void HipPlatform::setPropertyValue(Context& context, const string& property, const string& value) const {
} }
vector<map<string, string> > HipPlatform::getDevices(const map<string, string>& filters) const {
// Check for properties that might act as filters.
int deviceIndex = -1;
if (filters.find(HipDeviceIndex()) != filters.end())
stringstream(filters.at(HipDeviceIndex())) >> deviceIndex;
string deviceName = (filters.find(HipDeviceName()) == filters.end() ? "" : filters.at(HipDeviceName()));
// Loop over devices.
vector<map<string, string> > results;
int numDevices;
if (hipGetDeviceCount(&numDevices) != hipSuccess)
numDevices = 0;
for (int i = 0; i < numDevices; i++) {
if (deviceIndex != -1 && deviceIndex != i)
continue;
char name[1000];
hipDevice_t device;
CHECK_RESULT(hipDeviceGet(&device, i), "Error querying device");
CHECK_RESULT(hipDeviceGetName(name, 1000, device), "Error querying device name");
stringstream deviceNameStr;
deviceNameStr << name;
if (deviceName.size() > 0 && deviceName != deviceNameStr.str())
continue;
stringstream deviceIndexStr;
deviceIndexStr << i;
map<string, string> properties = {{HipDeviceIndex(), deviceIndexStr.str()},
{HipDeviceName(), deviceNameStr.str()}};
results.push_back(properties);
}
return results;
}
void HipPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const { void HipPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
const string& devicePropValue = (properties.find(HipDeviceIndex()) == properties.end() ? const string& devicePropValue = (properties.find(HipDeviceIndex()) == properties.end() ?
getPropertyDefaultValue(HipDeviceIndex()) : properties.find(HipDeviceIndex())->second); getPropertyDefaultValue(HipDeviceIndex()) : properties.find(HipDeviceIndex())->second);
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -50,6 +50,7 @@ public: ...@@ -50,6 +50,7 @@ public:
static bool isPlatformSupported(); static bool isPlatformSupported();
const std::string& getPropertyValue(const Context& context, const std::string& property) const; const std::string& getPropertyValue(const Context& context, const std::string& property) const;
void setPropertyValue(Context& context, const std::string& property, const std::string& value) const; void setPropertyValue(Context& context, const std::string& property, const std::string& value) const;
std::vector<std::map<std::string, std::string> > getDevices(const std::map<std::string, std::string>& filters={}) const;
void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const; void contextCreated(ContextImpl& context, const std::map<std::string, std::string>& properties) const;
void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const; void linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const;
void contextDestroyed(ContextImpl& context) const; void contextDestroyed(ContextImpl& context) const;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -178,6 +178,66 @@ const string& OpenCLPlatform::getPropertyValue(const Context& context, const str ...@@ -178,6 +178,66 @@ const string& OpenCLPlatform::getPropertyValue(const Context& context, const str
void OpenCLPlatform::setPropertyValue(Context& context, const string& property, const string& value) const { void OpenCLPlatform::setPropertyValue(Context& context, const string& property, const string& value) const {
} }
vector<map<string, string> > OpenCLPlatform::getDevices(const map<string, string>& filters) const {
// Check for properties that might act as filters.
int platformIndex = -1;
if (filters.find(OpenCLPlatformIndex()) != filters.end())
stringstream(filters.at(OpenCLPlatformIndex())) >> platformIndex;
string platformName = (filters.find(OpenCLPlatformName()) == filters.end() ? "" : filters.at(OpenCLPlatformName()));
int deviceIndex = -1;
if (filters.find(OpenCLDeviceIndex()) != filters.end())
stringstream(filters.at(OpenCLDeviceIndex())) >> deviceIndex;
string deviceName = (filters.find(OpenCLDeviceName()) == filters.end() ? "" : filters.at(OpenCLDeviceName()));
bool needsDouble = false;
if (filters.find(OpenCLPrecision()) != filters.end()) {
string precision = filters.at(OpenCLPrecision());
transform(precision.begin(), precision.end(), precision.begin(), ::tolower);
needsDouble = (precision != "single");
}
// Loop over platforms.
vector<map<string, string> > results;
vector<cl::Platform> platforms;
cl::Platform::get(&platforms);
for (int i = 0; i < platforms.size(); i++) {
if (platformIndex != -1 && platformIndex != i)
continue;
if (platformName.size() > 0 && platformName != platforms[i].getInfo<CL_PLATFORM_NAME>())
continue;
// Loop over devices for the platform.
vector<cl::Device> devices;
try {
platforms[i].getDevices(CL_DEVICE_TYPE_GPU | CL_DEVICE_TYPE_CPU, &devices);
}
catch (...) {
// There are no devices available for this platform.
continue;
}
for (int j = 0; j < devices.size(); j++) {
if (deviceIndex != -1 && deviceIndex != j)
continue;
if (deviceName.size() > 0 && deviceName != devices[j].getInfo<CL_DEVICE_NAME>())
continue;
bool supportsDouble = (devices[j].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
if (needsDouble && !supportsDouble)
continue;
stringstream platformIndexStr, deviceIndexStr;
platformIndexStr << i;
deviceIndexStr << j;
map<string, string> properties = {{OpenCLPlatformIndex(), platformIndexStr.str()},
{OpenCLPlatformName(), platforms[i].getInfo<CL_PLATFORM_NAME>()},
{OpenCLDeviceIndex(), deviceIndexStr.str()},
{OpenCLDeviceName(), devices[j].getInfo<CL_DEVICE_NAME>()}};
results.push_back(properties);
}
}
return results;
}
void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const { void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
const string& platformPropValue = (properties.find(OpenCLPlatformIndex()) == properties.end() ? const string& platformPropValue = (properties.find(OpenCLPlatformIndex()) == properties.end() ?
getPropertyDefaultValue(OpenCLPlatformIndex()) : properties.find(OpenCLPlatformIndex())->second); getPropertyDefaultValue(OpenCLPlatformIndex()) : properties.find(OpenCLPlatformIndex())->second);
......
...@@ -73,6 +73,7 @@ class WrapperGenerator: ...@@ -73,6 +73,7 @@ class WrapperGenerator:
'const std::vector<std::vector<int> >& OpenMM::Context::getMolecules', 'const std::vector<std::vector<int> >& OpenMM::Context::getMolecules',
'static std::vector<std::string> OpenMM::Platform::getPluginLoadFailures', 'static std::vector<std::string> OpenMM::Platform::getPluginLoadFailures',
'static std::vector<std::string> OpenMM::Platform::loadPluginsFromDirectory', 'static std::vector<std::string> OpenMM::Platform::loadPluginsFromDirectory',
'virtual std::vector<std::map<std::string, std::string> > OpenMM::Platform::getDevices',
'Vec3 OpenMM::LocalCoordinatesSite::getOriginWeights', 'Vec3 OpenMM::LocalCoordinatesSite::getOriginWeights',
'Vec3 OpenMM::LocalCoordinatesSite::getXWeights', 'Vec3 OpenMM::LocalCoordinatesSite::getXWeights',
'Vec3 OpenMM::LocalCoordinatesSite::getYWeights', 'Vec3 OpenMM::LocalCoordinatesSite::getYWeights',
......
...@@ -18,6 +18,7 @@ namespace std { ...@@ -18,6 +18,7 @@ namespace std {
%template(vectorstring) vector<string>; %template(vectorstring) vector<string>;
%template(mapstringstring) map<string,string>; %template(mapstringstring) map<string,string>;
%template(mapstringdouble) map<string,double>; %template(mapstringdouble) map<string,double>;
%template(vectormapstringstring) vector<map<string,string> >;
%template(mapii) map<int,int>; %template(mapii) map<int,int>;
%template(seti) set<int>; %template(seti) set<int>;
}; };
......
...@@ -472,6 +472,7 @@ UNITS = { ...@@ -472,6 +472,7 @@ UNITS = {
("GayBerneForce", "addParticle") : (None, ("unit.nanometer", "unit.kilojoule_per_mole", None, None, "unit.nanometer", "unit.nanometer", "unit.nanometer", None, None, None)), ("GayBerneForce", "addParticle") : (None, ("unit.nanometer", "unit.kilojoule_per_mole", None, None, "unit.nanometer", "unit.nanometer", "unit.nanometer", None, None, None)),
("GayBerneForce", "getParticleParameters") : (None, ("unit.nanometer", "unit.kilojoule_per_mole", None, None, "unit.nanometer", "unit.nanometer", "unit.nanometer", None, None, None)), ("GayBerneForce", "getParticleParameters") : (None, ("unit.nanometer", "unit.kilojoule_per_mole", None, None, "unit.nanometer", "unit.nanometer", "unit.nanometer", None, None, None)),
("GayBerneForce", "setParticleParameters") : (None, (None, "unit.nanometer", "unit.kilojoule_per_mole", None, None, "unit.nanometer", "unit.nanometer", "unit.nanometer", None, None, None)), ("GayBerneForce", "setParticleParameters") : (None, (None, "unit.nanometer", "unit.kilojoule_per_mole", None, None, "unit.nanometer", "unit.nanometer", "unit.nanometer", None, None, None)),
("Platform", "getDevices") : (None, (None,)),
("Platform", "getDefaultPluginsDirectory") : (None, ()), ("Platform", "getDefaultPluginsDirectory") : (None, ()),
("Platform", "getPropertyDefaultValue") : (None, ()), ("Platform", "getPropertyDefaultValue") : (None, ()),
("Platform", "getPropertyNames") : (None, ()), ("Platform", "getPropertyNames") : (None, ()),
......
...@@ -296,6 +296,30 @@ class TestSimulation(unittest.TestCase): ...@@ -296,6 +296,30 @@ class TestSimulation(unittest.TestCase):
simulation.minimizeEnergy(reporter=reporter) simulation.minimizeEnergy(reporter=reporter)
assert not reporter.error assert not reporter.error
def testSelectDevice(self):
"""Test querying and selecting devices to run on."""
pdb = PDBFile('systems/alanine-dipeptide-implicit.pdb')
ff = ForceField('amber99sb.xml', 'tip3p.xml')
system = ff.createSystem(pdb.topology)
for i in range(Platform.getNumPlatforms()):
platform = Platform.getPlatform(i)
devices = platform.getDevices()
if platform.getName() in ['Reference', 'CPU']:
assert len(devices) == 1
else:
for device in devices:
integrator = LangevinIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)
try:
simulation = Simulation(pdb.topology, system, integrator, platform, device)
except:
# This can happen if a device can't be supported.
continue
for key, value in device.items():
assert platform.getPropertyValue(simulation.context, key) == value
for j in range(len(devices)):
for k in range(j):
assert devices[j] != devices[k]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment