Commit c1d35c3e authored by peastman's avatar peastman
Browse files

Continuing to implement force field patches

parent e1f8d449
......@@ -38,6 +38,7 @@ import itertools
import xml.etree.ElementTree as etree
import math
from math import sqrt, cos
from copy import deepcopy
import simtk.openmm as mm
import simtk.unit as unit
from . import element as elem
......@@ -203,7 +204,7 @@ class ForceField(object):
template.addExternalBondByName(bond.attrib['atomName'])
else:
template.addExternalBond(int(bond.attrib['from']))
for patch in patch.findall('AllowPatch'):
for patch in residue.findall('AllowPatch'):
patchName = patch.attrib['name']
if ':' in name:
colonIndex = name.find(':')
......@@ -235,7 +236,7 @@ class ForceField(object):
allAtomNames.add(atomName)
atomDescription = ForceField._PatchAtomData(atomName)
typeName = atom.attrib['type']
patchData.addedAtoms.append(ForceField._TemplateAtomData(atomDescription, typeName, self._atomTypes[typeName].element, params))
patchData.addedAtoms[atomDescription.residue].append(ForceField._TemplateAtomData(atomDescription.name, typeName, self._atomTypes[typeName].element, params))
for atom in patch.findall('ChangeAtom'):
params = {}
for key in atom.attrib:
......@@ -247,7 +248,7 @@ class ForceField(object):
allAtomNames.add(atomName)
atomDescription = ForceField._PatchAtomData(atomName)
typeName = atom.attrib['type']
patchData.changedAtoms.append(ForceField._TemplateAtomData(atomDescription, typeName, self._atomTypes[typeName].element, params))
patchData.changedAtoms[atomDescription.residue].append(ForceField._TemplateAtomData(atomDescription.name, typeName, self._atomTypes[typeName].element, params))
for atom in patch.findall('RemoveAtom'):
atomName = atom.attrib['name']
if atomName in allAtomNames:
......@@ -586,14 +587,74 @@ class ForceField(object):
def __init__(self, name, numResidues):
self.name = name
self.numResidues = numResidues
self.addedAtoms = []
self.addedAtoms = [[] for i in range(numResidues)]
self.changedAtoms = [[] for i in range(numResidues)]
self.deletedAtoms = []
self.changedAtoms = []
self.addedBonds = []
self.deletedBonds = []
self.addedExternalBonds = []
self.deletedExternalBonds = []
def createPatchedTemplates(self, templates):
"""Apply this patch to a set of templates, creating new modified ones."""
if len(templates) != self.numResidues:
raise ValueError("Patch '%s' expected %d templates, received %d", (self.name, self.numResidues, len(templates)))
# Construct a new version of each template.
newTemplates = []
for index, template in enumerate(templates):
newTemplate = ForceField._TemplateData("%s-%s" % (template.name, self.name))
newTemplates.append(newTemplate)
# Build the list of atoms in it.
for atom in template.atoms:
if not any(deleted.name == atom.name and deleted.residue == index for deleted in self.deletedAtoms):
newTemplate.atoms.append(deepcopy(atom))
for atom in self.addedAtoms[index]:
newTemplate.atoms.append(deepcopy(atom))
oldAtomIndex = dict([(atom.name, i) for i, atom in enumerate(template.atoms)])
newAtomIndex = dict([(atom.name, i) for i, atom in enumerate(newTemplate.atoms)])
for atom in self.changedAtoms[index]:
if atom.name not in newAtomIndex:
raise ValueError("Patch '%s' modifies nonexistent atom '%s' in template '%s'" % (self.name, atom.name, template.name))
newTemplate.atoms[newAtomIndex[atom.name]] = deepcopy(atom)
# Copy over the virtual sites, translating the atom indices.
indexMap = dict([(oldAtomIndex[name], newAtomIndex[name]) for name in newAtomIndex if name in oldAtomIndex])
for site in template.virtualSites:
if site.index in indexMap and all(i in indexMap for i in site.atoms):
newSite = deepcopy(site)
newSite.index = indexMap[site.index]
newSite.atoms = [indexMap[i] for i in site.atoms]
newTemplate.virtualSites.append(newSite)
# Build the lists of bonds and external bonds.
atomMap = dict([(template.atoms[i], indexMap[i]) for i in indexMap])
deletedBonds = [(atom1.name, atom2.name) for atom1, atom2 in self.deletedBonds if atom1.residue == index and atom2.residue == index]
for atom1, atom2 in template.bonds:
a1 = template.atoms[atom1]
a2 = template.atoms[atom2]
if (a1.name, a2.name) not in deletedBonds and (a2.name, a1.name) not in deletedBonds:
newTemplate.addBond(atomMap[a1], atomMap[a2])
deletedExternalBonds = [atom.name for atom in self.deletedExternalBonds if atom.residue == index]
for atom in template.externalBonds:
if template.atoms[atom].name not in deletedExternalBonds:
newTemplate.addExternalBond(atomMap[atom])
for atom1, atom2 in self.addedBonds:
if atom1.residue == index and atom2.residue == index:
newTemplate.addBondByName(atom1.name, atom2.name)
elif atom1.residue == index:
newTemplate.addExternalBondByName(atom1.name)
elif atom2.residue == index:
newTemplate.addExternalBondByName(atom2.name)
for atom in self.addedExternalBonds:
newTemplate.addExternalBondByName(atom.name)
return newTemplates
class _PatchAtomData(object):
"""Inner class used to encapsulate data about an atom in a patch definition."""
def __init__(self, description):
......
......@@ -39,7 +39,9 @@ class TestForceField(unittest.TestCase):
self.assertEqual(1, len(ff._patches))
patch = ff._patches['Test']
self.assertEqual(1, len(patch.addedAtoms))
self.assertEqual(1, len(patch.addedAtoms[0]))
self.assertEqual(1, len(patch.changedAtoms))
self.assertEqual(1, len(patch.changedAtoms[0]))
self.assertEqual(1, len(patch.deletedAtoms))
self.assertEqual(1, len(patch.addedBonds))
self.assertEqual(1, len(patch.deletedBonds))
......@@ -47,12 +49,10 @@ class TestForceField(unittest.TestCase):
self.assertEqual(1, len(patch.deletedExternalBonds))
self.assertEqual(1, len(ff._templatePatches))
self.assertEqual(1, len(ff._templatePatches['RES']))
self.assertEqual('A', patch.addedAtoms[0].name.name)
self.assertEqual(0, patch.addedAtoms[0].name.residue)
self.assertEqual('A type', patch.addedAtoms[0].type)
self.assertEqual('B', patch.changedAtoms[0].name.name)
self.assertEqual(0, patch.changedAtoms[0].name.residue)
self.assertEqual('B type', patch.changedAtoms[0].type)
self.assertEqual('A', patch.addedAtoms[0][0].name)
self.assertEqual('A type', patch.addedAtoms[0][0].type)
self.assertEqual('B', patch.changedAtoms[0][0].name)
self.assertEqual('B type', patch.changedAtoms[0][0].type)
self.assertEqual('C', patch.deletedAtoms[0].name)
self.assertEqual(0, patch.deletedAtoms[0].residue)
self.assertEqual('A', patch.addedBonds[0][0].name)
......@@ -70,5 +70,111 @@ class TestForceField(unittest.TestCase):
self.assertEqual('Test', ff._templatePatches['RES'][0][0])
self.assertEqual(0, ff._templatePatches['RES'][0][1])
def testParseMultiresiduePatch(self):
"""Test parsing a <Patch> tag that affects two residues."""
xml = """
<ForceField>
<AtomTypes>
<Type name="A type" class="A class" element="O" mass="15.99943"/>
<Type name="B type" class="B class" element="H" mass="1.007947"/>
</AtomTypes>
<Patches>
<Patch name="Test" residues="2">
<AddAtom name="1:A" type="A type"/>
<ChangeAtom name="2:B" type="B type"/>
<AddBond atomName1="1:A" atomName2="2:B"/>
<ApplyToResidue name="1:RESA"/>
<ApplyToResidue name="2:RESB"/>
</Patch>
</Patches>
</ForceField>"""
ff = ForceField(StringIO(xml))
self.assertEqual(1, len(ff._patches))
patch = ff._patches['Test']
self.assertEqual(2, len(patch.addedAtoms))
self.assertEqual(1, len(patch.addedAtoms[0]))
self.assertEqual(0, len(patch.addedAtoms[1]))
self.assertEqual(2, len(patch.changedAtoms))
self.assertEqual(0, len(patch.changedAtoms[0]))
self.assertEqual(1, len(patch.changedAtoms[1]))
self.assertEqual(1, len(patch.addedBonds))
self.assertEqual(2, len(ff._templatePatches))
self.assertEqual(1, len(ff._templatePatches['RESA']))
self.assertEqual(1, len(ff._templatePatches['RESB']))
self.assertEqual('A', patch.addedAtoms[0][0].name)
self.assertEqual('A type', patch.addedAtoms[0][0].type)
self.assertEqual('B', patch.changedAtoms[1][0].name)
self.assertEqual('B type', patch.changedAtoms[1][0].type)
self.assertEqual('A', patch.addedBonds[0][0].name)
self.assertEqual(0, patch.addedBonds[0][0].residue)
self.assertEqual('B', patch.addedBonds[0][1].name)
self.assertEqual(1, patch.addedBonds[0][1].residue)
self.assertEqual('Test', ff._templatePatches['RESA'][0][0])
self.assertEqual(0, ff._templatePatches['RESA'][0][1])
self.assertEqual('Test', ff._templatePatches['RESB'][0][0])
self.assertEqual(1, ff._templatePatches['RESB'][0][1])
def testApplyPatch(self):
"""Test applying a patch to a template."""
xml = """
<ForceField>
<AtomTypes>
<Type name="A type" class="A class" element="O" mass="15.99943"/>
<Type name="B type" class="B class" element="H" mass="1.007947"/>
<Type name="C type" class="C class" element="H" mass="1.007947"/>
<Type name="D type" class="D class" element="C" mass="12.010000"/>
</AtomTypes>
<Residues>
<Residue name="RES">
<Atom name="A" type="A type"/>
<Atom name="B" type="B type"/>
<Atom name="C" type="C type"/>
<Bond atomName1="A" atomName2="B"/>
<Bond atomName1="B" atomName2="C"/>
<ExternalBond atomName="C"/>
<VirtualSite type="average2" siteName="C" atomName1="B" atomName2="C" weight1="0.6" weight2="0.4"/>
</Residue>
</Residues>
<Patches>
<Patch name="Test">
<AddAtom name="D" type="D type"/>
<ChangeAtom name="B" type="A type"/>
<RemoveAtom name="A"/>
<AddBond atomName1="B" atomName2="D"/>
<RemoveBond atomName1="A" atomName2="B"/>
<AddExternalBond atomName="D"/>
<RemoveExternalBond atomName="C"/>
<ApplyToResidue name="RES"/>
</Patch>
</Patches>
</ForceField>"""
ff = ForceField(StringIO(xml))
self.assertEqual(1, len(ff._patches))
patch = ff._patches['Test']
template = ff._templates['RES']
newTemplates = patch.createPatchedTemplates([template])
self.assertEqual(1, len(newTemplates))
t = newTemplates[0]
self.assertEqual(3, len(t.atoms))
self.assertTrue(any(a.name == 'B' and a.type == 'A type' for a in t.atoms))
self.assertTrue(any(a.name == 'C' and a.type == 'C type' for a in t.atoms))
self.assertTrue(any(a.name == 'D' and a.type == 'D type' for a in t.atoms))
indexMap = dict([(a.name, i) for i, a in enumerate(t.atoms)])
self.assertEqual(2, len(t.bonds))
self.assertTrue((indexMap['B'], indexMap['C']) in t.bonds)
self.assertTrue((indexMap['B'], indexMap['D']) in t.bonds)
self.assertEqual(1, len(t.externalBonds))
self.assertTrue(indexMap['D'] in t.externalBonds)
self.assertEqual(1, len(t.virtualSites))
v = t.virtualSites[0]
self.assertEqual('average2', v.type)
self.assertEqual(0.6, v.weights[0])
self.assertEqual(0.4, v.weights[1])
self.assertEqual(indexMap['C'], v.index)
self.assertEqual(indexMap['B'], v.atoms[0])
self.assertEqual(indexMap['C'], v.atoms[1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment