Commit b0baa00d authored by Robert McGibbon's avatar Robert McGibbon
Browse files

Support modern swig, strip units on the C++ side of the swig wrappers.

parent 04db8c60
......@@ -163,7 +163,6 @@ set(SWIG_INPUT_FILES2
"${SWIG_OPENMM_DIR}/swig_lib/python/header.i"
"${SWIG_OPENMM_DIR}/swig_lib/python/pythoncode.i"
"${SWIG_OPENMM_DIR}/swig_lib/python/typemaps.i"
"${SWIG_OPENMM_DIR}/swig_lib/python/pythonprepend_all.i"
)
# Create input files for swig
......@@ -176,6 +175,7 @@ add_custom_command(
-o OpenMMSwigHeaders.i
-a swig_lib/python/pythonprepend.i
-z swig_lib/python/pythonappend.i
-v "${SWIG_VERSION}"
WORKING_DIRECTORY "${SWIG_OPENMM_DIR}"
DEPENDS
"${SWIG_OPENMM_DIR}/swigInputConfig.py"
......
......@@ -123,17 +123,17 @@ class Element(object):
Element._elements_by_mass = OrderedDict()
for elem in sorted(Element._elements_by_symbol.values(),
key=lambda x: x.mass):
Element._elements_by_mass[elem.mass] = elem
Element._elements_by_mass[elem.mass.value_in_unit(daltons)] = elem
diff = mass
best_guess = None
for elemmass, element in Element._elements_by_mass.iteritems():
massdiff = abs(elemmass._value - mass)
massdiff = abs(elemmass - mass)
if massdiff < diff:
best_guess = element
diff = massdiff
if elemmass._value > mass:
if elemmass > mass:
# Elements are only getting heavier, so bail out early
return best_guess
......
......@@ -43,7 +43,9 @@ from simtk.openmm.app import Topology
def _convertParameterToNumber(param):
if unit.is_quantity(param):
return mm.stripUnits((param,))[0]
if param.unit.is_compatible(unit.bar):
return param / unit.bar
return param.value_in_unit_system(unit.md_unit_system)
return float(param)
# Enumerated values for nonbonded method
......@@ -115,10 +117,10 @@ class ForceField(object):
self._scripts = []
for file in files:
self.loadFile(file)
def loadFile(self, file):
"""Load an XML file and add the definitions from it to this FieldField.
Parameters:
- file (string or file) An XML file containing force field definitions. It may
be either an absolute file path, a path relative to the current working
......@@ -171,11 +173,11 @@ class ForceField(object):
def getGenerators(self):
"""Get the list of all registered generators."""
return self._forces
def registerGenerator(self, generator):
"""Register a new generator."""
self._forces.append(generator)
def registerAtomType(self, parameters):
"""Register a new atom type."""
name = parameters['name']
......@@ -196,7 +198,7 @@ class ForceField(object):
self._atomClasses[atomClass] = typeSet
typeSet.add(name)
self._atomClasses[''].add(name)
def registerResidueTemplate(self, template):
"""Register a new residue template."""
self._templates[template.name] = template
......@@ -205,7 +207,7 @@ class ForceField(object):
self._templateSignatures[signature].append(template)
else:
self._templateSignatures[signature] = [template]
def registerScript(self, script):
"""Register a new script to be executed after building the System."""
self._scripts.append(script)
......@@ -270,7 +272,7 @@ class ForceField(object):
self.virtualSites = []
self.bonds = []
self.externalBonds = []
def addBond(self, atom1, atom2):
self.bonds.append((atom1, atom2))
self.atoms[atom1].bondedTo.append(atom2)
......@@ -390,9 +392,9 @@ class ForceField(object):
sys = mm.System()
for atom in topology.atoms():
sys.addParticle(self._atomTypes[data.atomType[atom]][1])
# Adjust masses.
if hydrogenMass is not None:
for atom1, atom2 in topology.bonds():
if atom1.element == elem.hydrogen:
......@@ -643,46 +645,46 @@ def _findMatchErrors(forcefield, res):
residueCounts = _countResidueAtoms([atom.element for atom in res.atoms()])
numResidueAtoms = sum(residueCounts.itervalues())
numResidueHeavyAtoms = sum(residueCounts[element] for element in residueCounts if element not in (None, elem.hydrogen))
# Loop over templates and see how closely each one might match.
bestMatchName = None
numBestMatchAtoms = 3*numResidueAtoms
numBestMatchHeavyAtoms = 2*numResidueHeavyAtoms
for templateName in forcefield._templates:
template = forcefield._templates[templateName]
templateCounts = _countResidueAtoms([atom.element for atom in template.atoms])
# Does the residue have any atoms that clearly aren't in the template?
if any(element not in templateCounts or templateCounts[element] < residueCounts[element] for element in residueCounts):
continue
# If there are too many missing atoms, discard this template.
numTemplateAtoms = sum(templateCounts.itervalues())
numTemplateHeavyAtoms = sum(templateCounts[element] for element in templateCounts if element not in (None, elem.hydrogen))
if numTemplateAtoms > numBestMatchAtoms:
continue
if numTemplateHeavyAtoms > numBestMatchHeavyAtoms:
continue
# If this template has the same number of missing atoms as our previous best one, look at the name
# to decide which one to use.
if numTemplateAtoms == numBestMatchAtoms:
if bestMatchName == res.name or res.name not in templateName:
continue
# Accept this as our new best match.
bestMatchName = templateName
numBestMatchAtoms = numTemplateAtoms
numBestMatchHeavyAtoms = numTemplateHeavyAtoms
numBestMatchExtraParticles = len([atom for atom in template.atoms if atom.element is None])
# Return an appropriate error message.
if numBestMatchAtoms == numResidueAtoms:
chainResidues = list(res.chain.residues())
if len(chainResidues) > 1 and (res == chainResidues[0] or res == chainResidues[-1]):
......@@ -714,7 +716,7 @@ class HarmonicBondGenerator:
self.types2 = []
self.length = []
self.k = []
def registerBond(self, parameters):
types = self.ff._findAtomTypes(parameters, 2)
if None not in types:
......@@ -1183,7 +1185,7 @@ class NonbondedGenerator:
def postprocessSystem(self, sys, data, args):
# Create exceptions based on bonds.
bondIndices = []
for bond in data.bonds:
bondIndices.append((bond.atom1, bond.atom2))
......@@ -1195,10 +1197,10 @@ class NonbondedGenerator:
(site, atoms, excludeWith) = data.virtualSites[data.atoms[i]]
if excludeWith is None:
bondIndices.append((i, site.getParticle(0)))
# Certain particles, such as lone pairs and Drude particles, share exclusions with a parent atom.
# If the parent atom does not interact with an atom, the child particle does not either.
for atom1, atom2 in bondIndices:
for child1 in data.excludeAtomWith[atom1]:
bondIndices.append((child1, atom2))
......@@ -1629,7 +1631,7 @@ class CustomNonbondedGenerator:
def postprocessSystem(self, sys, data, args):
# Create exclusions based on bonds.
bondIndices = []
for bond in data.bonds:
bondIndices.append((bond.atom1, bond.atom2))
......@@ -1641,10 +1643,10 @@ class CustomNonbondedGenerator:
(site, atoms, excludeWith) = data.virtualSites[data.atoms[i]]
if excludeWith is None:
bondIndices.append((i, site.getParticle(0)))
# Certain particles, such as lone pairs and Drude particles, share exclusions with a parent atom.
# If the parent atom does not interact with an atom, the child particle does not either.
for atom1, atom2 in bondIndices:
for child1 in data.excludeAtomWith[atom1]:
bondIndices.append((child1, atom2))
......@@ -1654,7 +1656,7 @@ class CustomNonbondedGenerator:
bondIndices.append((atom1, child2))
# Create the exclusions.
nonbonded = [f for f in sys.getForces() if isinstance(f, mm.CustomNonbondedForce)][0]
nonbonded.createExclusionsFromBonds(bondIndices, self.bondCutoff)
......@@ -1765,7 +1767,7 @@ class CustomManyParticleGenerator:
self.perParticleParams = []
self.functions = []
self.typeFilters = []
@staticmethod
def parseElement(element, ff):
permutationMap = {"SinglePermutation" : mm.CustomManyParticleForce.SinglePermutation,
......@@ -1825,7 +1827,7 @@ class CustomManyParticleGenerator:
def postprocessSystem(self, sys, data, args):
# Create exclusions based on bonds.
bondIndices = []
for bond in data.bonds:
bondIndices.append((bond.atom1, bond.atom2))
......@@ -1837,10 +1839,10 @@ class CustomManyParticleGenerator:
(site, atoms, excludeWith) = data.virtualSites[data.atoms[i]]
if excludeWith is None:
bondIndices.append((i, site.getParticle(0)))
# Certain particles, such as lone pairs and Drude particles, share exclusions with a parent atom.
# If the parent atom does not interact with an atom, the child particle does not either.
for atom1, atom2 in bondIndices:
for child1 in data.excludeAtomWith[atom1]:
bondIndices.append((child1, atom2))
......@@ -1850,7 +1852,7 @@ class CustomManyParticleGenerator:
bondIndices.append((atom1, child2))
# Create the exclusions.
nonbonded = [f for f in sys.getForces() if isinstance(f, mm.CustomManyParticleForce)][0]
nonbonded.createExclusionsFromBonds(bondIndices, self.bondCutoff)
......@@ -3244,7 +3246,7 @@ class AmoebaVdwGenerator:
exclusionSet.add(i)
force.setParticleExclusions(i, exclusionSet)
force.setParticleExclusions(i, tuple(exclusionSet))
parsers["AmoebaVdwForce"] = AmoebaVdwGenerator.parseElement
......@@ -3896,12 +3898,12 @@ class AmoebaMultipoleGenerator:
newIndex = force.addMultipole(savedMultipoleDict['charge'], savedMultipoleDict['dipole'], savedMultipoleDict['quadrupole'], savedMultipoleDict['axisType'],
zaxis, xaxis, yaxis, savedMultipoleDict['thole'], savedMultipoleDict['pdamp'], savedMultipoleDict['polarizability'])
if (atomIndex == newIndex):
force.setCovalentMap(atomIndex, mm.AmoebaMultipoleForce.Covalent12, bonded12ParticleSets[atomIndex])
force.setCovalentMap(atomIndex, mm.AmoebaMultipoleForce.Covalent13, bonded13ParticleSets[atomIndex])
force.setCovalentMap(atomIndex, mm.AmoebaMultipoleForce.Covalent14, bonded14ParticleSets[atomIndex])
force.setCovalentMap(atomIndex, mm.AmoebaMultipoleForce.Covalent15, bonded15ParticleSets[atomIndex])
force.setCovalentMap(atomIndex, mm.AmoebaMultipoleForce.Covalent12, tuple(bonded12ParticleSets[atomIndex]))
force.setCovalentMap(atomIndex, mm.AmoebaMultipoleForce.Covalent13, tuple(bonded13ParticleSets[atomIndex]))
force.setCovalentMap(atomIndex, mm.AmoebaMultipoleForce.Covalent14, tuple(bonded14ParticleSets[atomIndex]))
force.setCovalentMap(atomIndex, mm.AmoebaMultipoleForce.Covalent15, tuple(bonded15ParticleSets[atomIndex]))
else:
raise ValueError("Atom %s of %s %d is out of synch!." %(atom.name, atom.residue.name, atom.residue.index))
raise ValueError("Atom %s of %s %d is out of sync!." %(atom.name, atom.residue.name, atom.residue.index))
else:
raise ValueError("Atom %s of %s %d was not assigned." %(atom.name, atom.residue.name, atom.residue.index))
else:
......
%module openmm
%include "factory.i"
%include "std_string.i"
%include "std_iostream.i"
%include "typemaps.i"
%include "std_map.i"
%include "std_pair.i"
%include "std_set.i"
%include "std_vector.i"
namespace std {
%template(pairii) pair<int,int>;
%template(vectord) vector<double>;
......@@ -23,6 +22,7 @@ namespace std {
%template(seti) set<int>;
};
%include "typemaps.i"
%include "windows.i"
%{
......
#!/usr/bin/env python
#
#
"""Build swig imput file from xml encoded header files (see gccxml)."""
__author__ = "Randall J. Radmer"
__version__ = "1.0"
......@@ -12,6 +9,7 @@ import time
import getopt
import re
import xml.etree.ElementTree as etree
from distutils.version import LooseVersion
try:
from html.parser import HTMLParser
......@@ -19,6 +17,9 @@ except ImportError:
# python 2
from HTMLParser import HTMLParser
INDENT = " "
docTags = {'emphasis':'i', 'bold':'b', 'itemizedlist':'ul', 'listitem':'li', 'preformatted':'pre', 'computeroutput':'tt', 'subscript':'sub'}
def striphtmltags(s):
"""Strip a couple html tags used inside docstrings in the C++ source
......@@ -52,11 +53,6 @@ def striphtmltags(s):
return s
INDENT = " ";
docTags = {'emphasis':'i', 'bold':'b', 'itemizedlist':'ul', 'listitem':'li', 'preformatted':'pre', 'computeroutput':'tt', 'subscript':'sub'}
def trimToSingleSpace(text):
if text is None or len(text) == 0:
return ""
......@@ -158,8 +154,10 @@ class SwigInputBuilder:
docstringFilename=None,
pythonprependFilename=None,
pythonappendFilename=None,
skipAdditionalMethods=[]):
skipAdditionalMethods=[],
SWIG_VERSION='3.0.2'):
self.nodeByID={}
self.SWIG_COMPACT_ARGUMENTS = LooseVersion(SWIG_VERSION) < LooseVersion('3.0.6')
self.configModule = __import__(os.path.splitext(configFilename)[0])
......@@ -444,25 +442,20 @@ class SwigInputBuilder:
key=(shortClassName, methName)
if key in self.configModule.STEAL_OWNERSHIP:
for argNum in self.configModule.STEAL_OWNERSHIP[key]:
self.fOutPythonprepend.write("%pythonprepend")
self.fOutPythonprepend.write(" OpenMM::%s::%s%s %%{\n"
% (shortClassName,
methName,
mArgsstring))
self.fOutPythonprepend.write(
"%sif not args[%s].thisown:\n"
% (INDENT, argNum))
s = 's = "the %s object does not own its'
s = '%s corresponding OpenMM object" \\' % s
self.fOutPythonprepend.write("%s %s\n" % (INDENT, s))
s = ' %% args[%s].__class__.__name__' % argNum
self.fOutPythonprepend.write("%s %s\n" % (INDENT, s))
s = "raise Exception(s)"
self.fOutPythonprepend.write("%s %s\n" % (INDENT, s))
self.fOutPythonprepend.write("%}\n\n")
if self.SWIG_COMPACT_ARGUMENTS:
argName = 'args[%s]' % argNum
else:
argName = getText('declname', paramList[argNum])
text = '''
%pythonprepend OpenMM::{shortClassName}::{methName}{mArgsstring} %{{
if not {argName}.thisown:
s = ("the %s object does not own its corresponding OpenMM object"
% self.__class__.__name__)
raise Exception(s)
%}}'''.format(argName=argName, shortClassName=shortClassName, methName=methName, mArgsstring=mArgsstring)
self.fOutPythonprepend.write(text)
#write pythonappend blocks
if self.fOutPythonappend \
......@@ -505,8 +498,12 @@ class SwigInputBuilder:
if key in self.configModule.STEAL_OWNERSHIP:
for argNum in self.configModule.STEAL_OWNERSHIP[key]:
addText = "%s%sargs[%s].thisown=0\n" \
% (addText, INDENT, argNum)
if self.SWIG_COMPACT_ARGUMENTS:
argName = 'args[%s]' % argNum
else:
argName = getText('declname', paramList[argNum])
addText = "%s%s%s.thisown=0\n" \
% (addText, INDENT, argName)
if addText:
self.fOutPythonappend.write("%pythonappend")
......@@ -592,7 +589,7 @@ class SwigInputBuilder:
def parseCommandLine():
opts, args_proper = getopt.getopt(sys.argv[1:], 'hi:c:o:d:a:z:s:')
opts, args_proper = getopt.getopt(sys.argv[1:], 'hi:c:o:d:a:z:s:v:')
inputDirname = None
configFilename = None
outputFilename = ""
......@@ -600,6 +597,7 @@ def parseCommandLine():
pythonprependFilename = ""
pythonappendFilename = ""
skipAdditionalMethods = []
swigVersion = '3.0.2'
for option, parameter in opts:
if option=='-h': usageError()
if option=='-i': inputDirname = parameter
......@@ -609,19 +607,21 @@ def parseCommandLine():
if option=='-a': pythonprependFilename=parameter
if option=='-z': pythonappendFilename=parameter
if option=='-s': skipAdditionalMethods.append(parameter)
if option=='-v': swigVersion = parameter
if not inputDirname: usageError()
if not configFilename: usageError()
return (args_proper, inputDirname, configFilename, outputFilename,
docstringFilename,
pythonprependFilename, pythonappendFilename, skipAdditionalMethods)
docstringFilename, pythonprependFilename, pythonappendFilename,
skipAdditionalMethods, swigVersion)
def main():
(args_proper, inputDirname, configFilename, outputFilename,
docstringFilename, pythonprependFilename, pythonappendFilename,
skipAdditionalMethods) = parseCommandLine()
skipAdditionalMethods, swigVersion) = parseCommandLine()
sBuilder = SwigInputBuilder(inputDirname, configFilename, outputFilename,
docstringFilename, pythonprependFilename,
pythonappendFilename, skipAdditionalMethods)
pythonappendFilename, skipAdditionalMethods,
swigVersion)
#print "Calling writeSwigFile\n"
sBuilder.writeSwigFile()
#print "Done writeSwigFile\n"
......@@ -644,6 +644,8 @@ def usageError():
% (' '*len(os.path.basename(sys.argv[0]))))
sys.stdout.write(' %s[-s skippedClasses]\n' \
% (' '*len(os.path.basename(sys.argv[0]))))
sys.stdout.write(' %s[-v swigVersion]\n' \
% (' '*len(os.path.basename(sys.argv[0]))))
sys.exit(1)
if __name__=='__main__':
......
......@@ -307,7 +307,7 @@ UNITS = {
("Context", "getParameter") : (None, ()),
("Context", "getMolecules") : (None, ()),
("CMAPTorsionForce", "getMapParameters") : (None, ()),
("CMAPTorsionForce", "getMapParameters") : (None, (None, 'unit.kilojoule_per_mole')),
("CMAPTorsionForce", "getTorsionParameters") : (None, ()),
("CMMotionRemover", "getFrequency") : (None, ()),
("CustomAngleForce", "getNumPerAngleParameters") : (None, ()),
......
......@@ -2,16 +2,6 @@
%include exceptions.i
%include extend.i
%include header.i
%include pythonprepend_all.i
%include pythonprepend.i
%include pythonappend.i
%include typemaps.i
/* SWIG 3.x resolved a bug in which all wrapped C++ functions took *args as its
* default argument list. OpenMM then exploited this bug by doing stuff like
* passing args to stripUnits (and all added code assumed that the arguments
* were in an "args" list). So in order to restore this arguably buggy behavior
* from SWIG 2, enable the "compactdefaultargs" feature globally.
*
* See https://github.com/swig/swig/issues/387
*/
%feature("compactdefaultargs");
......@@ -218,69 +218,6 @@ class State(_object):
raise TypeError('Parameters were not requested in getState() call, so are not available.')
return self._paramMap
def stripUnits(args):
"""
getState(self, quantity)
-> value with *no* units
Examples
>>> import simtk
>>> x = 5
>>> print x
5
>>> x = stripUnits((5*simtk.unit.nanometer,))
>>> x
(5,)
>>> arg1 = 5*simtk.unit.angstrom
>>> x = stripUnits((arg1,))
>>> x
(0.5,)
>>> arg1 = 5
>>> x = stripUnits((arg1,))
>>> x
(5,)
>>> arg1 = (1*simtk.unit.angstrom, 5*simtk.unit.angstrom)
>>> x = stripUnits((arg1,))
>>> x
((0.10000000000000001, 0.5),)
>>> arg1 = (1*simtk.unit.angstrom,
... 5*simtk.unit.kilojoule_per_mole,
... 1*simtk.unit.kilocalorie_per_mole)
>>> y = stripUnits((arg1,))
>>> y
((0.10000000000000001, 5, 4.1840000000000002),)
"""
newArgList=[]
for arg in args:
if 'numpy' in sys.modules and isinstance(arg, numpy.ndarray):
arg = arg.tolist()
elif unit.is_quantity(arg):
# JDC: Ugly workaround for OpenMM using 'bar' for fundamental pressure unit.
if arg.unit.is_compatible(unit.bar):
arg = arg / unit.bar
else:
arg = arg.value_in_unit_system(unit.md_unit_system)
# JDC: End workaround.
elif isinstance(arg, dict):
newKeys = stripUnits(arg.keys())
newValues = stripUnits(arg.values())
arg = dict(zip(newKeys, newValues))
elif not isinstance(arg, _string_types):
try:
# Reclusively strip units from all quantities
arg=stripUnits(arg)
except TypeError:
pass
newArgList.append(arg)
return tuple(newArgList)
%}
%pythonappend OpenMM::Context::Context %{
......
%pythonprepend %{
try: args=stripUnits(args)
except UnboundLocalError: pass
%}
%fragment("Py_StripOpenMMUnits", "header") {
/* Convert python list of tuples to C++ std::vector of Vec3 objects */
%typemap(in) const std::vector<Vec3>& (std::vector<OpenMM::Vec3> vVec) {
// typemap -- %typemap(in) std::vector<Vec3>& (std::vector<OpenMM::Vec3> vVec)
int i, pLength, itemLength;
double x, y, z;
PyObject *o;
PyObject *o1;
pLength=(int)PySequence_Length($input);
for (i=0; i<pLength; i++) {
o=PySequence_GetItem($input, i);
itemLength = (int) PySequence_Length(o);
if (itemLength != 3) {
PyErr_SetString(PyExc_TypeError, "Item must have length 3");
return NULL;
static PyObject *__s_Quantity = NULL;
static PyObject *__s_md_unit_system_tuple = NULL;
static PyObject *__s_bar_tuple = NULL;
PyObject* Py_StripOpenMMUnits(PyObject *input) {
if (__s_Quantity == NULL) {
PyObject* module = NULL;
module = PyImport_ImportModule("simtk.unit");
if (!module) {
PyErr_SetString(PyExc_ImportError, "simtk.unit"); Py_CLEAR(module); return NULL;
}
__s_Quantity = PyObject_GetAttrString(module, "Quantity");
if (!__s_Quantity) {
PyErr_SetString(PyExc_AttributeError, "'module' object has no attribute 'Quantity'");
Py_CLEAR(module);
Py_CLEAR(__s_Quantity);
return NULL;
}
PyObject* bar = NULL;
bar = PyObject_GetAttrString(module, "bar");
if (!bar) {
PyErr_SetString(PyExc_AttributeError, "'module' object has no attribute 'bar'");
Py_CLEAR(module);
Py_CLEAR(__s_Quantity);
Py_CLEAR(bar);
return NULL;
}
PyObject* md_unit_system = NULL;
md_unit_system = PyObject_GetAttrString(module, "md_unit_system");
if (!md_unit_system) {
PyErr_SetString(PyExc_AttributeError, "'module' object has no attribute 'md_unit_system'");
Py_CLEAR(module);
Py_CLEAR(__s_Quantity);
Py_CLEAR(bar);
Py_CLEAR(md_unit_system);
}
__s_md_unit_system_tuple = PyTuple_Pack(1, md_unit_system);
__s_bar_tuple = PyTuple_Pack(1, bar);
Py_DECREF(md_unit_system);
Py_DECREF(bar);
Py_DECREF(module);
}
PyObject *val;
if (PyObject_IsInstance(input, __s_Quantity)) {
PyObject* input_unit = NULL, *is_compatible = NULL, *compatible_with_bar = NULL;
input_unit = PyObject_GetAttrString(input, "unit");
is_compatible = PyObject_GetAttrString(input_unit, "is_compatible");
compatible_with_bar = PyObject_Call(is_compatible, __s_bar_tuple, NULL);
if (PyObject_IsTrue(compatible_with_bar)) {
// input.in_units_of(unit.bar)
PyObject* value_in_unit = PyObject_GetAttrString(input, "value_in_unit");
val = PyObject_Call(value_in_unit, __s_bar_tuple, NULL);
Py_DECREF(value_in_unit);
} else {
// input.value_in_unit_system(md_unit_system_tuple)
PyObject* value_in_unit_system = PyObject_GetAttrString(input, "value_in_unit_system");
val = PyObject_Call(value_in_unit_system, __s_md_unit_system_tuple, NULL);
Py_DECREF(value_in_unit_system);
}
Py_CLEAR(input_unit);
Py_CLEAR(is_compatible);
Py_CLEAR(compatible_with_bar);
if (PyErr_Occurred() != NULL) {
return NULL;
}
} else {
val = input;
Py_INCREF(val);
}
return val;
}
}
o1=PySequence_GetItem(o, 0);
x=PyFloat_AsDouble(o1);
Py_DECREF(o1);
%fragment("Py_SequenceToVec3", "header", fragment="Py_StripOpenMMUnits") {
OpenMM::Vec3 Py_SequenceToVec3(PyObject* obj, int& status) {
PyObject* s, *o, *o1;
double x[3];
int i, length;
s = Py_StripOpenMMUnits(obj);
if (s == NULL) {
status = SWIG_ERROR;
return OpenMM::Vec3(0, 0, 0);
}
o1=PySequence_GetItem(o, 1);
y=PyFloat_AsDouble(o1);
Py_DECREF(o1);
length = (int) PySequence_Length(s);
if (length != 3) {
Py_DECREF(s);
PyErr_SetString(PyExc_TypeError, "Item must have length 3");
status = SWIG_ERROR;
return OpenMM::Vec3(0, 0, 0);
}
o1=PySequence_GetItem(o, 2);
z=PyFloat_AsDouble(o1);
Py_DECREF(o1);
for (i = 0; i < 3; i++ ) {
o = PySequence_GetItem(s, i);
o1 = Py_StripOpenMMUnits(o);
if (o1 != NULL) {
x[i] = PyFloat_AsDouble(o1);
}
if (o1 == NULL || PyErr_Occurred() != NULL) {
Py_DECREF(s);
Py_DECREF(o);
Py_XDECREF(o1);
status = SWIG_ERROR;
return OpenMM::Vec3(0, 0, 0);
}
Py_DECREF(o);
Py_DECREF(o1);
}
Py_DECREF(o);
vVec.push_back( OpenMM::Vec3(x, y, z) );
status = SWIG_OK;
Py_DECREF(s);
return OpenMM::Vec3(x[0], x[1], x[2]);
}
$1 = &vVec;
}
%fragment("Py_SequenceToVecDouble", "header", fragment="Py_StripOpenMMUnits") {
int Py_SequenceToVecDouble(PyObject* obj, std::vector<double>& out) {
PyObject* stripped = NULL;
PyObject* item = NULL;
PyObject* item1 = NULL;
PyObject* iterator = NULL;
stripped = Py_StripOpenMMUnits(obj);
iterator = PyObject_GetIter(stripped);
if (iterator == NULL) {
Py_DECREF(stripped);
return SWIG_ERROR;
}
while ((item = PyIter_Next(iterator))) {
item1 = Py_StripOpenMMUnits(item);
if (item1 == NULL) {
Py_DECREF(item);
return SWIG_ERROR;
}
double d = PyFloat_AsDouble(item1);
Py_DECREF(item);
Py_DECREF(item1);
if (PyErr_Occurred() != NULL) {
return SWIG_ERROR;
}
out.push_back(d);
}
Py_DECREF(iterator);
Py_DECREF(stripped);
return SWIG_OK;
}
}
%typemap(typecheck, precedence=SWIG_TYPECHECK_DOUBLE, fragment="Py_StripOpenMMUnits") double {
double argp = 0;
PyObject* s = NULL;
s = Py_StripOpenMMUnits($input);
$1 = (s != NULL) ? SWIG_IsOK(SWIG_AsVal_double(s, &argp)) : 0;
Py_DECREF(s);
}
%typemap(in, noblock=1, fragment="Py_StripOpenMMUnits") double (double argp = 0, int res = 0,
PyObject* stripped = NULL) {
stripped = Py_StripOpenMMUnits($input);
if (stripped == NULL) { SWIG_fail; }
res = SWIG_AsVal_double(stripped, &argp);
if (!SWIG_IsOK(res)) {
PyErr_SetString(PyExc_ValueError, "in method $symname, argument $argnum could not be converted to type $type");
SWIG_fail;
}
$1 = ($ltype)(argp);
Py_CLEAR(stripped);
}
%typemap(in, fragment="Py_SequenceToVec3") Vec3 (int res=0){
// typemap -- %typemap(in) Vec3
$1 = Py_SequenceToVec3($input, res);
if (!SWIG_IsOK(res)) {
PyErr_SetString(PyExc_ValueError, "in method $symname, argument $argnum could not be converted to type $type");
SWIG_fail;
}
}
%typemap(in, fragment="Py_SequenceToVec3") const Vec3& (OpenMM::Vec3 myVec, int res=0) {
// typemap -- %typemap(in) Vec3
myVec = Py_SequenceToVec3($input, res);
if (!SWIG_IsOK(res)) {
PyErr_SetString(PyExc_ValueError, "in method $symname, argument $argnum could not be converted to type $type");
SWIG_fail;
}
$1 = &myVec;
}
/* Convert python list of tuples to C++ std::vector of Vec3 objects */
%typemap(in, fragment="Py_SequenceToVec3") const std::vector<Vec3>& (std::vector<OpenMM::Vec3> vVec, PyObject* s=NULL, PyObject* o=NULL) {
int i, pLength, ret;
s = Py_StripOpenMMUnits($input);
pLength = (int)PySequence_Length(s);
for (i = 0; i < pLength; i++) {
o = PySequence_GetItem(s, i);
OpenMM::Vec3 v = Py_SequenceToVec3(o, ret);
if (!SWIG_IsOK(ret)) {
Py_DECREF(s);
Py_DECREF(o);
PyErr_SetString(PyExc_ValueError, "in method $symname, argument $argnum could not be converted to type $type");
SWIG_fail;
}
vVec.push_back(v);
}
$1 = &vVec;
Py_DECREF(s);
}
%typemap(in, fragment="Py_SequenceToVecDouble") const std::vector<double> & (std::vector<double> v, int res=0) {
res = Py_SequenceToVecDouble($input, v);
if (!SWIG_IsOK(res)) {
PyErr_SetString(PyExc_ValueError, "in method $symname, argument $argnum could not be converted to type $type");
SWIG_fail;
}
$1 = &v;
}
/* The following two typemaps cause a non-const vector<Vec3>& to become a return value. */
......@@ -39,6 +236,7 @@
$1 = &temp;
}
%typemap(argout) std::vector<Vec3>& {
int i, n;
PyObject *pyList;
......@@ -57,113 +255,84 @@
$result = pyList;
}
/* const vector<Vec3> should NOT become an output. */
%typemap(argout) const std::vector<Vec3>& {
}
/* Convert python tuple to C++ Vec3 object*/
%typemap(typecheck) Vec3 {
// typemap -- %typemap(typecheck) Vec3
$1 = (PySequence_Length($input) >= 3 ? 1 : 0);
// typemap -- %typemap(typecheck) Vec3
$1 = (PySequence_Length($input) >= 3 ? 1 : 0);
}
%typemap(in) Vec3 {
// typemap -- %typemap(in) Vec3
double x, y, z;
PyObject *o;
o=PySequence_GetItem($input, 0);
x=PyFloat_AsDouble(o);
Py_DECREF(o);
o=PySequence_GetItem($input, 1);
y=PyFloat_AsDouble(o);
Py_DECREF(o);
o=PySequence_GetItem($input, 2);
z=PyFloat_AsDouble(o);
Py_DECREF(o);
$1 = OpenMM::Vec3(x, y, z);
}
%typemap(typecheck) const Vec3& {
// typemap -- %typemap(typecheck) Vec3
$1 = (PySequence_Length($input) >= 3 ? 1 : 0);
// typemap -- %typemap(typecheck) Vec3
$1 = (PySequence_Length($input) >= 3 ? 1 : 0);
}
%typemap(in) const Vec3& (OpenMM::Vec3 myVec) {
// typemap -- %typemap(in) Vec3
double x, y, z;
PyObject *o;
o=PySequence_GetItem($input, 0);
x=PyFloat_AsDouble(o);
Py_DECREF(o);
o=PySequence_GetItem($input, 1);
y=PyFloat_AsDouble(o);
Py_DECREF(o);
o=PySequence_GetItem($input, 2);
z=PyFloat_AsDouble(o);
Py_DECREF(o);
myVec = OpenMM::Vec3(x, y, z);
$1 = &myVec;
}
%typemap(out) Vec3 {
PyObject* mm = PyImport_AddModule("simtk.openmm");
PyObject* vec3 = PyObject_GetAttrString(mm, "Vec3");
PyObject* args = Py_BuildValue("(d,d,d)", ($1)[0], ($1)[1], ($1)[2]);
$result = PyObject_CallObject(vec3, args);
Py_DECREF(args);
PyObject* mm = PyImport_AddModule("simtk.openmm");
PyObject* vec3 = PyObject_GetAttrString(mm, "Vec3");
PyObject* args = Py_BuildValue("(d,d,d)", ($1)[0], ($1)[1], ($1)[2]);
$result = PyObject_CallObject(vec3, args);
Py_DECREF(args);
}
%typemap(out) const Vec3& {
PyObject* mm = PyImport_AddModule("simtk.openmm");
PyObject* vec3 = PyObject_GetAttrString(mm, "Vec3");
PyObject* args = Py_BuildValue("(d,d,d)", (*$1)[0], (*$1)[1], (*$1)[2]);
$result = PyObject_CallObject(vec3, args);
Py_DECREF(args);
PyObject* mm = PyImport_AddModule("simtk.openmm");
PyObject* vec3 = PyObject_GetAttrString(mm, "Vec3");
PyObject* args = Py_BuildValue("(d,d,d)", (*$1)[0], (*$1)[1], (*$1)[2]);
$result = PyObject_CallObject(vec3, args);
Py_DECREF(args);
}
/* Convert C++ (Vec3&, Vec3&, Vec3&) object to python tuple or tuples */
%typemap(argout) (Vec3& a, Vec3& b, Vec3& c) {
// %typemap(argout) (Vec3& a, Vec3& b, Vec3& c)
PyObject* mm = PyImport_AddModule("simtk.openmm");
PyObject* vec3 = PyObject_GetAttrString(mm, "Vec3");
PyObject* args1 = Py_BuildValue("(d,d,d)", (*$1)[0], (*$1)[1], (*$1)[2]);
PyObject* args2 = Py_BuildValue("(d,d,d)", (*$2)[0], (*$2)[1], (*$2)[2]);
PyObject* args3 = Py_BuildValue("(d,d,d)", (*$3)[0], (*$3)[1], (*$3)[2]);
PyObject* pyVec1 = PyObject_CallObject(vec3, args1);
PyObject* pyVec2 = PyObject_CallObject(vec3, args2);
PyObject* pyVec3 = PyObject_CallObject(vec3, args3);
Py_DECREF(args1);
Py_DECREF(args2);
Py_DECREF(args3);
PyObject *o, *o2, *o3;
o = Py_BuildValue("[N, N, N]", pyVec1, pyVec2, pyVec3);
if ((!$result) || ($result == Py_None)) {
$result = o;
} else {
if (!PyTuple_Check($result)) {
PyObject *o2 = $result;
$result = PyTuple_New(1);
PyTuple_SetItem($result, 0, o2);
// %typemap(argout) (Vec3& a, Vec3& b, Vec3& c)
PyObject* mm = PyImport_AddModule("simtk.openmm");
PyObject* vec3 = PyObject_GetAttrString(mm, "Vec3");
PyObject* args1 = Py_BuildValue("(d,d,d)", (*$1)[0], (*$1)[1], (*$1)[2]);
PyObject* args2 = Py_BuildValue("(d,d,d)", (*$2)[0], (*$2)[1], (*$2)[2]);
PyObject* args3 = Py_BuildValue("(d,d,d)", (*$3)[0], (*$3)[1], (*$3)[2]);
PyObject* pyVec1 = PyObject_CallObject(vec3, args1);
PyObject* pyVec2 = PyObject_CallObject(vec3, args2);
PyObject* pyVec3 = PyObject_CallObject(vec3, args3);
Py_DECREF(args1);
Py_DECREF(args2);
Py_DECREF(args3);
Py_DECREF(mm);
Py_DECREF(vec3);
PyObject *o, *o2, *o3;
o = Py_BuildValue("[N, N, N]", pyVec1, pyVec2, pyVec3);
if ((!$result) || ($result == Py_None)) {
$result = o;
} else {
if (!PyTuple_Check($result)) {
PyObject *o2 = $result;
$result = PyTuple_New(1);
PyTuple_SetItem($result, 0, o2);
}
o3 = PyTuple_New(1);
PyTuple_SetItem(o3, 0, o);
o2 = $result;
$result = PySequence_Concat(o2, o3);
Py_DECREF(o2);
Py_DECREF(o3);
}
o3 = PyTuple_New(1);
PyTuple_SetItem(o3, 0, o);
o2 = $result;
$result = PySequence_Concat(o2, o3);
Py_DECREF(o2);
Py_DECREF(o3);
}
}
%typemap(in, numinputs=0) (Vec3& a, Vec3& b, Vec3& c) (Vec3 tempA, Vec3 tempB, Vec3 tempC) {
$1 = &tempA;
$2 = &tempB;
$3 = &tempC;
$1 = &tempA;
$2 = &tempB;
$3 = &tempC;
}
%typemap(out) std::string OpenMM::Context::createCheckpoint{
// createCheckpoint returns a bytes object
$result = PyBytes_FromStringAndSize($1.c_str(), $1.length());
......
......@@ -33,16 +33,17 @@ class TestElement(unittest.TestCase):
"""
min_diff = mass
closest_element = None
for symbol, elem in element.Element._elements_by_symbol.items():
for elem in sorted(element.Element._elements_by_symbol.values(),
key=lambda x:x.mass):
diff = abs(elem.mass._value - mass)
if diff < min_diff:
min_diff = diff
closest_element = elem
return closest_element
# Check 5000 random numbers between 0 and 300
for i in range(5000):
mass = random.random() * 300
# Check 500 random numbers between 0 and 200
for i in range(500):
mass = random.random() * 200
elem = element.Element.getByMass(mass)
self.assertTrue(elem is exhaustive_search(mass))
......
......@@ -74,6 +74,8 @@ class TestNumpyCompatibility(unittest.TestCase):
energy = np.random.randn(10*10)
f.addMap(10, energy)
size, energy_out = f.getMapParameters(0)
energy_out = energy_out.value_in_unit_system(unit.md_unit_system)
self.assertEqual(size, 10)
np.testing.assert_array_almost_equal(energy, np.asarray(energy_out))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment