Unverified Commit 9f17d0f8 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Infrastructure for multistate sampling (#5231)

* Infrastructure for multistate sampling

* Added computeRelativeEnergies()
parent e4728a21
"""
multistatesampler.py: Utilities for multistate sampling algorithms.
This is part of the OpenMM molecular simulation toolkit.
See https://openmm.org/development.
Portions copyright (c) 2026 Stanford University and the Authors.
Authors: Peter Eastman
Contributors:
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
import openmm
from openmm.unit import kilojoules_per_mole
"""This is a utility for internal use by algorithms that perform multistate sampling. Given the definitions of a set of
states, it can put a context into any state and efficiently evaluate the potential energies of states.
Each (not to be confused with a State object) is represented by a dict containing property values. All states must
specify values for the same properties. The following properties are supported.
- 'temperature': the simulation temperature. It knows how to set the temperature for all standard integrators and forces.
- 'groups': a set containing the force groups to include when computing the energy, for example {0, 2}. It also may be
a weighted sum specified as a dict. For example, {0:1.0, 2:0.5} means to return the energy of group 0 plus half the
energy of group 2.
- Context parameters
- Global variables defined by a CustomIntegrator
"""
class MultistateSampler(object):
def __init__(self, states, context):
"""Create a MultistateSampler.
Parameters
----------
states: list
the states to sample. Each entry should be a dict containing the property values for one state.
context: Context
the Context to apply the states to
"""
self.states = states
self.context = context
keys = set(states[0].keys())
for s in states:
if set(s.keys()) != keys:
raise ValueError('All states must include the same set of properties')
# Process the keys and determine how to set the simulation to a state.
self.parameters = [{} for _ in states]
energy_parameters = [{} for _ in states]
self.variables = [{} for _ in states]
self.temperature = None
self.groups = None
if isinstance(context.getIntegrator(), openmm.CustomIntegrator):
variables = set(context.getIntegrator().getGlobalVariableName(i) for i in range(context.getIntegrator().getNumGlobalVariables()))
else:
variables = set()
for key in keys:
if key == 'temperature':
# Setting the temperature can involve call setTemperature() on an integrator and setting context parameters.
if hasattr(context.getIntegrator(), 'setTemperature'):
self.temperature = [s['temperature'] for s in states]
for force in context.getSystem().getForces():
if hasattr(type(force), 'Temperature'):
param = force.Temperature()
if param in context.getParameters():
for i, s in enumerate(states):
self.parameters[i][param] = s['temperature']
if 'temperature' in variables:
for i, s in enumerate(states):
self.variables[i]['temperature'] = s['temperature']
elif key == 'groups':
# Force groups can be specified with either a set or a dict.
self.groups = []
for s in states:
groups = s['groups']
if isinstance(groups, set):
groups = {g:1.0 for g in groups}
self.groups.append(groups)
elif key in context.getParameters():
# A context parameter.
for i, s in enumerate(states):
self.parameters[i][key] = s[key]
energy_parameters[i][key] = s[key]
elif key in variables:
# A global variable of a CustomIntegrator.
for i, s in enumerate(states):
self.variables[i][key] = s[key]
else:
raise ValueError(f'Unknown property "{key}"')
# We now identify subsets of states whose energies can be evaluated more efficiently. Two states are in the
# same subset if they have identical context parameters (ignoring parameters used only to set temperature).
self.subsets = []
for i, s in enumerate(states):
match = None
for subset in self.subsets:
if energy_parameters[i] == energy_parameters[subset[0]]:
match = subset
break
if match is None:
self.subsets.append([i])
else:
match.append(i)
# States within a subset may still vary in what force groups they include. Figure out the most efficient way
# of evaluating all states in a subset.
if self.groups is not None:
self.groups_of_groups = [[] for _ in range(len(self.subsets))]
for i, subset in enumerate(self.subsets):
# Find all force groups that appear in this subset.
all_groups = set()
for j in subset:
for k in self.groups[j].keys():
all_groups.add(k)
# If two force groups always appear together with the same weight, they can be computed together.
while len(all_groups) > 0:
first = next(iter(all_groups))
group_of_groups = {first}
for group in list(all_groups)[1:]:
match = True
for j in subset:
if self.groups[j].get(first) != self.groups[j].get(group):
match = False
break
if match:
group_of_groups.add(group)
all_groups.remove(group)
self.groups_of_groups[i].append(group_of_groups)
all_groups.remove(first)
def applyState(self, index):
"""Modify the Context to match one of the states.
Parameters
----------
index: int
the index of the state in self.states
"""
if self.temperature is not None:
self.context.getIntegrator().setTemperature(self.temperature[index])
for name, value in self.parameters[index].items():
self.context.setParameter(name, value)
for name, value in self.variables[index].items():
self.context.getIntegrator().setGlobalVariableByName(name, value)
def computeEnergy(self, index):
"""Compute the potential energy of one of the states. This method calls applyState(), so when it returns, the
Context will be in the specified state.
If you need to compute the energies of all states, computeAllEnergies() is much more efficient than calling this
method for each one.
Parameters
----------
index: int
the index of the state in self.states
Returns
-------
the potential energy of the state
"""
self.applyState(index)
if self.groups is None:
return self.context.getState(energy=True).getPotentialEnergy()
weights = set(self.groups[index].values())
energy = 0*kilojoules_per_mole
for weight in weights:
groups = set(g for g, w in self.groups[index].items() if w == weight)
energy += weight*self.context.getState(energy=True, groups=groups).getPotentialEnergy()
return energy
def computeAllEnergies(self):
"""Compute the potential energies of states. This is much more efficient than calling computeEnergy() for each
state individually.
If you only care about energy differences between states, not absolute energies, use computeRelativeEnergies()
instead. It can be even more efficient.
When this method returns, it is undefined which state the Context will be in.
Returns
-------
an array containing the potential energies of all states in the order they appear in self.states
"""
energies = [0*kilojoules_per_mole for _ in self.states]
for i, subset in enumerate(self.subsets):
if self.groups is None:
# States don't depend on force groups, so we can just evaluate the energy of one state.
energy = self.computeEnergy(subset[0])
for j in subset:
energies[j] = energy
else:
# Compute the energy of each set of groups, and add them with the proper weights for each state.
self.applyState(subset[0])
for groups in self.groups_of_groups[i]:
energy = self.context.getState(energy=True, groups=groups).getPotentialEnergy()
for j in subset:
g = next(iter(groups))
if g in self.groups[j]:
energies[j] += energy*self.groups[j][g]
return energies
def computeRelativeEnergies(self):
"""This is similar to computeAllEnergies(), but the energies are shifted by a constant to make the energy of the
first state exactly zero. The advantage is that if states differ only in temperature or other ways that do not
affect energy, there is no need to compute any energies at all.
When this method returns, it is undefined which state the Context will be in.
Returns
-------
an array containing the shifted potential energies of all states in the order they appear in self.states
"""
if len(self.subsets) == 1 and self.groups is None:
return [0*kilojoules_per_mole]*len(self.states)
energies = self.computeAllEnergies()
return [e-energies[0] for e in energies]
\ No newline at end of file
from openmm import *
from openmm.unit import *
from openmm.app.internal.multistatesampler import MultistateSampler
import unittest
class TestMultistateSampler(unittest.TestCase):
def setUp(self):
# Create a Context for use in tests. Force groups should have the following energies:
# 1: 4*k1
# 2: 2*k2
# 3: 10
self.system = System()
self.system.addParticle(1.0)
self.system.addParticle(1.0)
force1 = CustomBondForce('k1*r^2')
force1.addGlobalParameter('k1', 1.0)
force1.addBond(0, 1)
force1.setForceGroup(1)
self.system.addForce(force1)
force2 = CustomBondForce('k2*r')
force2.addGlobalParameter('k2', 1.0)
force2.addBond(0, 1)
force2.setForceGroup(2)
self.system.addForce(force2)
force3 = CustomExternalForce('10+periodicdistance(x, y, z, 0, 0, 0)')
force3.addParticle(0)
force3.setForceGroup(3)
self.system.addForce(force3)
self.system.addForce(MonteCarloBarostat(1.0*bar, 300.0*kelvin))
integrator = LangevinIntegrator(300*kelvin, 1.0/picosecond, 0.004*picoseconds)
self.context = Context(self.system, integrator, Platform.getPlatform('Reference'))
self.context.setPositions([Vec3(0, 0, 0), Vec3(0, 2, 0)])
def validateEnergies(self, sampler, expected):
"""Check that all states have the correct energies when computed in various ways."""
for i in range(len(sampler.states)):
energy = sampler.computeEnergy(i)
self.assertAlmostEqual(expected[i], energy.value_in_unit(kilojoules_per_mole), 5)
energies = sampler.computeAllEnergies()
self.assertEqual(len(sampler.states), len(energies))
for e1, e2 in zip(expected, energies):
self.assertAlmostEqual(e1, e2.value_in_unit(kilojoules_per_mole), 5)
relative = sampler.computeRelativeEnergies()
self.assertEqual(len(sampler.states), len(relative))
for e1, e2 in zip(expected, relative):
self.assertAlmostEqual(e1-expected[0], (e2-relative[0]).value_in_unit(kilojoules_per_mole), 6)
if sampler.groups is None:
for i in range(len(sampler.states)):
sampler.applyState(i)
energy = sampler.context.getState(energy=True).getPotentialEnergy()
self.assertAlmostEqual(expected[i], energy.value_in_unit(kilojoules_per_mole), 5)
def testTemperatures(self):
"""Test a set of states that vary only in temperature."""
states = [{'temperature':t} for t in [300, 350, 400, 450, 500]]
sampler = MultistateSampler(states, self.context)
self.assertEqual(1, len(sampler.subsets))
for i in range(len(states)):
sampler.applyState(i)
self.assertAlmostEqual(states[i]['temperature'], self.context.getIntegrator().getTemperature().value_in_unit(kelvin), 5)
self.assertAlmostEqual(states[i]['temperature'], self.context.getParameter('MonteCarloTemperature'), 5)
self.validateEnergies(sampler, [16.0]*len(states))
def testGroups(self):
"""Test a set of states that use different force groups."""
states = [{'temperature':300, 'groups':{1}},
{'temperature':300, 'groups':{1, 2, 3}},
{'temperature':500, 'groups':{1:0.5}},
{'temperature':500, 'groups':{1:1.0, 2:0.5, 3:0.5}}]
sampler = MultistateSampler(states, self.context)
self.assertEqual(1, len(sampler.subsets))
self.assertEqual(2, len(sampler.groups_of_groups[0]))
for i in range(len(states)):
sampler.applyState(i)
self.assertAlmostEqual(states[i]['temperature'], self.context.getIntegrator().getTemperature().value_in_unit(kelvin), 5)
self.assertAlmostEqual(states[i]['temperature'], self.context.getParameter('MonteCarloTemperature'), 5)
self.validateEnergies(sampler, [4.0, 16.0, 2.0, 10.0])
def testParameters(self):
"""Test a set of states that set parameters to different values."""
states = [{'k1':1.0, 'k2':1.0},
{'k1':1.0, 'k2':2.0},
{'k1':2.0, 'k2':1.0}]
sampler = MultistateSampler(states, self.context)
self.assertEqual(3, len(sampler.subsets))
for i in range(len(states)):
sampler.applyState(i)
self.assertAlmostEqual(states[i]['k1'], self.context.getParameter('k1'), 5)
self.assertAlmostEqual(states[i]['k2'], self.context.getParameter('k2'), 5)
self.validateEnergies(sampler, [16.0, 18.0, 20.0])
def testParametersAndGroups(self):
"""Test a set of states that differ in both parameters and force groups."""
states = [{'k1':1.0, 'groups':{1}},
{'k1':1.0, 'groups':{1, 2, 3}},
{'k1':2.0, 'groups':{1:0.5}},
{'k1':2.0, 'groups':{1:1.0, 2:0.5, 3:0.5}}]
sampler = MultistateSampler(states, self.context)
self.assertEqual(2, len(sampler.subsets))
self.assertEqual(2, len(sampler.groups_of_groups[0]))
self.assertEqual(2, len(sampler.groups_of_groups[1]))
for i in range(len(states)):
sampler.applyState(i)
self.assertAlmostEqual(states[i]['k1'], self.context.getParameter('k1'), 5)
self.assertAlmostEqual(1.0, self.context.getParameter('k2'), 5)
self.validateEnergies(sampler, [4.0, 16.0, 4.0, 14.0])
def testCustomIntegrator(self):
"""Test a set of states that set global variables on a CustomIntegrator."""
integrator = CustomIntegrator(0.001)
integrator.addGlobalVariable('temperature', 300.0)
integrator.addGlobalVariable('friction', 1.0)
context = Context(self.system, integrator, Platform.getPlatform('Reference'))
context.setPositions([Vec3(0, 0, 0), Vec3(0, 2, 0)])
states = [{'temperature':300, 'friction':1.0},
{'temperature':300, 'friction':2.0},
{'temperature':500, 'friction':1.0},
{'temperature':500, 'friction':2.0}]
sampler = MultistateSampler(states, context)
self.assertEqual(1, len(sampler.subsets))
for i in range(len(states)):
sampler.applyState(i)
self.assertAlmostEqual(states[i]['temperature'], integrator.getGlobalVariableByName('temperature'), 5)
self.assertAlmostEqual(states[i]['friction'], integrator.getGlobalVariableByName('friction'), 5)
self.assertAlmostEqual(states[i]['temperature'], context.getParameter('MonteCarloTemperature'), 5)
self.validateEnergies(sampler, [16.0]*len(states))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment