swigInputBuilder.py 34.2 KB
Newer Older
Robert McGibbon's avatar
typo  
Robert McGibbon committed
1
#!/usr/bin/env python
2
3
4
"""Build swig imput file from xml encoded header files (see gccxml)."""
__author__ = "Randall J. Radmer"
__version__ = "1.0"
5
6


7
8
9
10
import sys, os
import time
import getopt
import re
11
import xml.etree.ElementTree as etree
12

13
14
15
16
17
18
try:
    from html.parser import HTMLParser
except ImportError:
    # python 2
    from HTMLParser import HTMLParser

19
INDENT = "   "
20
docTags = {'emphasis':'i', 'bold':'b', 'itemizedlist':'ul', 'listitem':'li', 'preformatted':'pre', 'computeroutput':'tt', 'subscript':'sub', 'verbatim': 'verbatim'}
21

22
23
24
def is_method_abstract(argstring):
    return argstring.split(")")[-1].find("=0") >= 0

25
26
27
28
29
30
def striphtmltags(s):
    """Strip a couple html tags used inside docstrings in the C++ source
    to produce something more easily read as plain text.
    """
    class ConvertLists(HTMLParser):
        def reset(self):
Robert McGibbon's avatar
Robert McGibbon committed
31
            HTMLParser.reset(self)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
            self.out = []

        def handle_starttag(self, tag, attrs):
            if tag == 'li':
                self.out.append('\n - ')
        def handle_data(self, data):
            self.out.append(data.strip())

    convertlists = ConvertLists()

    def replace_ul_tags(m):
        a, b = m.span()
        sub = s[a:b]

        convertlists.reset()
        convertlists.feed(sub)
        return '\n%s\n\n' % ''.join(convertlists.out)

Robert McGibbon's avatar
typo  
Robert McGibbon committed
50
51
52
    s = s.replace('<i>', '_').replace('</i>', '_')
    s = s.replace('<b>', '*').replace('</b>', '*')

Robert McGibbon's avatar
Robert McGibbon committed
53
    s = re.sub('\s*(<ul>.*?</ul>\s*)', replace_ul_tags, s, flags=re.MULTILINE | re.DOTALL)
54
55
    return s

56
57
58
59
60
61
62
63
64
65
66
def trimToSingleSpace(text):
    if text is None or len(text) == 0:
        return ""
    t = text.strip()
    if len(t) == 0:
        return t
    if text[0].isspace():
        t = " %s" % t
    if text[-1].isspace():
        t = "%s " % t
    return t
67

68
69
70
71
72
73
74
75
76
77
def getNodeText(node):
    if node.text is not None:
        s = node.text
    else:
        s = ""
    for n in node:
        if n.tag == "para":
            s = "%s%s\n\n" % (s, getNodeText(n))
        elif n.tag == "ref":
            s = "%s%s" % (s, getNodeText(n))
78
79
80
81
82
        elif n.tag == "xrefsect":
            title = n.find("xreftitle")
            description = n.find("xrefdescription")
            if title is not None and description is not None and getNodeText(title).lower() == "deprecated":
                s = "%s\n@deprecated %s\n\n" % (s, getNodeText(description))
83
84
85
86
87
88
89
90
        else:
            if n.tag in docTags:
                tag = docTags[n.tag]
                s = "%s<%s>%s</%s>" % (s, tag, getNodeText(n), tag)
        if n.tail is not None:
            s = "%s%s" % (s, n.tail)
    return s

91
def getText(subNodePath, node):
92
    s = ""
93
    for n in node.findall(subNodePath):
94
95
96
        s = "%s%s" % (s, trimToSingleSpace(getNodeText(n)))
        if n.tag == "para":
            s = "%s\n\n" % s
97
98
    return s.strip()

99
OPENMM_RE_PATTERN=re.compile("(.*)OpenMM:[a-zA-Z0-9:]*:(.*)")
100
def stripOpenmmPrefix(name, rePattern=OPENMM_RE_PATTERN):
101
102
103
104
105
106
107
    try:
        m=rePattern.search(name)
        rValue = "%s%s" % m.group(1,2)
        rValue.strip()
        return rValue
    except:
        return name
108

109
110
111
112
113
114
115
116
117
118
def findNodes(parent, path, **args):
    nodes = []
    for node in parent.findall(path):
        match = True
        for arg in args:
            if arg not in node.attrib or node.attrib[arg] != args[arg]:
                match = False
        if match:
            nodes.append(node)
    return nodes
119
120
121
122
123

def getClassMethodList(classNode, skipMethods):
    className = getText("compoundname", classNode)
    shortClassName=stripOpenmmPrefix(className)
    methodList=[]
124
    for section in findNodes(classNode, "sectiondef", kind="public-static-func")+findNodes(classNode, "sectiondef", kind="public-func"):
125
        for memberNode in findNodes(section, "memberdef", kind="function", prot="public"):
126
127
            methDefinition = getText("definition", memberNode)
            shortMethDefinition=stripOpenmmPrefix(methDefinition)
128
            shortMethDefinition = shortMethDefinition.replace(' &', '&')
129
130
131
132
133
134
135
136
137
            methName=shortMethDefinition.split()[-1]
            if (shortClassName, methName) in skipMethods: continue
            numParams=len(findNodes(memberNode, 'param'))
            if (shortClassName, methName, numParams) in skipMethods: continue
            for catchString in ['Factory', 'Impl', 'Info', 'Kernel']:
                if shortClassName.endswith(catchString):
                    sys.stderr.write("Warning: Including class %s\n" %
                                     shortClassName)
                    continue
138

139
            if (shortClassName, methName) in skipMethods: continue
140

141
            # set template info
142

143
144
            templateType = getText("templateparamlist/param/type", memberNode)
            templateName = getText("templateparamlist/param/declname", memberNode)
145

146
147
148
149
150
151
            methodList.append( (shortClassName,
                                memberNode,
                                shortMethDefinition,
                                methName,
                                shortClassName==methName,
                                '~'+shortClassName==methName, templateType, templateName ) )
152
153
154
    return methodList


Robert McGibbon's avatar
Robert McGibbon committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def docstringTypemap(cpptype):
    """Translate a C++ type to Python for inclusion in the Python docstrings.
    This doesn't need to be perfectly accurate -- it's not used for generating
    the actual swig wrapper code. It's only used for generating the docstrings.
    """
    pytype = cpptype
    if pytype.startswith('const '):
        pytype = pytype[6:]
    if pytype.startswith('std::'):
        pytype = pytype[5:]
    pytype = pytype.strip('&')
    return pytype.strip()


169
170
class SwigInputBuilder:
    def __init__(self,
171
                 inputDirname,
172
173
174
175
176
                 configFilename,
                 outputFilename=None,
                 docstringFilename=None,
                 pythonprependFilename=None,
                 pythonappendFilename=None,
177
178
                 skipAdditionalMethods=[],
                 SWIG_VERSION='3.0.2'):
179
180
181
182
183
184
185
186
187
188
189
        self.nodeByID={}

        self.configModule = __import__(os.path.splitext(configFilename)[0])

        self.skipMethods=self.configModule.SKIP_METHODS[:]
        for skipMethod in skipAdditionalMethods:
            items=skipMethod.split('::')
            if len(items)==3:
                items[2]=int(items[2])
            self.skipMethods.append(tuple(items))

190
191
192
        # Read all the XML files and merge them into a single document.
        self.doc = etree.ElementTree(etree.Element('root'))
        for file in os.listdir(inputDirname):
193
194
195
196
            if file.lower().endswith('xml'):
                root = etree.parse(os.path.join(inputDirname, file)).getroot()
                for node in root:
                    self.doc.getroot().append(node)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

        if outputFilename:
            self.fOut = open(outputFilename, 'w')
        else:
            self.fOut = sys.stdout

        if docstringFilename:
            self.fOutDocstring = open(docstringFilename, 'w')
        else:
            self.fOutDocstring = None

        if pythonprependFilename:
            self.fOutPythonprepend = open(pythonprependFilename, 'w')
        else:
            self.fOutPythonprepend = None

        if pythonappendFilename:
            self.fOutPythonappend = open(pythonappendFilename, 'w')
        else:
            self.fOutPythonappend = None

        self._enumByClassname={}

        self._orderedClassNodes=self._buildOrderedClassNodes()

    def _getNodeByID(self, id):
        if id not in self.nodeByID:
224
225
            for node in findNodes(self.doc.getroot(), "compounddef", id=id):
                self.nodeByID[id] = node
226
227
228
229
        return self.nodeByID[id]

    def _buildOrderedClassNodes(self):
        orderedClassNodes=[]
230
        for node in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
231
232
233
234
235
236
237
238
            self._findBaseNodes(node, orderedClassNodes)
        return orderedClassNodes

    def _findBaseNodes(self, node, excludedClassNodes=[]):
        if node in excludedClassNodes: return
        nodeName = getText("compoundname", node)
        if (nodeName.split("::")[-1],) in self.skipMethods:
            return
239
        for baseNodePnt in findNodes(node, "basecompoundref", prot="public"):
240
241
242
243
            if "refid" in baseNodePnt.attrib:
                baseNodeID = baseNodePnt.attrib["refid"]
                baseNode = self._getNodeByID(baseNodeID)
                self._findBaseNodes(baseNode, excludedClassNodes)
244
245
246
        excludedClassNodes.append(node)


247
248
249
250
    def writeFactories(self):
        self.fOut.write("\n/* Declare factories */\n\n")
        forceSubclassList = []
        integratorSubclassList = []
251
        tabulatedFunctionSubclassList = []
252
        for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
253
254
255
256
            className = getText("compoundname", classNode)
            shortClassName=stripOpenmmPrefix(className)
            if (className.split("::")[-1],) in self.skipMethods:
                continue
257
            for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
258
259
260
261
262
263
                if "refid" in baseNodePnt.attrib:
                    baseNodeID=baseNodePnt.attrib["refid"]
                    baseNode=self._getNodeByID(baseNodeID)
                    baseName = getText("compoundname", baseNode)
                    if baseName == 'OpenMM::Force':
                        forceSubclassList.append(shortClassName)
264
                    elif baseName in ('OpenMM::Integrator', 'OpenMM::DrudeIntegrator'):
265
                        integratorSubclassList.append(shortClassName)
266
267
                    elif baseName == 'OpenMM::TabulatedFunction':
                        tabulatedFunctionSubclassList.append(shortClassName)
268
269
        # We need to include subclasses of DrudeIntegrator, but not DrudeIntegrator itself.
        integratorSubclassList.remove('DrudeIntegrator')
Robert McGibbon's avatar
Robert McGibbon committed
270

271
272
273
274
        self.fOut.write("%factory(OpenMM::Force& OpenMM::System::getForce")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
275

276
        self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__cloneForce")
Robert McGibbon's avatar
Robert McGibbon committed
277
278
279
280
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

281
282
283
284
        self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__deserializeForce")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
285

286
287
288
289
290
        self.fOut.write("%factory(OpenMM::Force& OpenMM::CustomCVForce::getCollectiveVariable")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

291
        self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__cloneIntegrator")
Robert McGibbon's avatar
Robert McGibbon committed
292
293
294
295
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

296
297
298
299
        self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__deserializeIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
300

301
302
303
304
        self.fOut.write("%factory(OpenMM::Integrator& OpenMM::Context::getIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
305

306
307
308
309
310
        self.fOut.write("%factory(OpenMM::Integrator& OpenMM::CompoundIntegrator::getIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

311
        self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM_XmlSerializer__cloneTabulatedFunction")
312
313
314
315
316
317
318
319
320
        for name in sorted(tabulatedFunctionSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

        self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM_XmlSerializer__deserializeTabulatedFunction")
        for name in sorted(tabulatedFunctionSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

321
322
323
324
325
326
327
328
329
330
331
332
333
        for classNode in self._orderedClassNodes:
            methodList=getClassMethodList(classNode, self.skipMethods)
            for items in methodList:
                (shortClassName, memberNode,
                 shortMethDefinition, methName,
                 isConstructors, isDestructor, templateType, templateName) = items
                if shortMethDefinition == 'TabulatedFunction& getTabulatedFunction':
                    self.fOut.write("%factory(OpenMM::TabulatedFunction& OpenMM::")
                    self.fOut.write("%s::%s" % (shortClassName, methName))
                    for name in sorted(tabulatedFunctionSubclassList):
                        self.fOut.write(",\n         OpenMM::%s" % name)
                    self.fOut.write(");\n\n")

334
        self.fOut.write("%factory(OpenMM::VirtualSite& OpenMM::System::getVirtualSite, OpenMM::TwoParticleAverageSite, OpenMM::ThreeParticleAverageSite, OpenMM::OutOfPlaneSite, OpenMM::LocalCoordinatesSite);\n\n")
335
336
337
338
        self.fOut.write("\n")

    def writeGlobalConstants(self):
        self.fOut.write("/* Global Constants */\n\n")
Peter Eastman's avatar
Peter Eastman committed
339
        node = next((x for x in findNodes(self.doc.getroot(), "compounddef", kind="namespace") if x.findtext("compoundname") == "OpenMM"))
340
341
342
343
        for section in findNodes(node, "sectiondef", kind="var"):
            for memberNode in findNodes(section, "memberdef", kind="variable", mutable="no", prot="public", static="yes"):
                vDef = stripOpenmmPrefix(getText("definition", memberNode))
                iDef = getText("initializer", memberNode)
344
345
                if iDef.startswith("="):
                    iDef = iDef[1:]
346
                self.fOut.write("static %s = %s;\n" % (vDef, iDef))
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        self.fOut.write("\n")

    def writeForwardDeclarations(self):
        self.fOut.write("\n/* Forward Declarations */\n\n")

        for classNode in self._orderedClassNodes:
            hasConstructor=False
            methodList=getClassMethodList(classNode, self.skipMethods)
            for items in methodList:
                (shortClassName, memberNode,
                 shortMethDefinition, methName,
                 isConstructors, isDestructor, templateType, templateName) = items
                if isConstructors:
                    hasConstructor=True

            className = stripOpenmmPrefix(getText("compoundname", classNode))
            # If has a constructor then tell swig tell to make a copy method
            if hasConstructor:
                self.fOut.write("%%copyctor %s ;\n" % className)
            self.fOut.write("class %s ;\n" % className)
        self.fOut.write("\n")

    def writeClassDeclarations(self):
        self.fOut.write("\n/* Class Declarations */\n\n")
        for classNode in self._orderedClassNodes:
            className = stripOpenmmPrefix(getText("compoundname", classNode))
373
374
375
376
            if self.fOutDocstring:
                dNode = classNode.find('detaileddescription')
                if dNode is not None:
                    docstring = getNodeText(dNode).strip().replace('"', '\\"')
377
                    docstring = striphtmltags(docstring)
378
                    self.fOutDocstring.write('%%feature("docstring") %s "%s";\n' % (className, docstring))
379
380
381
382
383
            self.fOut.write("class %s" % className)
            if className in self.configModule.MISSING_BASE_CLASSES:
                self.fOut.write(" : public %s" %
                                self.configModule.MISSING_BASE_CLASSES[className])

384
            for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
385
386
387
                if "refid" in baseNodePnt.attrib:
                    baseName = stripOpenmmPrefix(getText(".", baseNodePnt))
                    self.fOut.write(" : public %s" % baseName)
388
389
390
391
392
393
394
395
            self.fOut.write(" {\n")
            self.fOut.write("public:\n")
            self.writeEnumerations(classNode)
            self.writeMethods(classNode)
            self.fOut.write("};\n\n")
        self.fOut.write("\n")

    def writeEnumerations(self, classNode):
396
397
398
399
        enumNodes = []
        for section in findNodes(classNode, "sectiondef", kind="public-type"):
            for node in findNodes(section, "memberdef", kind="enum", prot="public"):
                enumNodes.append(node)
400
401
402
403
404
405
406
407
408
409
        for enumNode in enumNodes:
            className = getText("compoundname", classNode)
            shortClassName=stripOpenmmPrefix(className)
            enumName = getText("name", enumNode)
            try:
                self._enumByClassname[shortClassName].append(enumName)
            except KeyError:
                self._enumByClassname[shortClassName]=[enumName]
            self.fOut.write("%senum %s {" % (INDENT, enumName))
            argSep="\n"
410
            for valueNode in findNodes(enumNode, "enumvalue", prot="public"):
411
412
                vName = getText("name", valueNode)
                vInit = getText("initializer", valueNode)
413
414
                if vInit.startswith("="):
                    vInit = vInit[1:]
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
                self.fOut.write("%s%s%s = %s" % (argSep, 2*INDENT, vName, vInit))
                argSep=",\n"
            self.fOut.write("\n%s};\n" % INDENT)
        if len(enumNodes)>0: self.fOut.write("\n")

    def writeMethods(self, classNode):
        methodList=getClassMethodList(classNode, self.skipMethods)

        #write only Constructors
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isConstructors:
                mArgsstring = getText("argsstring", memberNode)
                try:
                    pExceptions = " %s" % getText('exceptions', memberNode)
                except IndexError:
                    pExceptions = ""
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition,
                                                 mArgsstring, pExceptions))
        #write only Destructors
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isDestructor:
                mArgsstring = getText("argsstring", memberNode)
                try:
                    pExceptions = " %s" % getText('exceptions', memberNode)
                except IndexError:
                    pExceptions = ""
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition,
                                                 mArgsstring, pExceptions))

        #write only non Constructor and Destructor methods and python mods
        self.fOut.write("\n")
452
        methodsWithOutputArgs = set()
453
454
455
456
457
458
459
460
461
462
463
464
        for items in methodList:
            clearOutput=""
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isConstructors or isDestructor: continue

            key = (shortClassName, methName)
            if key in self.configModule.DOC_STRINGS:
                self.fOut.write('%%feature("autodoc", "%s") %s;\n' %
                                (self.configModule.DOC_STRINGS[key], methName))

465
            paramList=findNodes(memberNode, 'param')
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
            for pNode in paramList:
                try:
                    pType = getText('type', pNode)
                except IndexError:
                    pType = getText('type/ref', pNode)
                pName = getText('declname', pNode)
                key = (shortClassName, methName, pName)
                if pType.find('&')>=0 and \
                   'const' not in pType.split():
                    if key not in self.configModule.NO_OUTPUT_ARGS:
                        eType = pType.split()[0]
                        if shortClassName in self._enumByClassname and \
                           eType in self._enumByClassname[shortClassName]:
                            simpleType = re.sub(eType, 'int', pType)
                        else:
                            simpleType = pType
                        self.fOut.write("%s%%apply %s OUTPUT { %s %s };\n" %
                                        (INDENT, simpleType, pType, pName))
                        clearOutput = "%s%s%%clear %s %s;\n" \
                                     % (clearOutput, INDENT, pType, pName)
486
                        methodsWithOutputArgs.add((shortClassName, methName))
487
488
489
490
491
492

            mArgsstring = getText("argsstring", memberNode)
            try:
                pExceptions = " %s" % getText('exceptions', memberNode)
            except IndexError:
                pExceptions = ""
493
            if memberNode.attrib["virt"].strip()!='non-virtual':
494
495
496
497
498
499
500
501
502
503
504
505
506
507
                if 'virtual' not in shortMethDefinition.split():
                    shortMethDefinition="virtual %s" % shortMethDefinition
            if( len(templateType) > 0 ):
                self.fOut.write("%stemplate<%s %s> %s%s%s;\n" % (INDENT, templateType, templateName, shortMethDefinition, mArgsstring, pExceptions))
            else:
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition, mArgsstring, pExceptions))
            if clearOutput:
                self.fOut.write(clearOutput)

        #write python mod blocks
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
508
            paramList = findNodes(memberNode, 'param')
509

510
            # write pythonprepend blocks
511
512
513
514
            if isConstructors:
                mArgsstring = '' # specifying args to constructors seems to prevent append and prepend from working
            else:
                mArgsstring = getText("argsstring", memberNode)
515
516
            if self.fOutPythonprepend and \
               len(paramList) and \
517
               not is_method_abstract(mArgsstring):
518
519
520
521
522
523
                text = '''
%pythonprepend OpenMM::{shortClassName}::{methName}{mArgsstring} %{{{{{{0}}
%}}}}'''.format(shortClassName=shortClassName, methName=methName, mArgsstring=mArgsstring)
                textInside = ''
                key = (shortClassName, methName)
                for argNum in self.configModule.STEAL_OWNERSHIP.get(key, []):
524
                    argName = getText('declname', paramList[argNum])
525
526

                    textInside += '''
527
528
529
    if not {argName}.thisown:
        s = ("the %s object does not own its corresponding OpenMM object"
             % self.__class__.__name__)
530
        raise Exception(s)'''.format(argName=argName)
531

532

533
534
535
536
537
538
539
540
                # Convert input arguments to the proper units, if specified.
                if key not in methodsWithOutputArgs:
                    if key in self.configModule.UNITS:
                        argUnits=self.configModule.UNITS[key][1]
                    elif ("*", methName) in self.configModule.UNITS:
                        argUnits=self.configModule.UNITS[("*", methName)][1]
                    else:
                        argUnits = ()
541
                    if len(argUnits) > 0 and isConstructors:
542
543
544
545
                        textInside += '''
    args = list(args)'''
                    for i, units in enumerate(argUnits):
                        if units is not None:
546
                            if isConstructors:
547
548
549
550
551
552
553
                                argName = 'args[%s]' % i
                            else:
                                argName = getText('declname', paramList[i])
                            textInside += '''
    if unit.is_quantity({argName}):
        {argName} = {argName}.value_in_unit({units})'''.format(argName=argName, units=units)

554
                for argNum in self.configModule.REQUIRE_ORDERED_SET.get(key, []):
555
                    argName = getText('declname', paramList[argNum])
556
557
558
559
560
561
562

                    textInside += '''
    {argName} = list({argName})'''.format(argName=argName)
                if textInside:
                    self.fOutPythonprepend.write(text.format(textInside))

            # write pythonappend blocks
563
            if self.fOutPythonappend \
564
               and not is_method_abstract(mArgsstring):
565
                key = (shortClassName, methName)
566
567
                #sys.stdout.write("key %s %s \n" % (shortClassName, methName))

568
                addText=''
569
                returnType = getText("type", memberNode)
570
571
572
573
574

                if key in self.configModule.UNITS:
                    valueUnits=self.configModule.UNITS[key]
                elif ("*", methName) in self.configModule.UNITS:
                    valueUnits=self.configModule.UNITS[("*", methName)]
575
                elif methName.startswith('get') and returnType not in ('void', 'int', 'bool', 'std::string', 'const std::string &'):
576
577
                    s = 'do not know how to add units to %s %s::%s' \
                        % (returnType, shortClassName, methName)
578
579
580
581
582
                    raise Exception(s)
                else:
                    valueUnits=[None, ()]

                index=0
583
                if valueUnits[0] is not None:
584
585
586
587
588
589
590
591
592
593
594
595
596
                    sys.stdout.write("%s.%s() returns %s\n" %
                                     (shortClassName, methName, valueUnits[0]))
                    if len(valueUnits[1])>0:
                        addText = "%s%sval[%d]=unit.Quantity(val[%d], %s)\n" \
                                 % (addText, INDENT,
                                    index, index,
                                    valueUnits[0])
                        index+=1
                    else:
                        addText = "%s%sval=unit.Quantity(val, %s)\n" \
                                 % (addText, INDENT, valueUnits[0])

                for vUnit in valueUnits[1]:
597
                    if vUnit is not None and key in methodsWithOutputArgs:
598
                        addText = "%s%sval[%s]=unit.Quantity(val[%s], %s)\n" \
599
                                     % (addText, INDENT, index, index, vUnit)
600
                    index+=1
601
602
603

                if key in self.configModule.STEAL_OWNERSHIP:
                    for argNum in self.configModule.STEAL_OWNERSHIP[key]:
604
                        argName = getText('declname', paramList[argNum])
605
606
                        addText = "%s%s%s.thisown=0\n" \
                                % (addText, INDENT, argName)
607
608
609
610
611
612
613
614
615
616
617

                if addText:
                    self.fOutPythonappend.write("%pythonappend")
                    self.fOutPythonappend.write(" OpenMM::%s::%s(" % key)
                    sepChar=''
                    outputIndex=0
                    for pNode in paramList:
                        try:
                            pType = getText('type', pNode)
                        except IndexError:
                            pType = getText('type/ref', pNode)
618
619
620
621
622
623
624
                        # parse default arguments
                        try:
                            defaultValue = getText('defval', pNode)
                        except:
                            defaultValue = ""
                        if defaultValue != "":
                            defaultValue = "=%s" %defaultValue
625
                        pName = getText('declname', pNode)
626
                        self.fOutPythonappend.write("%s%s %s%s" % (sepChar, pType, pName, defaultValue))
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
                        sepChar=', '

                        if pType.find('&')>=0 and \
                          'const' not in pType.split() and \
                          key not in self.configModule.NO_OUTPUT_ARGS and \
                          len(valueUnits[1])>0:
                            try:
                                unitType=valueUnits[1][outputIndex]
                            except IndexError:
                                s = "missing unit type for %s.%s() arg named %s" \
                                   % (shortClassName, methName, pName)
                                raise Exception(s)
                            sys.stdout.write("%s.%s() returns %s as %s\n" %
                                             (shortClassName, methName,
                                              pName, unitType))
                            outputIndex+=1
643
                    if memberNode.attrib["const"]=="yes":
644
645
646
647
648
649
650
651
652
653
654
655
656
657
                        constString=' const'
                    else:
                        constString=''
                    self.fOutPythonappend.write(")%s %%{\n" % constString)
                    self.fOutPythonappend.write(addText)
                    self.fOutPythonappend.write("%}\n\n")


        #print "Done python mod blocks\n"
        #write Docstring info
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName ) = items
Robert McGibbon's avatar
Robert McGibbon committed
658

659
            if self.fOutDocstring:
Robert McGibbon's avatar
Robert McGibbon committed
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
                signatureParams = findNodes(memberNode, 'param')
                assert len(findNodes(memberNode, 'detaileddescription')) == 1
                dNode = findNodes(memberNode, 'detaileddescription')[0]

                try:
                    description=getText('para', dNode)
                    description.strip()
                except IndexError:
                    description = ''
                params = findNodes(dNode, 'para/parameterlist/parameteritem')

                paramString = ['Parameters', '----------']
                returnString = ['Returns', '-------']

                if len(params) > 0:
                    if len(signatureParams) != len(params):
                        raise ValueError('docstring in %s.%s does not match the signature' % (shortClassName, methName))

                    for pNode, pSignatureNode in zip(params, signatureParams):
                        parameterNameNode = findNodes(pNode, 'parameternamelist/parametername')[0]
680
                        argDoc = getText('parameterdescription/para', pNode)
Robert McGibbon's avatar
Robert McGibbon committed
681
682
                        argName = getNodeText(parameterNameNode)
                        argType = docstringTypemap(getText('type', pSignatureNode))
683

Robert McGibbon's avatar
Robert McGibbon committed
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
                        isOutput = parameterNameNode.get('direction') == 'out'
                        if isOutput:
                            returnString.extend(['%s : %s' % (argName, argType), '    %s' % argDoc])
                        else:
                            paramString.extend(['%s : %s' % (argName, argType), '    %s' % argDoc])


                returnSection = findNodes(dNode, 'para/simplesect')
                if len(returnSection) > 0:
                    returnNode = returnSection[0]
                    if returnNode.get('kind') == 'return':
                        argType = getNodeText(findNodes(memberNode, 'type')[0])
                        argType = docstringTypemap(argType)
                        returnString.extend([argType, '    %s' % getNodeText(returnNode).strip()])

                dString = '\n'.join(
                    ([description] + [''] if len(description) > 0 else []) +
                    (paramString + [''] if len(paramString) > 2 else []) +
                    (returnString if len(returnString) > 2 else [])).strip()
                if dString:
                    dString = re.sub(r'([^\\])"', r'\g<1>\"', dString)
                    s = '%%feature("docstring") OpenMM::%s::%s "%s";' \
                       % (shortClassName, methName, dString)
                    self.fOutDocstring.write("%s\n" % s)

                self.fOutDocstring.write("\n\n")
710
711
712
713
714
715


    def writeSwigFile(self):
        self.fOut.write("/* Swig input file,\n")
        self.fOut.write("%sgenerated by %s on %s\n*/\n\n\n"
                        % (INDENT, sys.argv[0], time.asctime()))
716
        self.fOut.write("%include \"factory.i\"\n")
717
        self.fOut.write("\nnamespace OpenMM {\n\n")
718
        self.writeFactories()
719
720
721
722
723
724
725
        self.writeGlobalConstants()
        self.writeForwardDeclarations()
        self.writeClassDeclarations()
        self.fOut.write("\n} // namespace OpenMM\n\n")


def parseCommandLine():
726
    opts, args_proper = getopt.getopt(sys.argv[1:], 'hi:c:o:d:a:z:s:v:')
727
    inputDirname = None
728
729
730
731
732
733
    configFilename = None
    outputFilename = ""
    docstringFilename = ""
    pythonprependFilename = ""
    pythonappendFilename = ""
    skipAdditionalMethods = []
734
    swigVersion = '3.0.2'
735
736
    for option, parameter in opts:
        if option=='-h': usageError()
737
        if option=='-i': inputDirname = parameter
738
739
740
741
742
743
        if option=='-c': configFilename=parameter
        if option=='-o': outputFilename = parameter
        if option=='-d': docstringFilename = parameter
        if option=='-a': pythonprependFilename=parameter
        if option=='-z': pythonappendFilename=parameter
        if option=='-s': skipAdditionalMethods.append(parameter)
744
        if option=='-v': swigVersion = parameter
745
    if not inputDirname: usageError()
746
    if not configFilename: usageError()
747
    return (args_proper, inputDirname, configFilename, outputFilename,
748
749
            docstringFilename, pythonprependFilename, pythonappendFilename,
            skipAdditionalMethods, swigVersion)
750
751

def main():
752
    (args_proper, inputDirname, configFilename, outputFilename,
753
     docstringFilename, pythonprependFilename, pythonappendFilename,
754
     skipAdditionalMethods, swigVersion) = parseCommandLine()
755
    sBuilder = SwigInputBuilder(inputDirname, configFilename, outputFilename,
756
                                docstringFilename, pythonprependFilename,
757
758
                                pythonappendFilename, skipAdditionalMethods,
                                swigVersion)
759
760
761
762
763
764
765
766
767
768
769
770
    #print "Calling writeSwigFile\n"
    sBuilder.writeSwigFile()
    #print "Done writeSwigFile\n"

    return


def usageError():
    sys.stdout.write('usage: %s -i xmlHeadersFilename \\\n' \
         % os.path.basename(sys.argv[0]))
    sys.stdout.write('       %s -c inputConfigFilename\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
771
    sys.stdout.write('       %s[-o swigInputDirname]\n' \
772
773
774
775
776
777
778
779
780
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-d docstringFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-a pythonprependFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-z pythonappendFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-s skippedClasses]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
781
782
    sys.stdout.write('       %s[-v swigVersion]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
783
784
785
786
787
788
    sys.exit(1)

if __name__=='__main__':
    main()