Unverified Commit e62bdf6a authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

API improvements (#4437)

* Can use getPlatform() instead of getPlatformByName()

* More concise arguments for getState()
parent 71f4b3fc
......@@ -125,7 +125,7 @@ Both of these classes should be packaged into a dynamic library (.so on Linux,
must also implement the two functions from PluginInitializer.h.
:code:`registerPlatforms()` will do nothing, since this plugin does not
implement any new Platforms. :code:`registerKernelFactories()` should call
\ :code:`Platform::getPlatformByName("OpenCL")` to get the OpenCL Platform,
\ :code:`Platform::getPlatform("OpenCL")` to get the OpenCL Platform,
then create a new OpenCLStringForceKernelFactory and call
:code:`registerKernelFactory()` on the Platform to register it. If the OpenCL
Platform is not available, you should catch the exception then return without
......
......@@ -428,7 +428,7 @@ of the :class:`Platform` to use. This overrides the default logic.
:class:`Simulation`. The following lines specify to use the :class:`CUDA` platform:
::
platform = Platform.getPlatformByName('CUDA')
platform = Platform.getPlatform('CUDA')
simulation = Simulation(prmtop.topology, system, integrator, platform)
The platform name should be one of :code:`OpenCL`, :code:`CUDA`, :code:`CPU`, or
......@@ -441,7 +441,7 @@ work across two different GPUs (CUDA devices 0 and 1), doing all computations in
double precision:
::
platform = Platform.getPlatformByName('CUDA')
platform = Platform.getPlatform('CUDA')
properties = {'DeviceIndex': '0,1', 'Precision': 'double'}
simulation = Simulation(prmtop.topology, system, integrator, platform, properties)
......
......@@ -248,7 +248,7 @@ PDB file.
simulation.context.setPositions(modeller.positions)
simulation.minimizeEnergy(maxIterations=100)
print('Saving...')
positions = simulation.context.getState(getPositions=True).getPositions()
positions = simulation.context.getState(positions=True).getPositions()
PDBFile.writeFile(simulation.topology, positions, open('output.pdb', 'w'))
print('Done')
......
......@@ -209,7 +209,7 @@ the potential energy of each one. Assume we have already created our :class:`Sy
for file in os.listdir('structures'):
pdb = PDBFile(os.path.join('structures', file))
simulation.context.setPositions(pdb.positions)
state = simulation.context.getState(getEnergy=True)
state = simulation.context.getState(energy=True)
print(file, state.getPotentialEnergy())
.. caption::
......@@ -220,6 +220,6 @@ We use Python’s :code:`listdir()` function to list all the files in the
directory. We create a :class:`PDBFile` object for each one and call
:meth:`setPositions()` on the Context to specify the particle positions loaded
from the PDB file. We then compute the energy by calling :meth:`getState()`
with the option :code:`getEnergy=True`\ , and print it to the console along
with the option :code:`energy=True`\ , and print it to the console along
with the name of the file.
......@@ -12,7 +12,7 @@ Context constructor:
.. code-block:: c
Platform& platform = Platform::getPlatformByName("OpenCL");
Platform& platform = Platform::getPlatform("OpenCL");
map<string, string> properties;
properties["DeviceIndex"] = "1";
Context context(system, integrator, platform, properties);
......
......@@ -588,7 +588,7 @@ notable differences:
available. For example:
::
myContext.getState(getEnergy=True, getForce=False, …)
myContext.getState(energy=True, force=False, …)
#. Wherever the C++ API uses references to return multiple values from a method,
the Python API returns a tuple. For example, in C++ you would query a
......
......@@ -129,7 +129,7 @@ for (lambda_index, lambda_value) in enumerate(lambda_values):
integrator.step(nprod_steps)
# Get coordinates.
state = context.getState(getPositions=True)
state = context.getState(positions=True)
positions = state.getPositions(asNumpy=True)
# Store positions.
......@@ -149,7 +149,7 @@ for (lambda_index, lambda_value) in enumerate(lambda_values):
# Compute reduced potentials of all snapshots.
for n in range(nprod_iterations):
context.setPositions(position_history[n])
state = context.getState(getEnergy=True)
state = context.getState(energy=True)
u_kln[lambda_index, l, n] = beta * state.getPotentialEnergy()
# Clean up.
......
......@@ -99,10 +99,10 @@ def printTestResult(test_result, options):
def timeIntegration(context, steps, initialSteps):
"""Integrate a Context for a specified number of steps, then return how many seconds it took."""
context.getIntegrator().step(initialSteps) # Make sure everything is fully initialized
context.getState(getEnergy=True)
context.getState(energy=True)
start = datetime.now()
context.getIntegrator().step(steps)
context.getState(getEnergy=True)
context.getState(energy=True)
end = datetime.now()
elapsed = end-start
return elapsed.seconds + elapsed.microseconds*1e-6
......@@ -347,7 +347,7 @@ def runOneTest(testName, options):
test_result['timestep_in_fs'] = dt.value_in_unit(unit.femtoseconds)
properties = {}
initialSteps = 5
platform = mm.Platform.getPlatformByName(options.platform)
platform = mm.Platform.getPlatform(options.platform)
if options.device is not None and 'DeviceIndex' in platform.getPropertyNames():
properties['DeviceIndex'] = options.device
if ',' in options.device or ' ' in options.device:
......@@ -384,7 +384,7 @@ def runOneTest(testName, options):
tol = 1.0e-8
context.applyConstraints(tol)
context.applyVelocityConstraints(tol)
state = context.getState(getPositions=True, getVelocities=True, getEnergy=True, getForces=True, getParameters=True)
state = context.getState(positions=True, velocities=True, energy=True, forces=True, parameters=True)
# Time integration, ensuring we trigger kernel compilation before we start timing
steps = 20
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -185,6 +185,11 @@ public:
* Get a registered Platform by index.
*/
static Platform& getPlatform(int index);
/**
* Get a registered Platform by name. If no Platform with that name has been
* registered, this throws an exception.
*/
static Platform& getPlatform(const std::string& name);
/**
* Get any failures caused during the last call to loadPluginsFromDirectory
*/
......@@ -192,6 +197,9 @@ public:
/**
* Get the registered Platform with a particular name. If no Platform with that name has been
* registered, this throws an exception.
*
* This is identical to the version of getPlatform() that takes a name. It
* is here for backward compatibility.
*/
static Platform& getPlatformByName(const std::string& name);
/**
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -163,6 +163,10 @@ Platform& Platform::getPlatform(int index) {
throw OpenMMException("Invalid platform index");
}
Platform& Platform::getPlatform(const string& name) {
return getPlatformByName(name);
}
std::vector<std::string> Platform::getPluginLoadFailures() {
return pluginLoadFailures;
}
......
......@@ -603,7 +603,7 @@ void testLargeSystem() {
system.addForce(force);
VerletIntegrator integrator1(0.001);
VerletIntegrator integrator2(0.001);
Context context1(system, integrator1, Platform::getPlatformByName("Reference"));
Context context1(system, integrator1, Platform::getPlatform("Reference"));
Context context2(system, integrator2, platform);
context1.setPositions(positions);
context2.setPositions(positions);
......@@ -698,7 +698,7 @@ void testCentralParticleModeLargeSystem() {
system.addForce(force);
VerletIntegrator integrator1(0.001);
VerletIntegrator integrator2(0.001);
Context context1(system, integrator1, Platform::getPlatformByName("Reference"));
Context context1(system, integrator1, Platform::getPlatform("Reference"));
Context context2(system, integrator2, platform);
context1.setPositions(positions);
context2.setPositions(positions);
......
......@@ -70,7 +70,7 @@ void testTruncatedOctahedron() {
system.addForce(force);
VerletIntegrator integrator(0.01);
Context context(system, integrator, Platform::getPlatformByName("Reference"));
Context context(system, integrator, Platform::getPlatform("Reference"));
context.setPositions(positions);
State initialState = context.getState(State::Positions | State::Energy, true);
for (int i = 0; i < numMolecules; i++) {
......
......@@ -304,7 +304,7 @@ void testTriclinic2() {
Context context1(system, integ1, platform);
context1.setPositions(positions);
VerletIntegrator integ2(0.001);
Context context2(system, integ2, Platform::getPlatformByName("Reference"));
Context context2(system, integ2, Platform::getPlatform("Reference"));
context2.setPositions(positions);
State state1 = context1.getState(State::Forces | State::Energy, false, 2);
State state2 = context2.getState(State::Forces | State::Energy, false, 2);
......
......@@ -57,7 +57,7 @@ void testFindMolecules() {
bonds->addBond(index, index-1, 1.0, 1.0);
}
VerletIntegrator integrator(1.0);
Context context(system, integrator, Platform::getPlatformByName("Reference"));
Context context(system, integrator, Platform::getPlatform("Reference"));
ContextImpl* contextImpl = *reinterpret_cast<ContextImpl**>(&context);
const vector<vector<int> >& molecules = contextImpl->getMolecules();
ASSERT_EQUAL(numMolecules, molecules.size());
......
......@@ -468,7 +468,7 @@ void testOverlappingSites() {
}
VerletIntegrator i1(0.002);
VerletIntegrator i2(0.002);
Context c1(system, i1, Platform::getPlatformByName("Reference"));
Context c1(system, i1, Platform::getPlatform("Reference"));
Context c2(system, i2, platform);
c1.setPositions(positions);
c2.setPositions(positions);
......
......@@ -149,7 +149,7 @@ class CheckpointReporter(object):
self._file.seek(0)
if self._writeState:
state = simulation.context.getState(getPositions=True, getVelocities=True, getParameters=True, getIntegratorParameters=True)
state = simulation.context.getState(positions=True, velocities=True, parameters=True, integratorParameters=True)
self._file.write(mm.XmlSerializer.serialize(state))
else:
self._file.write(simulation.context.createCheckpoint())
......
......@@ -172,7 +172,7 @@ class Metadynamics(object):
simulation.step(nextSteps)
if simulation.currentStep % self.frequency == 0:
position = self._force.getCollectiveVariableValues(simulation.context)
energy = simulation.context.getState(getEnergy=True, groups={forceGroup}).getPotentialEnergy()
energy = simulation.context.getState(energy=True, groups={forceGroup}).getPotentialEnergy()
height = self.height*np.exp(-energy/(unit.MOLAR_GAS_CONSTANT_R*self._deltaT))
self._addGaussian(position, height, simulation.context)
if self.saveFrequency is not None and simulation.currentStep % self.saveFrequency == 0:
......
......@@ -1065,7 +1065,7 @@ class Modeller(object):
context.setPositions(newPositions)
LocalEnergyMinimizer.minimize(context, 1.0, 50)
self.topology = newTopology
self.positions = context.getState(getPositions=True).getPositions()
self.positions = context.getState(positions=True).getPositions()
del context
return actualVariants
......@@ -1528,7 +1528,7 @@ class Modeller(object):
for i in range(steps):
weight1 = i/(steps-1)
weight2 = 1.0-weight1
mergedPositions = context.getState(getPositions=True).getPositions(asNumpy=hasNumpy).value_in_unit(nanometer)
mergedPositions = context.getState(positions=True).getPositions(asNumpy=hasNumpy).value_in_unit(nanometer)
if hasNumpy:
mergedPositions[numMembraneParticles:] = weight1*proteinPosArray + weight2*scaledProteinPosArray
else:
......@@ -1540,7 +1540,7 @@ class Modeller(object):
# Add the membrane to the protein.
modeller = Modeller(self.topology, self.positions)
modeller.add(membraneTopology, context.getState(getPositions=True).getPositions()[:numMembraneParticles])
modeller.add(membraneTopology, context.getState(positions=True).getPositions()[:numMembraneParticles])
modeller.topology.setPeriodicBoxVectors(membraneTopology.getPeriodicBoxVectors())
del context
del system
......
......@@ -262,8 +262,8 @@ class Simulation(object):
getForces = True
if next[4]:
getEnergy = True
state = self.context.getState(getPositions=getPositions, getVelocities=getVelocities, getForces=getForces,
getEnergy=getEnergy, getParameters=True, enforcePeriodicBox=periodic,
state = self.context.getState(positions=getPositions, velocities=getVelocities, forces=getForces,
energy=getEnergy, parameters=True, enforcePeriodicBox=periodic,
groups=self.context.getIntegrator().getIntegrationForceGroups())
for reporter, next in reports:
reporter.report(self, state)
......@@ -325,7 +325,7 @@ class Simulation(object):
a File-like object to write the state to, or alternatively a
filename
"""
state = self.context.getState(getPositions=True, getVelocities=True, getParameters=True, getIntegratorParameters=True)
state = self.context.getState(positions=True, velocities=True, parameters=True, integratorParameters=True)
xml = mm.XmlSerializer.serialize(state)
if isinstance(file, str):
with open(file, 'w') as f:
......
......@@ -52,7 +52,7 @@ def run_tests():
try:
simulation = Simulation(pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions)
forces[i] = simulation.context.getState(getForces=True).getForces()
forces[i] = simulation.context.getState(forces=True).getForces()
del simulation
print("- Successfully computed forces")
except:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment