Unverified Commit e62bdf6a authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

API improvements (#4437)

* Can use getPlatform() instead of getPlatformByName()

* More concise arguments for getState()
parent 71f4b3fc
...@@ -125,7 +125,7 @@ Both of these classes should be packaged into a dynamic library (.so on Linux, ...@@ -125,7 +125,7 @@ Both of these classes should be packaged into a dynamic library (.so on Linux,
must also implement the two functions from PluginInitializer.h. must also implement the two functions from PluginInitializer.h.
:code:`registerPlatforms()` will do nothing, since this plugin does not :code:`registerPlatforms()` will do nothing, since this plugin does not
implement any new Platforms. :code:`registerKernelFactories()` should call implement any new Platforms. :code:`registerKernelFactories()` should call
\ :code:`Platform::getPlatformByName("OpenCL")` to get the OpenCL Platform, \ :code:`Platform::getPlatform("OpenCL")` to get the OpenCL Platform,
then create a new OpenCLStringForceKernelFactory and call then create a new OpenCLStringForceKernelFactory and call
:code:`registerKernelFactory()` on the Platform to register it. If the OpenCL :code:`registerKernelFactory()` on the Platform to register it. If the OpenCL
Platform is not available, you should catch the exception then return without Platform is not available, you should catch the exception then return without
......
...@@ -428,7 +428,7 @@ of the :class:`Platform` to use. This overrides the default logic. ...@@ -428,7 +428,7 @@ of the :class:`Platform` to use. This overrides the default logic.
:class:`Simulation`. The following lines specify to use the :class:`CUDA` platform: :class:`Simulation`. The following lines specify to use the :class:`CUDA` platform:
:: ::
platform = Platform.getPlatformByName('CUDA') platform = Platform.getPlatform('CUDA')
simulation = Simulation(prmtop.topology, system, integrator, platform) simulation = Simulation(prmtop.topology, system, integrator, platform)
The platform name should be one of :code:`OpenCL`, :code:`CUDA`, :code:`CPU`, or The platform name should be one of :code:`OpenCL`, :code:`CUDA`, :code:`CPU`, or
...@@ -441,7 +441,7 @@ work across two different GPUs (CUDA devices 0 and 1), doing all computations in ...@@ -441,7 +441,7 @@ work across two different GPUs (CUDA devices 0 and 1), doing all computations in
double precision: double precision:
:: ::
platform = Platform.getPlatformByName('CUDA') platform = Platform.getPlatform('CUDA')
properties = {'DeviceIndex': '0,1', 'Precision': 'double'} properties = {'DeviceIndex': '0,1', 'Precision': 'double'}
simulation = Simulation(prmtop.topology, system, integrator, platform, properties) simulation = Simulation(prmtop.topology, system, integrator, platform, properties)
......
...@@ -248,7 +248,7 @@ PDB file. ...@@ -248,7 +248,7 @@ PDB file.
simulation.context.setPositions(modeller.positions) simulation.context.setPositions(modeller.positions)
simulation.minimizeEnergy(maxIterations=100) simulation.minimizeEnergy(maxIterations=100)
print('Saving...') print('Saving...')
positions = simulation.context.getState(getPositions=True).getPositions() positions = simulation.context.getState(positions=True).getPositions()
PDBFile.writeFile(simulation.topology, positions, open('output.pdb', 'w')) PDBFile.writeFile(simulation.topology, positions, open('output.pdb', 'w'))
print('Done') print('Done')
......
...@@ -209,7 +209,7 @@ the potential energy of each one. Assume we have already created our :class:`Sy ...@@ -209,7 +209,7 @@ the potential energy of each one. Assume we have already created our :class:`Sy
for file in os.listdir('structures'): for file in os.listdir('structures'):
pdb = PDBFile(os.path.join('structures', file)) pdb = PDBFile(os.path.join('structures', file))
simulation.context.setPositions(pdb.positions) simulation.context.setPositions(pdb.positions)
state = simulation.context.getState(getEnergy=True) state = simulation.context.getState(energy=True)
print(file, state.getPotentialEnergy()) print(file, state.getPotentialEnergy())
.. caption:: .. caption::
...@@ -220,6 +220,6 @@ We use Python’s :code:`listdir()` function to list all the files in the ...@@ -220,6 +220,6 @@ We use Python’s :code:`listdir()` function to list all the files in the
directory. We create a :class:`PDBFile` object for each one and call directory. We create a :class:`PDBFile` object for each one and call
:meth:`setPositions()` on the Context to specify the particle positions loaded :meth:`setPositions()` on the Context to specify the particle positions loaded
from the PDB file. We then compute the energy by calling :meth:`getState()` from the PDB file. We then compute the energy by calling :meth:`getState()`
with the option :code:`getEnergy=True`\ , and print it to the console along with the option :code:`energy=True`\ , and print it to the console along
with the name of the file. with the name of the file.
...@@ -12,7 +12,7 @@ Context constructor: ...@@ -12,7 +12,7 @@ Context constructor:
.. code-block:: c .. code-block:: c
Platform& platform = Platform::getPlatformByName("OpenCL"); Platform& platform = Platform::getPlatform("OpenCL");
map<string, string> properties; map<string, string> properties;
properties["DeviceIndex"] = "1"; properties["DeviceIndex"] = "1";
Context context(system, integrator, platform, properties); Context context(system, integrator, platform, properties);
......
...@@ -588,7 +588,7 @@ notable differences: ...@@ -588,7 +588,7 @@ notable differences:
available. For example: available. For example:
:: ::
myContext.getState(getEnergy=True, getForce=False, …) myContext.getState(energy=True, force=False, …)
#. Wherever the C++ API uses references to return multiple values from a method, #. Wherever the C++ API uses references to return multiple values from a method,
the Python API returns a tuple. For example, in C++ you would query a the Python API returns a tuple. For example, in C++ you would query a
......
...@@ -129,7 +129,7 @@ for (lambda_index, lambda_value) in enumerate(lambda_values): ...@@ -129,7 +129,7 @@ for (lambda_index, lambda_value) in enumerate(lambda_values):
integrator.step(nprod_steps) integrator.step(nprod_steps)
# Get coordinates. # Get coordinates.
state = context.getState(getPositions=True) state = context.getState(positions=True)
positions = state.getPositions(asNumpy=True) positions = state.getPositions(asNumpy=True)
# Store positions. # Store positions.
...@@ -149,7 +149,7 @@ for (lambda_index, lambda_value) in enumerate(lambda_values): ...@@ -149,7 +149,7 @@ for (lambda_index, lambda_value) in enumerate(lambda_values):
# Compute reduced potentials of all snapshots. # Compute reduced potentials of all snapshots.
for n in range(nprod_iterations): for n in range(nprod_iterations):
context.setPositions(position_history[n]) context.setPositions(position_history[n])
state = context.getState(getEnergy=True) state = context.getState(energy=True)
u_kln[lambda_index, l, n] = beta * state.getPotentialEnergy() u_kln[lambda_index, l, n] = beta * state.getPotentialEnergy()
# Clean up. # Clean up.
......
...@@ -99,10 +99,10 @@ def printTestResult(test_result, options): ...@@ -99,10 +99,10 @@ def printTestResult(test_result, options):
def timeIntegration(context, steps, initialSteps): def timeIntegration(context, steps, initialSteps):
"""Integrate a Context for a specified number of steps, then return how many seconds it took.""" """Integrate a Context for a specified number of steps, then return how many seconds it took."""
context.getIntegrator().step(initialSteps) # Make sure everything is fully initialized context.getIntegrator().step(initialSteps) # Make sure everything is fully initialized
context.getState(getEnergy=True) context.getState(energy=True)
start = datetime.now() start = datetime.now()
context.getIntegrator().step(steps) context.getIntegrator().step(steps)
context.getState(getEnergy=True) context.getState(energy=True)
end = datetime.now() end = datetime.now()
elapsed = end-start elapsed = end-start
return elapsed.seconds + elapsed.microseconds*1e-6 return elapsed.seconds + elapsed.microseconds*1e-6
...@@ -347,7 +347,7 @@ def runOneTest(testName, options): ...@@ -347,7 +347,7 @@ def runOneTest(testName, options):
test_result['timestep_in_fs'] = dt.value_in_unit(unit.femtoseconds) test_result['timestep_in_fs'] = dt.value_in_unit(unit.femtoseconds)
properties = {} properties = {}
initialSteps = 5 initialSteps = 5
platform = mm.Platform.getPlatformByName(options.platform) platform = mm.Platform.getPlatform(options.platform)
if options.device is not None and 'DeviceIndex' in platform.getPropertyNames(): if options.device is not None and 'DeviceIndex' in platform.getPropertyNames():
properties['DeviceIndex'] = options.device properties['DeviceIndex'] = options.device
if ',' in options.device or ' ' in options.device: if ',' in options.device or ' ' in options.device:
...@@ -384,7 +384,7 @@ def runOneTest(testName, options): ...@@ -384,7 +384,7 @@ def runOneTest(testName, options):
tol = 1.0e-8 tol = 1.0e-8
context.applyConstraints(tol) context.applyConstraints(tol)
context.applyVelocityConstraints(tol) context.applyVelocityConstraints(tol)
state = context.getState(getPositions=True, getVelocities=True, getEnergy=True, getForces=True, getParameters=True) state = context.getState(positions=True, velocities=True, energy=True, forces=True, parameters=True)
# Time integration, ensuring we trigger kernel compilation before we start timing # Time integration, ensuring we trigger kernel compilation before we start timing
steps = 20 steps = 20
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2024 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -185,6 +185,11 @@ public: ...@@ -185,6 +185,11 @@ public:
* Get a registered Platform by index. * Get a registered Platform by index.
*/ */
static Platform& getPlatform(int index); static Platform& getPlatform(int index);
/**
* Get a registered Platform by name. If no Platform with that name has been
* registered, this throws an exception.
*/
static Platform& getPlatform(const std::string& name);
/** /**
* Get any failures caused during the last call to loadPluginsFromDirectory * Get any failures caused during the last call to loadPluginsFromDirectory
*/ */
...@@ -192,6 +197,9 @@ public: ...@@ -192,6 +197,9 @@ public:
/** /**
* Get the registered Platform with a particular name. If no Platform with that name has been * Get the registered Platform with a particular name. If no Platform with that name has been
* registered, this throws an exception. * registered, this throws an exception.
*
* This is identical to the version of getPlatform() that takes a name. It
* is here for backward compatibility.
*/ */
static Platform& getPlatformByName(const std::string& name); static Platform& getPlatformByName(const std::string& name);
/** /**
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2024 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -163,6 +163,10 @@ Platform& Platform::getPlatform(int index) { ...@@ -163,6 +163,10 @@ Platform& Platform::getPlatform(int index) {
throw OpenMMException("Invalid platform index"); throw OpenMMException("Invalid platform index");
} }
Platform& Platform::getPlatform(const string& name) {
return getPlatformByName(name);
}
std::vector<std::string> Platform::getPluginLoadFailures() { std::vector<std::string> Platform::getPluginLoadFailures() {
return pluginLoadFailures; return pluginLoadFailures;
} }
......
...@@ -603,7 +603,7 @@ void testLargeSystem() { ...@@ -603,7 +603,7 @@ void testLargeSystem() {
system.addForce(force); system.addForce(force);
VerletIntegrator integrator1(0.001); VerletIntegrator integrator1(0.001);
VerletIntegrator integrator2(0.001); VerletIntegrator integrator2(0.001);
Context context1(system, integrator1, Platform::getPlatformByName("Reference")); Context context1(system, integrator1, Platform::getPlatform("Reference"));
Context context2(system, integrator2, platform); Context context2(system, integrator2, platform);
context1.setPositions(positions); context1.setPositions(positions);
context2.setPositions(positions); context2.setPositions(positions);
...@@ -698,7 +698,7 @@ void testCentralParticleModeLargeSystem() { ...@@ -698,7 +698,7 @@ void testCentralParticleModeLargeSystem() {
system.addForce(force); system.addForce(force);
VerletIntegrator integrator1(0.001); VerletIntegrator integrator1(0.001);
VerletIntegrator integrator2(0.001); VerletIntegrator integrator2(0.001);
Context context1(system, integrator1, Platform::getPlatformByName("Reference")); Context context1(system, integrator1, Platform::getPlatform("Reference"));
Context context2(system, integrator2, platform); Context context2(system, integrator2, platform);
context1.setPositions(positions); context1.setPositions(positions);
context2.setPositions(positions); context2.setPositions(positions);
......
...@@ -70,7 +70,7 @@ void testTruncatedOctahedron() { ...@@ -70,7 +70,7 @@ void testTruncatedOctahedron() {
system.addForce(force); system.addForce(force);
VerletIntegrator integrator(0.01); VerletIntegrator integrator(0.01);
Context context(system, integrator, Platform::getPlatformByName("Reference")); Context context(system, integrator, Platform::getPlatform("Reference"));
context.setPositions(positions); context.setPositions(positions);
State initialState = context.getState(State::Positions | State::Energy, true); State initialState = context.getState(State::Positions | State::Energy, true);
for (int i = 0; i < numMolecules; i++) { for (int i = 0; i < numMolecules; i++) {
......
...@@ -304,7 +304,7 @@ void testTriclinic2() { ...@@ -304,7 +304,7 @@ void testTriclinic2() {
Context context1(system, integ1, platform); Context context1(system, integ1, platform);
context1.setPositions(positions); context1.setPositions(positions);
VerletIntegrator integ2(0.001); VerletIntegrator integ2(0.001);
Context context2(system, integ2, Platform::getPlatformByName("Reference")); Context context2(system, integ2, Platform::getPlatform("Reference"));
context2.setPositions(positions); context2.setPositions(positions);
State state1 = context1.getState(State::Forces | State::Energy, false, 2); State state1 = context1.getState(State::Forces | State::Energy, false, 2);
State state2 = context2.getState(State::Forces | State::Energy, false, 2); State state2 = context2.getState(State::Forces | State::Energy, false, 2);
......
...@@ -57,7 +57,7 @@ void testFindMolecules() { ...@@ -57,7 +57,7 @@ void testFindMolecules() {
bonds->addBond(index, index-1, 1.0, 1.0); bonds->addBond(index, index-1, 1.0, 1.0);
} }
VerletIntegrator integrator(1.0); VerletIntegrator integrator(1.0);
Context context(system, integrator, Platform::getPlatformByName("Reference")); Context context(system, integrator, Platform::getPlatform("Reference"));
ContextImpl* contextImpl = *reinterpret_cast<ContextImpl**>(&context); ContextImpl* contextImpl = *reinterpret_cast<ContextImpl**>(&context);
const vector<vector<int> >& molecules = contextImpl->getMolecules(); const vector<vector<int> >& molecules = contextImpl->getMolecules();
ASSERT_EQUAL(numMolecules, molecules.size()); ASSERT_EQUAL(numMolecules, molecules.size());
......
...@@ -468,7 +468,7 @@ void testOverlappingSites() { ...@@ -468,7 +468,7 @@ void testOverlappingSites() {
} }
VerletIntegrator i1(0.002); VerletIntegrator i1(0.002);
VerletIntegrator i2(0.002); VerletIntegrator i2(0.002);
Context c1(system, i1, Platform::getPlatformByName("Reference")); Context c1(system, i1, Platform::getPlatform("Reference"));
Context c2(system, i2, platform); Context c2(system, i2, platform);
c1.setPositions(positions); c1.setPositions(positions);
c2.setPositions(positions); c2.setPositions(positions);
......
...@@ -149,7 +149,7 @@ class CheckpointReporter(object): ...@@ -149,7 +149,7 @@ class CheckpointReporter(object):
self._file.seek(0) self._file.seek(0)
if self._writeState: if self._writeState:
state = simulation.context.getState(getPositions=True, getVelocities=True, getParameters=True, getIntegratorParameters=True) state = simulation.context.getState(positions=True, velocities=True, parameters=True, integratorParameters=True)
self._file.write(mm.XmlSerializer.serialize(state)) self._file.write(mm.XmlSerializer.serialize(state))
else: else:
self._file.write(simulation.context.createCheckpoint()) self._file.write(simulation.context.createCheckpoint())
......
...@@ -172,7 +172,7 @@ class Metadynamics(object): ...@@ -172,7 +172,7 @@ class Metadynamics(object):
simulation.step(nextSteps) simulation.step(nextSteps)
if simulation.currentStep % self.frequency == 0: if simulation.currentStep % self.frequency == 0:
position = self._force.getCollectiveVariableValues(simulation.context) position = self._force.getCollectiveVariableValues(simulation.context)
energy = simulation.context.getState(getEnergy=True, groups={forceGroup}).getPotentialEnergy() energy = simulation.context.getState(energy=True, groups={forceGroup}).getPotentialEnergy()
height = self.height*np.exp(-energy/(unit.MOLAR_GAS_CONSTANT_R*self._deltaT)) height = self.height*np.exp(-energy/(unit.MOLAR_GAS_CONSTANT_R*self._deltaT))
self._addGaussian(position, height, simulation.context) self._addGaussian(position, height, simulation.context)
if self.saveFrequency is not None and simulation.currentStep % self.saveFrequency == 0: if self.saveFrequency is not None and simulation.currentStep % self.saveFrequency == 0:
......
...@@ -1065,7 +1065,7 @@ class Modeller(object): ...@@ -1065,7 +1065,7 @@ class Modeller(object):
context.setPositions(newPositions) context.setPositions(newPositions)
LocalEnergyMinimizer.minimize(context, 1.0, 50) LocalEnergyMinimizer.minimize(context, 1.0, 50)
self.topology = newTopology self.topology = newTopology
self.positions = context.getState(getPositions=True).getPositions() self.positions = context.getState(positions=True).getPositions()
del context del context
return actualVariants return actualVariants
...@@ -1528,7 +1528,7 @@ class Modeller(object): ...@@ -1528,7 +1528,7 @@ class Modeller(object):
for i in range(steps): for i in range(steps):
weight1 = i/(steps-1) weight1 = i/(steps-1)
weight2 = 1.0-weight1 weight2 = 1.0-weight1
mergedPositions = context.getState(getPositions=True).getPositions(asNumpy=hasNumpy).value_in_unit(nanometer) mergedPositions = context.getState(positions=True).getPositions(asNumpy=hasNumpy).value_in_unit(nanometer)
if hasNumpy: if hasNumpy:
mergedPositions[numMembraneParticles:] = weight1*proteinPosArray + weight2*scaledProteinPosArray mergedPositions[numMembraneParticles:] = weight1*proteinPosArray + weight2*scaledProteinPosArray
else: else:
...@@ -1540,7 +1540,7 @@ class Modeller(object): ...@@ -1540,7 +1540,7 @@ class Modeller(object):
# Add the membrane to the protein. # Add the membrane to the protein.
modeller = Modeller(self.topology, self.positions) modeller = Modeller(self.topology, self.positions)
modeller.add(membraneTopology, context.getState(getPositions=True).getPositions()[:numMembraneParticles]) modeller.add(membraneTopology, context.getState(positions=True).getPositions()[:numMembraneParticles])
modeller.topology.setPeriodicBoxVectors(membraneTopology.getPeriodicBoxVectors()) modeller.topology.setPeriodicBoxVectors(membraneTopology.getPeriodicBoxVectors())
del context del context
del system del system
......
...@@ -262,8 +262,8 @@ class Simulation(object): ...@@ -262,8 +262,8 @@ class Simulation(object):
getForces = True getForces = True
if next[4]: if next[4]:
getEnergy = True getEnergy = True
state = self.context.getState(getPositions=getPositions, getVelocities=getVelocities, getForces=getForces, state = self.context.getState(positions=getPositions, velocities=getVelocities, forces=getForces,
getEnergy=getEnergy, getParameters=True, enforcePeriodicBox=periodic, energy=getEnergy, parameters=True, enforcePeriodicBox=periodic,
groups=self.context.getIntegrator().getIntegrationForceGroups()) groups=self.context.getIntegrator().getIntegrationForceGroups())
for reporter, next in reports: for reporter, next in reports:
reporter.report(self, state) reporter.report(self, state)
...@@ -325,7 +325,7 @@ class Simulation(object): ...@@ -325,7 +325,7 @@ class Simulation(object):
a File-like object to write the state to, or alternatively a a File-like object to write the state to, or alternatively a
filename filename
""" """
state = self.context.getState(getPositions=True, getVelocities=True, getParameters=True, getIntegratorParameters=True) state = self.context.getState(positions=True, velocities=True, parameters=True, integratorParameters=True)
xml = mm.XmlSerializer.serialize(state) xml = mm.XmlSerializer.serialize(state)
if isinstance(file, str): if isinstance(file, str):
with open(file, 'w') as f: with open(file, 'w') as f:
......
...@@ -52,7 +52,7 @@ def run_tests(): ...@@ -52,7 +52,7 @@ def run_tests():
try: try:
simulation = Simulation(pdb.topology, system, integrator, platform) simulation = Simulation(pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions) simulation.context.setPositions(pdb.positions)
forces[i] = simulation.context.getState(getForces=True).getForces() forces[i] = simulation.context.getState(forces=True).getForces()
del simulation del simulation
print("- Successfully computed forces") print("- Successfully computed forces")
except: except:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment