Commit d19984d6 authored by peastman's avatar peastman
Browse files

Merge pull request #1373 from rafwiewiora/forcefield_xml

Residue template overloading / User can specify template if multiple match / Refactor XML loading / No multiple gens
parents 1ffb9e0a 0c68e831
...@@ -120,21 +120,27 @@ class ForceField(object): ...@@ -120,21 +120,27 @@ class ForceField(object):
self._forces = [] self._forces = []
self._scripts = [] self._scripts = []
self._templateGenerators = [] self._templateGenerators = []
for file in files: self.loadFile(files)
self.loadFile(file)
def loadFile(self, file): def loadFile(self, files):
"""Load an XML file and add the definitions from it to this ForceField. """Load an XML file and add the definitions from it to this ForceField.
Parameters Parameters
---------- ----------
file : string or file files : string or file or tuple
An XML file containing force field definitions. It may be either an An XML file or tuple of XML files containing force field definitions.
absolute file path, a path relative to the current working Each entry may be either an absolute file path, a path relative to the current working
directory, a path relative to this module's data subdirectory (for directory, a path relative to this module's data subdirectory (for
built in force fields), or an open file-like object with a read() built in force fields), or an open file-like object with a read()
method from which the forcefield XML data can be loaded. method from which the forcefield XML data can be loaded.
""" """
if not isinstance(files, tuple):
files = (files,)
trees = []
for file in files:
try: try:
# this handles either filenames or open file-like objects # this handles either filenames or open file-like objects
tree = etree.parse(file) tree = etree.parse(file)
...@@ -152,20 +158,25 @@ class ForceField(object): ...@@ -152,20 +158,25 @@ class ForceField(object):
msg += "ForceField.loadFile() encountered an error reading file '%s'\n" % filename msg += "ForceField.loadFile() encountered an error reading file '%s'\n" % filename
raise Exception(msg) raise Exception(msg)
root = tree.getroot() trees.append(tree)
# Load the atom types. # Load the atom types.
for tree in trees:
if tree.getroot().find('AtomTypes') is not None: if tree.getroot().find('AtomTypes') is not None:
for type in tree.getroot().find('AtomTypes').findall('Type'): for type in tree.getroot().find('AtomTypes').findall('Type'):
self.registerAtomType(type.attrib) self.registerAtomType(type.attrib)
# Load the residue templates. # Load the residue templates.
for tree in trees:
if tree.getroot().find('Residues') is not None: if tree.getroot().find('Residues') is not None:
for residue in root.find('Residues').findall('Residue'): for residue in tree.getroot().find('Residues').findall('Residue'):
resName = residue.attrib['name'] resName = residue.attrib['name']
template = ForceField._TemplateData(resName) template = ForceField._TemplateData(resName)
if 'overload' in residue.attrib:
template.overloadLevel = int(residue.attrib['overload'])
atomIndices = {} atomIndices = {}
for atom in residue.findall('Atom'): for atom in residue.findall('Atom'):
params = {} params = {}
...@@ -194,12 +205,14 @@ class ForceField(object): ...@@ -194,12 +205,14 @@ class ForceField(object):
# Load force definitions # Load force definitions
for child in root: for tree in trees:
for child in tree.getroot():
if child.tag in parsers: if child.tag in parsers:
parsers[child.tag](child, self) parsers[child.tag](child, self)
# Load scripts # Load scripts
for tree in trees:
for node in tree.getroot().findall('Script'): for node in tree.getroot().findall('Script'):
self.registerScript(node.text) self.registerScript(node.text)
...@@ -237,6 +250,24 @@ class ForceField(object): ...@@ -237,6 +250,24 @@ class ForceField(object):
self._templates[template.name] = template self._templates[template.name] = template
signature = _createResidueSignature([atom.element for atom in template.atoms]) signature = _createResidueSignature([atom.element for atom in template.atoms])
if signature in self._templateSignatures: if signature in self._templateSignatures:
registered = False
for regtemplate in self._templateSignatures[signature]:
if regtemplate.name == template.name:
if regtemplate.overloadLevel > template.overloadLevel:
# ok to break - this is done every time a template is
# registered so there can only be one already existing
# with same name at a time
registered = True
break
elif regtemplate.overloadLevel < template.overloadLevel:
self._templateSignatures[signature].remove(regtemplate)
self._templateSignatures[signature].append(template)
registered = True
else:
raise Exception('Residue template %s with the same overloadLevel %d already exists.' %
(template.name, template.overloadLevel)
)
if not registered:
self._templateSignatures[signature].append(template) self._templateSignatures[signature].append(template)
else: else:
self._templateSignatures[signature] = [template] self._templateSignatures[signature] = [template]
...@@ -379,6 +410,7 @@ class ForceField(object): ...@@ -379,6 +410,7 @@ class ForceField(object):
self.virtualSites = [] self.virtualSites = []
self.bonds = [] self.bonds = []
self.externalBonds = [] self.externalBonds = []
self.overloadLevel = 0
def getAtomIndexByName(self, atom_name): def getAtomIndexByName(self, atom_name):
"""Look up an atom index by atom name, providing a helpful error message if not found.""" """Look up an atom index by atom name, providing a helpful error message if not found."""
...@@ -566,11 +598,16 @@ class ForceField(object): ...@@ -566,11 +598,16 @@ class ForceField(object):
matches = None matches = None
signature = _createResidueSignature([atom.element for atom in res.atoms()]) signature = _createResidueSignature([atom.element for atom in res.atoms()])
if signature in self._templateSignatures: if signature in self._templateSignatures:
allMatches = []
for t in self._templateSignatures[signature]: for t in self._templateSignatures[signature]:
matches = _matchResidue(res, t, bondedToAtom) match = _matchResidue(res, t, bondedToAtom)
if matches is not None: if match is not None:
template = t allMatches.append((t, match))
break if len(allMatches) == 1:
template = allMatches[0][0]
matches = allMatches[0][1]
elif len(allMatches) > 1:
raise Exception('Multiple matching templates found for residue %d (%s).' % (res.index+1, res.name))
return [template, matches] return [template, matches]
def _buildBondedToAtomList(self, topology): def _buildBondedToAtomList(self, topology):
...@@ -703,7 +740,7 @@ class ForceField(object): ...@@ -703,7 +740,7 @@ class ForceField(object):
return [templates, unique_unmatched_residues] return [templates, unique_unmatched_residues]
def createSystem(self, topology, nonbondedMethod=NoCutoff, nonbondedCutoff=1.0*unit.nanometer, def createSystem(self, topology, nonbondedMethod=NoCutoff, nonbondedCutoff=1.0*unit.nanometer,
constraints=None, rigidWater=True, removeCMMotion=True, hydrogenMass=None, **args): constraints=None, rigidWater=True, removeCMMotion=True, hydrogenMass=None, residueTemplates=dict(), **args):
"""Construct an OpenMM System representing a Topology with this force field. """Construct an OpenMM System representing a Topology with this force field.
Parameters Parameters
...@@ -727,6 +764,13 @@ class ForceField(object): ...@@ -727,6 +764,13 @@ class ForceField(object):
The mass to use for hydrogen atoms bound to heavy atoms. Any mass The mass to use for hydrogen atoms bound to heavy atoms. Any mass
added to a hydrogen is subtracted from the heavy atom to keep added to a hydrogen is subtracted from the heavy atom to keep
their total mass the same. their total mass the same.
residueTemplates : dict=dict()
Key: Topology Residue object
Value: string, name of _TemplateData residue template object to use for
(Key) residue
This allows user to specify which template to apply to particular Residues
in the event that multiple matching templates are available (e.g Fe2+ and Fe3+
templates in the ForceField for a monoatomic iron ion in the topology).
args args
Arbitrary additional keyword arguments may also be specified. Arbitrary additional keyword arguments may also be specified.
This allows extra parameters to be specified that are specific to This allows extra parameters to be specified that are specific to
...@@ -765,6 +809,13 @@ class ForceField(object): ...@@ -765,6 +809,13 @@ class ForceField(object):
for chain in topology.chains(): for chain in topology.chains():
for res in chain.residues(): for res in chain.residues():
if res in residueTemplates:
tname = residueTemplates[res]
template = self._templates[tname]
matches = _matchResidue(res, template, bondedToAtom)
if matches is None:
raise Exception('User-supplied template %s does not match the residue %d (%s)' % (tname, res.index+1, res.name))
else:
# Attempt to match one of the existing templates. # Attempt to match one of the existing templates.
[template, matches] = self._getResidueTemplateMatches(res, bondedToAtom) [template, matches] = self._getResidueTemplateMatches(res, bondedToAtom)
if matches is None: if matches is None:
...@@ -1183,8 +1234,12 @@ class HarmonicBondGenerator(object): ...@@ -1183,8 +1234,12 @@ class HarmonicBondGenerator(object):
@staticmethod @staticmethod
def parseElement(element, ff): def parseElement(element, ff):
existing = [f for f in ff._forces if isinstance(f, HarmonicBondGenerator)]
if len(existing) == 0:
generator = HarmonicBondGenerator(ff) generator = HarmonicBondGenerator(ff)
ff.registerGenerator(generator) ff.registerGenerator(generator)
else:
generator = existing[0]
for bond in element.findall('Bond'): for bond in element.findall('Bond'):
generator.registerBond(bond.attrib) generator.registerBond(bond.attrib)
...@@ -1236,8 +1291,12 @@ class HarmonicAngleGenerator(object): ...@@ -1236,8 +1291,12 @@ class HarmonicAngleGenerator(object):
@staticmethod @staticmethod
def parseElement(element, ff): def parseElement(element, ff):
existing = [f for f in ff._forces if isinstance(f, HarmonicAngleGenerator)]
if len(existing) == 0:
generator = HarmonicAngleGenerator(ff) generator = HarmonicAngleGenerator(ff)
ff.registerGenerator(generator) ff.registerGenerator(generator)
else:
generator = existing[0]
for angle in element.findall('Angle'): for angle in element.findall('Angle'):
generator.registerAngle(angle.attrib) generator.registerAngle(angle.attrib)
...@@ -1320,8 +1379,12 @@ class PeriodicTorsionGenerator(object): ...@@ -1320,8 +1379,12 @@ class PeriodicTorsionGenerator(object):
@staticmethod @staticmethod
def parseElement(element, ff): def parseElement(element, ff):
existing = [f for f in ff._forces if isinstance(f, PeriodicTorsionGenerator)]
if len(existing) == 0:
generator = PeriodicTorsionGenerator(ff) generator = PeriodicTorsionGenerator(ff)
ff.registerGenerator(generator) ff.registerGenerator(generator)
else:
generator = existing[0]
for torsion in element.findall('Proper'): for torsion in element.findall('Proper'):
generator.registerProperTorsion(torsion.attrib) generator.registerProperTorsion(torsion.attrib)
for torsion in element.findall('Improper'): for torsion in element.findall('Improper'):
...@@ -1419,8 +1482,12 @@ class RBTorsionGenerator(object): ...@@ -1419,8 +1482,12 @@ class RBTorsionGenerator(object):
@staticmethod @staticmethod
def parseElement(element, ff): def parseElement(element, ff):
existing = [f for f in ff._forces if isinstance(f, RBTorsionGenerator)]
if len(existing) == 0:
generator = RBTorsionGenerator(ff) generator = RBTorsionGenerator(ff)
ff.registerGenerator(generator) ff.registerGenerator(generator)
else:
generator = existing[0]
for torsion in element.findall('Proper'): for torsion in element.findall('Proper'):
types = ff._findAtomTypes(torsion.attrib, 4) types = ff._findAtomTypes(torsion.attrib, 4)
if None not in types: if None not in types:
...@@ -1523,8 +1590,12 @@ class CMAPTorsionGenerator(object): ...@@ -1523,8 +1590,12 @@ class CMAPTorsionGenerator(object):
@staticmethod @staticmethod
def parseElement(element, ff): def parseElement(element, ff):
existing = [f for f in ff._forces if isinstance(f, CMAPTorsionGenerator)]
if len(existing) == 0:
generator = CMAPTorsionGenerator(ff) generator = CMAPTorsionGenerator(ff)
ff.registerGenerator(generator) ff.registerGenerator(generator)
else:
generator = existing[0]
for map in element.findall('Map'): for map in element.findall('Map'):
values = [float(x) for x in map.text.split()] values = [float(x) for x in map.text.split()]
size = sqrt(len(values)) size = sqrt(len(values))
......
...@@ -448,6 +448,183 @@ class TestForceField(unittest.TestCase): ...@@ -448,6 +448,183 @@ class TestForceField(unittest.TestCase):
self.assertEqual(templates[1].name, 'ALA') self.assertEqual(templates[1].name, 'ALA')
self.assertEqual(templates[2].name, 'CALA') self.assertEqual(templates[2].name, 'CALA')
def test_Wildcard(self):
"""Test that PeriodicTorsionForces using wildcard ('') for atom types / classes in the ffxml are correctly registered"""
# Use wildcards in types
xml = """
<ForceField>
<AtomTypes>
<Type name="C" class="C" element="C" mass="12.010000"/>
<Type name="O" class="O" element="O" mass="16.000000"/>
</AtomTypes>
<PeriodicTorsionForce>
<Proper type1="" type2="C" type3="C" type4="" periodicity1="2" phase1="3.141593" k1="15.167000"/>
<Improper type1="C" type2="" type3="" type4="O" periodicity1="2" phase1="3.141593" k1="43.932000"/>
</PeriodicTorsionForce>
</ForceField>"""
ff = ForceField(StringIO(xml))
self.assertEqual(len(ff._forces[0].proper), 1)
self.assertEqual(len(ff._forces[0].improper), 1)
# Use wildcards in classes
xml = """
<ForceField>
<AtomTypes>
<Type name="C" class="C" element="C" mass="12.010000"/>
<Type name="O" class="O" element="O" mass="16.000000"/>
</AtomTypes>
<PeriodicTorsionForce>
<Proper class1="" class2="C" class3="C" class4="" periodicity1="2" phase1="3.141593" k1="15.167000"/>
<Improper class1="C" class2="" class3="" class4="O" periodicity1="2" phase1="3.141593" k1="43.932000"/>
</PeriodicTorsionForce>
</ForceField>"""
ff = ForceField(StringIO(xml))
self.assertEqual(len(ff._forces[0].proper), 1)
self.assertEqual(len(ff._forces[0].improper), 1)
def test_ScalingFactorCombining(self):
""" Tests that FFs can be combined if their scaling factors are very close """
forcefield = ForceField('amber99sb.xml', os.path.join('systems', 'test_amber_ff.xml'))
# This would raise an exception if it didn't work
def test_MultipleFilesandForceTags(self):
"""Test that the order of listing of multiple ffxmls does not matter.
Tests that one generator per force type is created and that the ffxml
defining atom types does not have to be listed first"""
ffxml = """<ForceField>
<Residues>
<Residue name="ACE-Test">
<Atom name="HH31" type="710"/>
<Atom name="CH3" type="711"/>
<Atom name="HH32" type="710"/>
<Atom name="HH33" type="710"/>
<Atom name="C" type="712"/>
<Atom name="O" type="713"/>
<Bond from="0" to="1"/>
<Bond from="1" to="2"/>
<Bond from="1" to="3"/>
<Bond from="1" to="4"/>
<Bond from="4" to="5"/>
<ExternalBond from="4"/>
</Residue>
</Residues>
<PeriodicTorsionForce>
<Proper class1="C" class2="C" class3="C" class4="C" periodicity1="2" phase1="3.14159265359" k1="10.46"/>
<Improper class1="C" class2="C" class3="C" class4="C" periodicity1="2" phase1="3.14159265359" k1="43.932"/>
</PeriodicTorsionForce>
</ForceField>"""
ff1 = ForceField(StringIO(ffxml), 'amber99sbildn.xml')
ff2 = ForceField('amber99sbildn.xml', StringIO(ffxml))
self.assertEqual(len(ff1._forces), 4)
self.assertEqual(len(ff2._forces), 4)
pertorsion1 = ff1._forces[0]
pertorsion2 = ff2._forces[2]
self.assertEqual(len(pertorsion1.proper), 110)
self.assertEqual(len(pertorsion1.improper), 42)
self.assertEqual(len(pertorsion2.proper), 110)
self.assertEqual(len(pertorsion2.improper), 42)
def test_ResidueTemplateUserChoice(self):
"""Test createSystem does not allow multiple matching templates, unless
user has specified which template to use via residueTemplates arg"""
ffxml = """<ForceField>
<AtomTypes>
<Type name="Fe2+" class="Fe2+" element="Fe" mass="55.85"/>
<Type name="Fe3+" class="Fe3+" element="Fe" mass="55.85"/>
</AtomTypes>
<Residues>
<Residue name="FE2">
<Atom name="FE2" type="Fe2+" charge="2.0"/>
</Residue>
<Residue name="FE">
<Atom name="FE" type="Fe3+" charge="3.0"/>
</Residue>
</Residues>
<NonbondedForce coulomb14scale="0.833333333333" lj14scale="0.5">
<UseAttributeFromResidue name="charge"/>
<Atom type="Fe2+" sigma="0.227535532613" epsilon="0.0150312292"/>
<Atom type="Fe3+" sigma="0.192790482606" epsilon="0.00046095128"/>
</NonbondedForce>
</ForceField>"""
pdb_string = "ATOM 1 FE FE A 1 20.956 27.448 -29.067 1.00 0.00 Fe"
ff = ForceField(StringIO(ffxml))
pdb = PDBFile(StringIO(pdb_string))
self.assertRaises(Exception, lambda: ff.createSystem(pdb.topology))
sys = ff.createSystem(pdb.topology, residueTemplates={list(pdb.topology.residues())[0] : 'FE2'})
# confirm charge
self.assertEqual(sys.getForce(0).getParticleParameters(0)[0]._value, 2.0)
sys = ff.createSystem(pdb.topology, residueTemplates={list(pdb.topology.residues())[0] : 'FE'})
# confirm charge
self.assertEqual(sys.getForce(0).getParticleParameters(0)[0]._value, 3.0)
def test_ResidueOverloading(self):
"""Test residue overloading via overload tag in the XML"""
ffxml1 = """<ForceField>
<AtomTypes>
<Type name="Fe2+_tip3p_HFE" class="Fe2+_tip3p_HFE" element="Fe" mass="55.85"/>
</AtomTypes>
<Residues>
<Residue name="FE2">
<Atom name="FE2" type="Fe2+_tip3p_HFE" charge="2.0"/>
</Residue>
</Residues>
<NonbondedForce coulomb14scale="0.833333333333" lj14scale="0.5">
<UseAttributeFromResidue name="charge"/>
<Atom type="Fe2+_tip3p_HFE" sigma="0.227535532613" epsilon="0.0150312292"/>
</NonbondedForce>
</ForceField>"""
ffxml2 = """<ForceField>
<AtomTypes>
<Type name="Fe2+_tip3p_standard" class="Fe2+_tip3p_standard" element="Fe" mass="55.85"/>
</AtomTypes>
<Residues>
<Residue name="FE2">
<Atom name="FE2" type="Fe2+_tip3p_standard" charge="2.0"/>
</Residue>
</Residues>
<NonbondedForce coulomb14scale="0.833333333333" lj14scale="0.5">
<UseAttributeFromResidue name="charge"/>
<Atom type="Fe2+_tip3p_standard" sigma="0.241077193129" epsilon="0.03940482832"/>
</NonbondedForce>
</ForceField>"""
ffxml3 = """<ForceField>
<AtomTypes>
<Type name="Fe2+_tip3p_standard" class="Fe2+_tip3p_standard" element="Fe" mass="55.85"/>
</AtomTypes>
<Residues>
<Residue name="FE2" overload="1">
<Atom name="FE2" type="Fe2+_tip3p_standard" charge="2.0"/>
</Residue>
</Residues>
<NonbondedForce coulomb14scale="0.833333333333" lj14scale="0.5">
<UseAttributeFromResidue name="charge"/>
<Atom type="Fe2+_tip3p_standard" sigma="0.241077193129" epsilon="0.03940482832"/>
</NonbondedForce>
</ForceField>"""
pdb_string = "ATOM 1 FE FE A 1 20.956 27.448 -29.067 1.00 0.00 Fe"
pdb = PDBFile(StringIO(pdb_string))
self.assertRaises(Exception, lambda: ForceField(StringIO(ffxml1), StringIO(ffxml2)))
ff = ForceField(StringIO(ffxml1), StringIO(ffxml3))
self.assertEqual(ff._templates['FE2'].atoms[0].type, 'Fe2+_tip3p_standard')
ff.createSystem(pdb.topology)
class AmoebaTestForceField(unittest.TestCase): class AmoebaTestForceField(unittest.TestCase):
"""Test the ForceField.createSystem() method with the AMOEBA forcefield.""" """Test the ForceField.createSystem() method with the AMOEBA forcefield."""
...@@ -535,49 +712,5 @@ class AmoebaTestForceField(unittest.TestCase): ...@@ -535,49 +712,5 @@ class AmoebaTestForceField(unittest.TestCase):
diff = norm(f1-f2) diff = norm(f1-f2)
self.assertTrue(diff < 0.1 or diff/norm(f1) < 1e-3) self.assertTrue(diff < 0.1 or diff/norm(f1) < 1e-3)
def test_Wildcard(self):
"""Test that PeriodicTorsionForces using wildcard ('') for atom types / classes in the ffxml are correctly registered"""
# Use wildcards in types
xml = """
<ForceField>
<AtomTypes>
<Type name="C" class="C" element="C" mass="12.010000"/>
<Type name="O" class="O" element="O" mass="16.000000"/>
</AtomTypes>
<PeriodicTorsionForce>
<Proper type1="" type2="C" type3="C" type4="" periodicity1="2" phase1="3.141593" k1="15.167000"/>
<Improper type1="C" type2="" type3="" type4="O" periodicity1="2" phase1="3.141593" k1="43.932000"/>
</PeriodicTorsionForce>
</ForceField>"""
ff = ForceField(StringIO(xml))
self.assertEqual(len(ff._forces[0].proper), 1)
self.assertEqual(len(ff._forces[0].improper), 1)
# Use wildcards in classes
xml = """
<ForceField>
<AtomTypes>
<Type name="C" class="C" element="C" mass="12.010000"/>
<Type name="O" class="O" element="O" mass="16.000000"/>
</AtomTypes>
<PeriodicTorsionForce>
<Proper class1="" class2="C" class3="C" class4="" periodicity1="2" phase1="3.141593" k1="15.167000"/>
<Improper class1="C" class2="" class3="" class4="O" periodicity1="2" phase1="3.141593" k1="43.932000"/>
</PeriodicTorsionForce>
</ForceField>"""
ff = ForceField(StringIO(xml))
self.assertEqual(len(ff._forces[0].proper), 1)
self.assertEqual(len(ff._forces[0].improper), 1)
def test_ScalingFactorCombining(self):
""" Tests that FFs can be combined if their scaling factors are very close """
forcefield = ForceField('amber99sb.xml', os.path.join('systems', 'test_amber_ff.xml'))
# This would raise an exception if it didn't work
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment