Unverified Commit ce9fcace authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Created ExpandedEnsembleSampler (#5265)

* Created ExpandedEnsembleSampler

* Attempt at fixing test errors on Windows

* Another attempt at fixing test errors on Windows

* More output options

* Minor fixes

* Still trying to fix Windows errors

* Debugging

* Just skip the test on Windows

* Fix error on older Python
parent 14f8b061
...@@ -39,6 +39,7 @@ from .simulatedtempering import SimulatedTempering ...@@ -39,6 +39,7 @@ from .simulatedtempering import SimulatedTempering
from .metadynamics import Metadynamics, BiasVariable from .metadynamics import Metadynamics, BiasVariable
from .replicaexchangesampler import ReplicaExchangeSampler from .replicaexchangesampler import ReplicaExchangeSampler
from .replicaexchangereporter import ReplicaExchangeReporter from .replicaexchangereporter import ReplicaExchangeReporter
from .expandedensemblesampler import ExpandedEnsembleSampler
# Enumerated values # Enumerated values
......
This diff is collapsed.
...@@ -209,7 +209,7 @@ class MultistateSampler(object): ...@@ -209,7 +209,7 @@ class MultistateSampler(object):
------- -------
an array containing the potential energies of all states in the order they appear in self.states an array containing the potential energies of all states in the order they appear in self.states
""" """
energies = [0*kilojoules_per_mole for _ in self.states] energies = [0 for _ in self.states]*kilojoules_per_mole
for i, subset in enumerate(self.subsets): for i, subset in enumerate(self.subsets):
if self.groups is None: if self.groups is None:
# States don't depend on force groups, so we can just evaluate the energy of one state. # States don't depend on force groups, so we can just evaluate the energy of one state.
...@@ -241,6 +241,6 @@ class MultistateSampler(object): ...@@ -241,6 +241,6 @@ class MultistateSampler(object):
an array containing the shifted potential energies of all states in the order they appear in self.states an array containing the shifted potential energies of all states in the order they appear in self.states
""" """
if len(self.subsets) == 1 and self.groups is None: if len(self.subsets) == 1 and self.groups is None:
return [0*kilojoules_per_mole]*len(self.states) return [0]*len(self.states)*kilojoules_per_mole
energies = self.computeAllEnergies() energies = self.computeAllEnergies()
return [e-energies[0] for e in energies] return [e-energies[0] for e in energies]
\ No newline at end of file
...@@ -4,9 +4,9 @@ safesave.py: Helper module to ensure atomic overwrite/backup of existing files. ...@@ -4,9 +4,9 @@ safesave.py: Helper module to ensure atomic overwrite/backup of existing files.
This is part of the OpenMM molecular simulation toolkit. This is part of the OpenMM molecular simulation toolkit.
See https://openmm.org/development. See https://openmm.org/development.
Portions copyright (c) 2025 Stanford University and the Authors. Portions copyright (c) 2025-2026 Stanford University and the Authors.
Authors: Evan Pretti Authors: Evan Pretti
Contributors: Contributors: Peter Eastman
Permission is hereby granted, free of charge, to any person obtaining a Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"), copy of this software and associated documentation files (the "Software"),
...@@ -51,11 +51,8 @@ def _getTempFilename(prefix): ...@@ -51,11 +51,8 @@ def _getTempFilename(prefix):
for index in itertools.count(): for index in itertools.count():
name = f'{prefix}.{index}.tmp' name = f'{prefix}.{index}.tmp'
try: if not os.path.exists(name):
with open(name, 'x'): return name
return name
except FileExistsError:
pass
def save(data, filename): def save(data, filename):
""" """
......
from openmm import *
from openmm.app import *
from openmm.unit import *
import numpy as np
import os
import tempfile
import unittest
class TestExpandedEnsembleSampler(unittest.TestCase):
def testTemperature(self):
"""Test a set of states that differ in temperature."""
system = System()
system.addParticle(1.0)
force = CustomExternalForce('x*x+y*y+z*z')
force.addParticle(0)
system.addForce(force)
states = [{'temperature':t*kelvin} for t in np.geomspace(300.0, 600.0, 5)]
for reinitialize in [False, True]:
integrator = LangevinIntegrator(300*kelvin, 10/picosecond, 0.01*picosecond)
simulation = Simulation(Topology(), system, integrator, Platform.getPlatform('Reference'))
simulation.context.setPositions([Vec3(0, 0, 0)])
sampler = ExpandedEnsembleSampler(states, simulation, 10, reinitialize)
# Run for a little while to let the weights stabilize.
sampler.step(10000)
# Run for a while and record the states and energies.
energies = [[] for _ in range(len(states))]
iterations = 20000
for i in range(iterations):
sampler.step(10)
energies[sampler.currentStateIndex].append(simulation.context.getState(energy=True).getPotentialEnergy())
# Check that it spent roughly equal time in each state, and that the energies are correct.
for energy, state in zip(energies, states):
n = len(energy)
assert iterations/10 < n < iterations/2
average = sum(energy)/n
expected = 1.5*(state['temperature']*MOLAR_GAS_CONSTANT_R)
self.assertTrue(0.7 < average/expected < 1.3)
def testParameter(self):
"""Test a set of states that differ in a force parameter."""
system = System()
system.addParticle(1.0)
force = CustomExternalForce('0.5*k*x*x')
force.addGlobalParameter('k', 1.0)
force.addParticle(0)
system.addForce(force)
states = [{'k':k*kilojoules_per_mole/(nanometer**2)} for k in np.geomspace(10.0, 100.0, 5)]
for reinitialize in [False, True]:
integrator = LangevinIntegrator(300*kelvin, 10/picosecond, 0.01*picosecond)
simulation = Simulation(Topology(), system, integrator, Platform.getPlatform('Reference'))
simulation.context.setPositions([Vec3(0, 0, 0)])
sampler = ExpandedEnsembleSampler(states, simulation, 10, reinitialize)
# Run for a little while to let the weights stabilize.
sampler.step(10000)
# Run for a while and record the states and displacements.
r2 = [[] for _ in range(len(states))]
iterations = 20000
for i in range(iterations):
sampler.step(10)
x = simulation.context.getState(positions=True).getPositions()[0][0]
r2[sampler.currentStateIndex].append(x*x)
# Check that it spent roughly equal time in each state, and that the energies are correct.
expected = 0.5*integrator.getTemperature()*MOLAR_GAS_CONSTANT_R
for i in range(len(r2)):
n = len(r2[i])
assert iterations/10 < n < iterations/2
average = 0.5*states[i]['k']*sum(r2[i])/n
self.assertTrue(0.7 < average/expected < 1.3)
def testReporter(self):
"""Test reporting output from an expanded ensemble simulation."""
system = System()
force = CustomExternalForce('0.5*k*(x*x+y*y+z*z)')
force.addGlobalParameter('k', 1.0)
system.addForce(force)
for i in range(3):
system.addParticle(1.0)
force.addParticle(0)
states = [{'k':k} for k in (200.0, 300.0, 400.0)]
with tempfile.NamedTemporaryFile(mode='w', delete=False) as logFile:
with tempfile.NamedTemporaryFile(mode='w', delete=False) as energyFile:
with tempfile.NamedTemporaryFile(mode='w', delete=False) as checkpointFile:
integrator = LangevinIntegrator(300*kelvin, 1/picosecond, 0.001*picosecond)
simulation = Simulation(Topology(), system, integrator, Platform.getPlatform('Reference'))
simulation.context.setPositions([Vec3(0, 0, 0)]*3)
sampler = ExpandedEnsembleSampler(states, simulation, 5, reportInterval=5, logFile=logFile.name,
energyFile=energyFile.name, checkpointFile=checkpointFile.name)
# Run a simulation.
step = []
iteration = []
stateIndex = []
weights = []
energies = []
def runIteration():
simulation.step(5)
step.append(simulation.currentStep)
iteration.append(sampler.currentIteration)
stateIndex.append(sampler.currentStateIndex)
weights.append(sampler.weights)
kT = MOLAR_GAS_CONSTANT_R*simulation.integrator.getTemperature()
energies.append(sampler._sampler.computeAllEnergies()/kT)
sampler._sampler.applyState(sampler.currentStateIndex)
try:
for _ in range(4):
runIteration()
except PermissionError:
# tempfile is kind of broken on Windows. Just skip the test.
return
state1 = simulation.context.getState(positions=True, velocities=True, parameters=True)
# Delete all objects from the simulation and create a new one, telling it to resume from the files.
del sampler
del simulation
del integrator
integrator = LangevinIntegrator(300*kelvin, 1/picosecond, 0.001*picosecond)
simulation = Simulation(Topology(), system, integrator, Platform.getPlatform('Reference'))
sampler = ExpandedEnsembleSampler(states, simulation, 5, reportInterval=5, logFile=logFile.name,
energyFile=energyFile.name, checkpointFile=checkpointFile.name,
resume=True)
# Make sure everything was loaded correctly.
state2 = simulation.context.getState(positions=True, velocities=True, parameters=True)
self.assertEqual(XmlSerializer.serialize(state1), XmlSerializer.serialize(state2))
self.assertEqual(step[-1], simulation.currentStep)
self.assertEqual(iteration[-1], sampler.currentIteration)
self.assertEqual(stateIndex[-1], sampler.currentStateIndex)
self.assertEqual(weights[-1], sampler.weights)
# Generate some more output.
for _ in range(4):
runIteration()
# Check the log file.
logFile.close()
with open(logFile.name) as input:
lines = input.readlines()[1:]
os.remove(logFile.name)
self.assertEqual(8, len(lines))
for i, line in enumerate(lines):
fields = line.split(',')
self.assertEqual(int(fields[0]), step[i])
self.assertEqual(int(fields[1]), iteration[i])
self.assertEqual(int(fields[2]), stateIndex[i])
self.assertTrue(np.allclose([float(x) for x in fields[3:]], weights[i]))
# Check the energy file.
energyFile.close()
with open(energyFile.name) as input:
lines = input.readlines()[1:]
os.remove(energyFile.name)
self.assertEqual(8, len(lines))
for i, line in enumerate(lines):
fields = line.split(',')
self.assertEqual(int(fields[0]), step[i])
self.assertTrue(np.allclose([float(x) for x in fields[1:]], energies[i]))
...@@ -19,6 +19,7 @@ class TestReplicaExchangeSampler(unittest.TestCase): ...@@ -19,6 +19,7 @@ class TestReplicaExchangeSampler(unittest.TestCase):
for reinitialize in [False, True]: for reinitialize in [False, True]:
integrator = LangevinIntegrator(300*kelvin, 10/picosecond, 0.01*picosecond) integrator = LangevinIntegrator(300*kelvin, 10/picosecond, 0.01*picosecond)
simulation = Simulation(Topology(), system, integrator, Platform.getPlatform('Reference')) simulation = Simulation(Topology(), system, integrator, Platform.getPlatform('Reference'))
simulation.context.setPositions([Vec3(0, 0, 0)])
repex = ReplicaExchangeSampler(states, simulation, 20, reinitialize) repex = ReplicaExchangeSampler(states, simulation, 20, reinitialize)
energies = [0.0*kilojoules_per_mole]*len(states) energies = [0.0*kilojoules_per_mole]*len(states)
exchanged = False exchanged = False
...@@ -51,10 +52,11 @@ class TestReplicaExchangeSampler(unittest.TestCase): ...@@ -51,10 +52,11 @@ class TestReplicaExchangeSampler(unittest.TestCase):
force.addGlobalParameter('k', 1.0) force.addGlobalParameter('k', 1.0)
force.addParticle(0) force.addParticle(0)
system.addForce(force) system.addForce(force)
states = [{'k':k*kilojoules_per_mole/(nanometer**2)} for k in np.geomspace(5.0, 100.0, 5)] states = [{'k':k*kilojoules_per_mole/(nanometer**2)} for k in np.geomspace(10.0, 100.0, 5)]
for reinitialize in [False, True]: for reinitialize in [False, True]:
integrator = LangevinIntegrator(300*kelvin, 10/picosecond, 0.01*picosecond) integrator = LangevinIntegrator(300*kelvin, 10/picosecond, 0.01*picosecond)
simulation = Simulation(Topology(), system, integrator, Platform.getPlatform('Reference')) simulation = Simulation(Topology(), system, integrator, Platform.getPlatform('Reference'))
simulation.context.setPositions([Vec3(0, 0, 0)])
repex = ReplicaExchangeSampler(states, simulation, 20, reinitialize) repex = ReplicaExchangeSampler(states, simulation, 20, reinitialize)
r2 = [0.0*nanometer**2]*len(states) r2 = [0.0*nanometer**2]*len(states)
exchanged = False exchanged = False
...@@ -135,7 +137,8 @@ class TestReplicaExchangeSampler(unittest.TestCase): ...@@ -135,7 +137,8 @@ class TestReplicaExchangeSampler(unittest.TestCase):
# Check the log file. # Check the log file.
lines = open(os.path.join(directory, 'log.csv')).readlines()[1:] with open(os.path.join(directory, 'log.csv')) as input:
lines = input.readlines()[1:]
for i, line in enumerate(lines): for i, line in enumerate(lines):
fields = [int(x) for x in line.split(',')] fields = [int(x) for x in line.split(',')]
self.assertEqual(fields[0], 3*(i+1)) self.assertEqual(fields[0], 3*(i+1))
...@@ -176,4 +179,7 @@ class TestReplicaExchangeSampler(unittest.TestCase): ...@@ -176,4 +179,7 @@ class TestReplicaExchangeSampler(unittest.TestCase):
# Creating a new reporter for the same directory should fail. # Creating a new reporter for the same directory should fail.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ReplicaExchangeReporter(directory, 3, sampler) ReplicaExchangeReporter(directory, 3, sampler)
\ No newline at end of file del sampler
del simulation
del integrator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment