Unverified Commit 7abdccbb authored by Evan Pretti's avatar Evan Pretti Committed by GitHub
Browse files

Reduce LennardJonesGenerator memory consumption with many NBFixPair-matching pairs (#4770)

* Include more cases in TestForceField.test_NBFix

* Reduce LennardJonesGenerator NBFIX memory consumption

* Restored test case to use multiple elements to avoid ambiguity
parent 1062810a
...@@ -6,9 +6,9 @@ Simbios, the NIH National Center for Physics-Based Simulation of ...@@ -6,9 +6,9 @@ Simbios, the NIH National Center for Physics-Based Simulation of
Biological Structures at Stanford, funded under the NIH Roadmap for Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org. Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2012-2024 Stanford University and the Authors. Portions copyright (c) 2012-2025 Stanford University and the Authors.
Authors: Peter Eastman, Mark Friedrichs Authors: Peter Eastman, Mark Friedrichs
Contributors: Contributors: Evan Pretti
Permission is hereby granted, free of charge, to any person obtaining a Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"), copy of this software and associated documentation files (the "Software"),
...@@ -2553,20 +2553,38 @@ class LennardJonesGenerator(object): ...@@ -2553,20 +2553,38 @@ class LennardJonesGenerator(object):
def __init__(self, forcefield, lj14scale, useDispersionCorrection): def __init__(self, forcefield, lj14scale, useDispersionCorrection):
self.ff = forcefield self.ff = forcefield
self.nbfixTypes = {} self.nbfixParameters = []
self.nbfixTypes1 = defaultdict(set)
self.nbfixTypes2 = defaultdict(set)
self.lj14scale = lj14scale self.lj14scale = lj14scale
self.useDispersionCorrection = useDispersionCorrection self.useDispersionCorrection = useDispersionCorrection
self.ljTypes = ForceField._AtomTypeParameters(forcefield, 'LennardJonesForce', 'Atom', ('sigma', 'epsilon')) self.ljTypes = ForceField._AtomTypeParameters(forcefield, 'LennardJonesForce', 'Atom', ('sigma', 'epsilon'))
def registerNBFIX(self, parameters): def registerNBFIX(self, parameters):
types = self.ff._findAtomTypes(parameters, 2) types = self.ff._findAtomTypes(parameters, 2)
if None not in types: if None not in types:
sigma = _convertParameterToNumber(parameters['sigma'])
epsilon = _convertParameterToNumber(parameters['epsilon'])
# Retrieve the index of nbfixParameters into which this sigma and
# epsilon will be stored, then register this index with the atom
# types that should have this sigma and epsilon applied.
nbfixIndex = len(self.nbfixParameters)
self.nbfixParameters.append([sigma, epsilon])
for type1 in types[0]: for type1 in types[0]:
self.nbfixTypes1[type1].add(nbfixIndex)
for type2 in types[1]: for type2 in types[1]:
epsilon = _convertParameterToNumber(parameters['epsilon']) self.nbfixTypes2[type2].add(nbfixIndex)
sigma = _convertParameterToNumber(parameters['sigma'])
self.nbfixTypes[(type1, type2)] = [sigma, epsilon] def getNBFIX(self, type1, type2):
self.nbfixTypes[(type2, type1)] = [sigma, epsilon] nbfixIndices = (self.nbfixTypes1[type1] & self.nbfixTypes2[type2]) | (self.nbfixTypes2[type1] & self.nbfixTypes1[type2])
if nbfixIndices:
if len(nbfixIndices) > 1:
raise ValueError('Multiple NBFixPair entries match atom types %s-%s.' % (type1, type2))
return self.nbfixParameters[nbfixIndices.pop()]
else:
return None
def registerLennardJones(self, parameters): def registerLennardJones(self, parameters):
self.ljTypes.registerAtom(parameters) self.ljTypes.registerAtom(parameters)
...@@ -2605,7 +2623,7 @@ class LennardJonesGenerator(object): ...@@ -2605,7 +2623,7 @@ class LennardJonesGenerator(object):
# First derive the lookup tables. We need to include entries for every type # First derive the lookup tables. We need to include entries for every type
# that a) appears in the system and b) has unique parameters. # that a) appears in the system and b) has unique parameters.
nbfixTypeSet = set().union(*self.nbfixTypes) nbfixTypeSet = {t for nbfixTypes in (self.nbfixTypes1, self.nbfixTypes2) for t in nbfixTypes if nbfixTypes[t]}
allTypes = set(data.atomType[atom] for atom in data.atoms) allTypes = set(data.atomType[atom] for atom in data.atoms)
mergedTypes = [] mergedTypes = []
mergedTypeParams = [] mergedTypeParams = []
...@@ -2636,10 +2654,9 @@ class LennardJonesGenerator(object): ...@@ -2636,10 +2654,9 @@ class LennardJonesGenerator(object):
bcoef = acoef[:] bcoef = acoef[:]
for m in range(numLjTypes): for m in range(numLjTypes):
for n in range(numLjTypes): for n in range(numLjTypes):
pair = (mergedTypes[m], mergedTypes[n]) nbfix = self.getNBFIX(mergedTypes[m], mergedTypes[n])
if pair in self.nbfixTypes: if nbfix is not None:
epsilon = self.nbfixTypes[pair][1] sigma, epsilon = nbfix
sigma = self.nbfixTypes[pair][0]
sigma6 = sigma**6 sigma6 = sigma**6
acoef[m+numLjTypes*n] = 4*epsilon*sigma6*sigma6 acoef[m+numLjTypes*n] = 4*epsilon*sigma6*sigma6
bcoef[m+numLjTypes*n] = 4*epsilon*sigma6 bcoef[m+numLjTypes*n] = 4*epsilon*sigma6
...@@ -2709,10 +2726,9 @@ class LennardJonesGenerator(object): ...@@ -2709,10 +2726,9 @@ class LennardJonesGenerator(object):
a1 = data.atoms[p1] a1 = data.atoms[p1]
a2 = data.atoms[p2] a2 = data.atoms[p2]
if (p1,p2) not in skip and (p2,p1) not in skip: if (p1,p2) not in skip and (p2,p1) not in skip:
type1 = data.atomType[a1] nbfix = self.getNBFIX(data.atomType[a1], data.atomType[a2])
type2 = data.atomType[a2] if nbfix is not None:
if (type1, type2) in self.nbfixTypes: sigma, epsilon = nbfix
sigma, epsilon = self.nbfixTypes[(type1, type2)]
else: else:
values1 = self.ljTypes.getAtomParameters(a1, data) values1 = self.ljTypes.getAtomParameters(a1, data)
values2 = self.ljTypes.getAtomParameters(a2, data) values2 = self.ljTypes.getAtomParameters(a2, data)
......
...@@ -914,68 +914,81 @@ class TestForceField(unittest.TestCase): ...@@ -914,68 +914,81 @@ class TestForceField(unittest.TestCase):
def test_NBFix(self): def test_NBFix(self):
"""Test using LennardJonesGenerator to implement NBFix terms.""" """Test using LennardJonesGenerator to implement NBFix terms."""
# Create a chain of five atoms. # Create a chain of seven atoms.
top = Topology() top = Topology()
chain = top.addChain() chain = top.addChain()
res = top.addResidue('RES', chain) res = top.addResidue('RES', chain)
top.addAtom('A', elem.oxygen, res) top.addAtom('A', elem.carbon, res)
top.addAtom('B', elem.carbon, res) top.addAtom('B', elem.nitrogen, res)
top.addAtom('C', elem.carbon, res) top.addAtom('C', elem.nitrogen, res)
top.addAtom('D', elem.carbon, res) top.addAtom('D', elem.oxygen, res)
top.addAtom('E', elem.nitrogen, res) top.addAtom('E', elem.carbon, res)
top.addAtom('F', elem.nitrogen, res)
top.addAtom('G', elem.oxygen, res)
atoms = list(top.atoms()) atoms = list(top.atoms())
top.addBond(atoms[0], atoms[1]) top.addBond(atoms[0], atoms[1])
top.addBond(atoms[1], atoms[2]) top.addBond(atoms[1], atoms[2])
top.addBond(atoms[2], atoms[3]) top.addBond(atoms[2], atoms[3])
top.addBond(atoms[3], atoms[4]) top.addBond(atoms[3], atoms[4])
top.addBond(atoms[4], atoms[5])
top.addBond(atoms[5], atoms[6])
# Create the force field and system. # Create the force field and system.
xml = """ xml = """
<ForceField> <ForceField>
<AtomTypes> <AtomTypes>
<Type name="A" class="A" element="O" mass="1"/> <Type name="A" class="A" element="C" mass="1"/>
<Type name="B" class="B" element="C" mass="1"/> <Type name="B" class="B" element="N" mass="1"/>
<Type name="C" class="C" element="C" mass="1"/> <Type name="C" class="C" element="O" mass="1"/>
<Type name="D" class="D" element="C" mass="1"/>
<Type name="E" class="E" element="N" mass="1"/>
</AtomTypes> </AtomTypes>
<Residues> <Residues>
<Residue name="RES"> <Residue name="RES">
<Atom name="A" type="A"/> <Atom name="A" type="A"/>
<Atom name="B" type="B"/> <Atom name="B" type="B"/>
<Atom name="C" type="C"/> <Atom name="C" type="B"/>
<Atom name="D" type="D"/> <Atom name="D" type="C"/>
<Atom name="E" type="E"/> <Atom name="E" type="A"/>
<Atom name="F" type="B"/>
<Atom name="G" type="C"/>
<Bond atomName1="A" atomName2="B"/> <Bond atomName1="A" atomName2="B"/>
<Bond atomName1="B" atomName2="C"/> <Bond atomName1="B" atomName2="C"/>
<Bond atomName1="C" atomName2="D"/> <Bond atomName1="C" atomName2="D"/>
<Bond atomName1="D" atomName2="E"/> <Bond atomName1="D" atomName2="E"/>
<Bond atomName1="E" atomName2="F"/>
<Bond atomName1="F" atomName2="G"/>
</Residue> </Residue>
</Residues> </Residues>
<LennardJonesForce lj14scale="0.3"> <LennardJonesForce lj14scale="0.3">
<Atom type="A" sigma="1" epsilon="0.1"/> <Atom type="A" sigma="2.1" epsilon="1.1"/>
<Atom type="B" sigma="2" epsilon="0.2"/> <Atom type="B" sigma="2.2" epsilon="1.2"/>
<Atom type="C" sigma="3" epsilon="0.3"/> <Atom type="C" sigma="2.4" epsilon="1.4"/>
<Atom type="D" sigma="4" epsilon="0.4"/> <NBFixPair type1="C" type2="C" sigma="3.1" epsilon="4.1"/>
<Atom type="E" sigma="4" epsilon="0.4"/> <NBFixPair type1="A" type2="A" sigma="3.2" epsilon="4.2"/>
<NBFixPair type1="A" type2="D" sigma="2.5" epsilon="1.1"/> <NBFixPair type1="B" type2="A" sigma="3.4" epsilon="4.4"/>
<NBFixPair type1="A" type2="E" sigma="3.5" epsilon="1.5"/>
</LennardJonesForce> </LennardJonesForce>
</ForceField> """ </ForceField> """
ff = ForceField(StringIO(xml)) ff = ForceField(StringIO(xml))
system = ff.createSystem(top) system = ff.createSystem(top)
# Check that it produces the correct energy. # Check that it produces the correct energy.
# The chain is A-B-B-C-A-B-C, and the pairs that are evaluated are:
# A0-C3, A0-A4, A0-B5, A0-C6,
# B1-A4, B1-B5, B1-C6,
# B2-B5, B2-C6,
# C3-C6.
integrator = VerletIntegrator(0.001) integrator = VerletIntegrator(0.001)
context = Context(system, integrator, Platform.getPlatform(0)) context = Context(system, integrator, Platform.getPlatform(0))
positions = [Vec3(i, 0, 0) for i in range(5)]*nanometers positions = [Vec3(i, 0, 0) for i in range(7)]*nanometers
context.setPositions(positions) context.setPositions(positions)
def ljEnergy(sigma, epsilon, r): def ljEnergy(sigma, epsilon, r):
return 4*epsilon*((sigma/r)**12-(sigma/r)**6) return 4*epsilon*((sigma/r)**12-(sigma/r)**6)
expected = 0.3*ljEnergy(2.5, 1.1, 3) + 0.3*ljEnergy(3.0, sqrt(0.08), 3) + ljEnergy(3.5, 1.5, 4) expected = 0.3*ljEnergy(2.25, math.sqrt(1.54), 3) + ljEnergy(3.2, 4.2, 4) + ljEnergy(3.4, 4.4, 5) + ljEnergy(2.25, math.sqrt(1.54), 6) \
+ 0.3*ljEnergy(3.4, 4.4, 3) + ljEnergy(2.2, 1.2, 4) + ljEnergy(2.3, math.sqrt(1.68), 5) \
+ 0.3*ljEnergy(2.2, 1.2, 3) + ljEnergy(2.3, math.sqrt(1.68), 4) \
+ 0.3*ljEnergy(3.1, 4.1, 3)
self.assertAlmostEqual(expected, context.getState(getEnergy=True).getPotentialEnergy().value_in_unit(kilojoules_per_mole)) self.assertAlmostEqual(expected, context.getState(getEnergy=True).getPotentialEnergy().value_in_unit(kilojoules_per_mole))
def test_IgnoreExternalBonds(self): def test_IgnoreExternalBonds(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment