swigInputBuilder.py 35.2 KB
Newer Older
Robert McGibbon's avatar
typo  
Robert McGibbon committed
1
#!/usr/bin/env python
2
3
4
"""Build swig imput file from xml encoded header files (see gccxml)."""
__author__ = "Randall J. Radmer"
__version__ = "1.0"
5
6


7
8
9
10
import sys, os
import time
import getopt
import re
11
import xml.etree.ElementTree as etree
12

13
14
15
16
17
18
try:
    from html.parser import HTMLParser
except ImportError:
    # python 2
    from HTMLParser import HTMLParser

19
INDENT = "   "
20
21
docTags = {'emphasis':'i', 'bold':'b', 'itemizedlist':'ul', 'listitem':'li', 'preformatted':'pre', 'computeroutput':'tt', 
           'superscript': 'sup', 'subscript':'sub', 'verbatim': 'verbatim'}
22
23
24
25
26
typeSubstitutions = {'double':'float', 'long long':'int', 'string':'str'}
vectorPattern = re.compile("vector\<(.*)>")
setPattern = re.compile("set\<(.*)>")
mapPattern = re.compile("map\<(.*)\,(.*)>")
pairPattern = re.compile("pair\<(.*)\,(.*)>")
27

28
29
30
def is_method_abstract(argstring):
    return argstring.split(")")[-1].find("=0") >= 0

31
32
33
34
35
36
def striphtmltags(s):
    """Strip a couple html tags used inside docstrings in the C++ source
    to produce something more easily read as plain text.
    """
    class ConvertLists(HTMLParser):
        def reset(self):
Robert McGibbon's avatar
Robert McGibbon committed
37
            HTMLParser.reset(self)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            self.out = []

        def handle_starttag(self, tag, attrs):
            if tag == 'li':
                self.out.append('\n - ')
        def handle_data(self, data):
            self.out.append(data.strip())

    convertlists = ConvertLists()

    def replace_ul_tags(m):
        a, b = m.span()
        sub = s[a:b]

        convertlists.reset()
        convertlists.feed(sub)
        return '\n%s\n\n' % ''.join(convertlists.out)

Robert McGibbon's avatar
typo  
Robert McGibbon committed
56
57
58
    s = s.replace('<i>', '_').replace('</i>', '_')
    s = s.replace('<b>', '*').replace('</b>', '*')

Robert McGibbon's avatar
Robert McGibbon committed
59
    s = re.sub('\s*(<ul>.*?</ul>\s*)', replace_ul_tags, s, flags=re.MULTILINE | re.DOTALL)
60
61
    return s

62
63
64
65
66
67
68
69
70
71
72
def trimToSingleSpace(text):
    if text is None or len(text) == 0:
        return ""
    t = text.strip()
    if len(t) == 0:
        return t
    if text[0].isspace():
        t = " %s" % t
    if text[-1].isspace():
        t = "%s " % t
    return t
73

74
75
76
77
78
79
80
81
82
83
def getNodeText(node):
    if node.text is not None:
        s = node.text
    else:
        s = ""
    for n in node:
        if n.tag == "para":
            s = "%s%s\n\n" % (s, getNodeText(n))
        elif n.tag == "ref":
            s = "%s%s" % (s, getNodeText(n))
84
85
86
87
88
        elif n.tag == "xrefsect":
            title = n.find("xreftitle")
            description = n.find("xrefdescription")
            if title is not None and description is not None and getNodeText(title).lower() == "deprecated":
                s = "%s\n@deprecated %s\n\n" % (s, getNodeText(description))
89
90
91
92
93
94
95
96
        else:
            if n.tag in docTags:
                tag = docTags[n.tag]
                s = "%s<%s>%s</%s>" % (s, tag, getNodeText(n), tag)
        if n.tail is not None:
            s = "%s%s" % (s, n.tail)
    return s

97
def getText(subNodePath, node):
98
    s = ""
99
    for n in node.findall(subNodePath):
100
101
102
        s = "%s%s" % (s, trimToSingleSpace(getNodeText(n)))
        if n.tag == "para":
            s = "%s\n\n" % s
103
104
    return s.strip()

105
OPENMM_RE_PATTERN=re.compile("(.*)OpenMM:[a-zA-Z0-9:]*:(.*)")
106
def stripOpenmmPrefix(name, rePattern=OPENMM_RE_PATTERN):
107
108
109
110
111
112
113
    try:
        m=rePattern.search(name)
        rValue = "%s%s" % m.group(1,2)
        rValue.strip()
        return rValue
    except:
        return name
114

115
116
117
118
119
120
121
122
123
124
def findNodes(parent, path, **args):
    nodes = []
    for node in parent.findall(path):
        match = True
        for arg in args:
            if arg not in node.attrib or node.attrib[arg] != args[arg]:
                match = False
        if match:
            nodes.append(node)
    return nodes
125
126
127
128
129

def getClassMethodList(classNode, skipMethods):
    className = getText("compoundname", classNode)
    shortClassName=stripOpenmmPrefix(className)
    methodList=[]
130
    for section in findNodes(classNode, "sectiondef", kind="public-static-func")+findNodes(classNode, "sectiondef", kind="public-func"):
131
        for memberNode in findNodes(section, "memberdef", kind="function", prot="public"):
132
133
            methDefinition = getText("definition", memberNode)
            shortMethDefinition=stripOpenmmPrefix(methDefinition)
134
            shortMethDefinition = shortMethDefinition.replace(' &', '&')
135
136
137
138
139
140
141
142
143
            methName=shortMethDefinition.split()[-1]
            if (shortClassName, methName) in skipMethods: continue
            numParams=len(findNodes(memberNode, 'param'))
            if (shortClassName, methName, numParams) in skipMethods: continue
            for catchString in ['Factory', 'Impl', 'Info', 'Kernel']:
                if shortClassName.endswith(catchString):
                    sys.stderr.write("Warning: Including class %s\n" %
                                     shortClassName)
                    continue
144

145
            if (shortClassName, methName) in skipMethods: continue
146

147
            # set template info
148

149
150
            templateType = getText("templateparamlist/param/type", memberNode)
            templateName = getText("templateparamlist/param/declname", memberNode)
151

152
153
154
155
156
157
            methodList.append( (shortClassName,
                                memberNode,
                                shortMethDefinition,
                                methName,
                                shortClassName==methName,
                                '~'+shortClassName==methName, templateType, templateName ) )
158
159
160
    return methodList


Robert McGibbon's avatar
Robert McGibbon committed
161
162
163
164
165
def docstringTypemap(cpptype):
    """Translate a C++ type to Python for inclusion in the Python docstrings.
    This doesn't need to be perfectly accurate -- it's not used for generating
    the actual swig wrapper code. It's only used for generating the docstrings.
    """
166
    pytype = cpptype.strip()
Robert McGibbon's avatar
Robert McGibbon committed
167
168
169
170
171
    if pytype.startswith('const '):
        pytype = pytype[6:]
    if pytype.startswith('std::'):
        pytype = pytype[5:]
    pytype = pytype.strip('&')
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    pytype = pytype.strip('*')
    pytype = pytype.strip()
    if pytype in typeSubstitutions:
        pytype = typeSubstitutions[pytype]
    match = vectorPattern.match(pytype)
    if match is not None:
        pytype = f'Sequence[{docstringTypemap(match[1])}]'
    match = setPattern.match(pytype)
    if match is not None:
        pytype = f'set[{docstringTypemap(match[1])}]'
    match = mapPattern.match(pytype)
    if match is not None:
        pytype = f'Mapping[{docstringTypemap(match[1])}, {docstringTypemap(match[2])}]'
    match = pairPattern.match(pytype)
    if match is not None:
        pytype = f'tuple[{docstringTypemap(match[1])}, {docstringTypemap(match[2])}]'
    return pytype
Robert McGibbon's avatar
Robert McGibbon committed
189
190


191
192
class SwigInputBuilder:
    def __init__(self,
193
                 inputDirname,
194
195
196
197
198
                 configFilename,
                 outputFilename=None,
                 docstringFilename=None,
                 pythonprependFilename=None,
                 pythonappendFilename=None,
199
200
                 skipAdditionalMethods=[],
                 SWIG_VERSION='3.0.2'):
201
202
203
204
205
206
207
208
209
210
211
        self.nodeByID={}

        self.configModule = __import__(os.path.splitext(configFilename)[0])

        self.skipMethods=self.configModule.SKIP_METHODS[:]
        for skipMethod in skipAdditionalMethods:
            items=skipMethod.split('::')
            if len(items)==3:
                items[2]=int(items[2])
            self.skipMethods.append(tuple(items))

212
213
214
        # Read all the XML files and merge them into a single document.
        self.doc = etree.ElementTree(etree.Element('root'))
        for file in os.listdir(inputDirname):
215
216
217
218
            if file.lower().endswith('xml'):
                root = etree.parse(os.path.join(inputDirname, file)).getroot()
                for node in root:
                    self.doc.getroot().append(node)
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

        if outputFilename:
            self.fOut = open(outputFilename, 'w')
        else:
            self.fOut = sys.stdout

        if docstringFilename:
            self.fOutDocstring = open(docstringFilename, 'w')
        else:
            self.fOutDocstring = None

        if pythonprependFilename:
            self.fOutPythonprepend = open(pythonprependFilename, 'w')
        else:
            self.fOutPythonprepend = None

        if pythonappendFilename:
            self.fOutPythonappend = open(pythonappendFilename, 'w')
        else:
            self.fOutPythonappend = None

        self._enumByClassname={}

        self._orderedClassNodes=self._buildOrderedClassNodes()

    def _getNodeByID(self, id):
        if id not in self.nodeByID:
246
247
            for node in findNodes(self.doc.getroot(), "compounddef", id=id):
                self.nodeByID[id] = node
248
249
250
251
        return self.nodeByID[id]

    def _buildOrderedClassNodes(self):
        orderedClassNodes=[]
252
        for node in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
253
254
255
256
257
258
259
260
            self._findBaseNodes(node, orderedClassNodes)
        return orderedClassNodes

    def _findBaseNodes(self, node, excludedClassNodes=[]):
        if node in excludedClassNodes: return
        nodeName = getText("compoundname", node)
        if (nodeName.split("::")[-1],) in self.skipMethods:
            return
261
        for baseNodePnt in findNodes(node, "basecompoundref", prot="public"):
262
263
264
265
            if "refid" in baseNodePnt.attrib:
                baseNodeID = baseNodePnt.attrib["refid"]
                baseNode = self._getNodeByID(baseNodeID)
                self._findBaseNodes(baseNode, excludedClassNodes)
266
267
268
        excludedClassNodes.append(node)


269
270
271
272
    def writeFactories(self):
        self.fOut.write("\n/* Declare factories */\n\n")
        forceSubclassList = []
        integratorSubclassList = []
273
        tabulatedFunctionSubclassList = []
274
        for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
275
276
277
278
            className = getText("compoundname", classNode)
            shortClassName=stripOpenmmPrefix(className)
            if (className.split("::")[-1],) in self.skipMethods:
                continue
279
            for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
280
281
282
283
284
285
                if "refid" in baseNodePnt.attrib:
                    baseNodeID=baseNodePnt.attrib["refid"]
                    baseNode=self._getNodeByID(baseNodeID)
                    baseName = getText("compoundname", baseNode)
                    if baseName == 'OpenMM::Force':
                        forceSubclassList.append(shortClassName)
286
                    elif baseName in ('OpenMM::Integrator', 'OpenMM::DrudeIntegrator'):
287
                        integratorSubclassList.append(shortClassName)
288
289
                    elif baseName == 'OpenMM::TabulatedFunction':
                        tabulatedFunctionSubclassList.append(shortClassName)
290
291
        # We need to include subclasses of DrudeIntegrator, but not DrudeIntegrator itself.
        integratorSubclassList.remove('DrudeIntegrator')
Robert McGibbon's avatar
Robert McGibbon committed
292

293
294
295
296
        self.fOut.write("%factory(OpenMM::Force& OpenMM::System::getForce")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
297

298
        self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__cloneForce")
Robert McGibbon's avatar
Robert McGibbon committed
299
300
301
302
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

303
304
305
306
        self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__deserializeForce")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
307

308
309
310
311
312
        self.fOut.write("%factory(OpenMM::Force& OpenMM::CustomCVForce::getCollectiveVariable")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

313
        self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__cloneIntegrator")
Robert McGibbon's avatar
Robert McGibbon committed
314
315
316
317
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

318
319
320
321
        self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__deserializeIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
322

323
324
325
326
        self.fOut.write("%factory(OpenMM::Integrator& OpenMM::Context::getIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
327

328
329
330
331
332
        self.fOut.write("%factory(OpenMM::Integrator& OpenMM::CompoundIntegrator::getIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

333
        self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM_XmlSerializer__cloneTabulatedFunction")
334
335
336
337
338
339
340
341
342
        for name in sorted(tabulatedFunctionSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

        self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM_XmlSerializer__deserializeTabulatedFunction")
        for name in sorted(tabulatedFunctionSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

343
344
345
346
347
348
349
350
351
352
353
354
355
        for classNode in self._orderedClassNodes:
            methodList=getClassMethodList(classNode, self.skipMethods)
            for items in methodList:
                (shortClassName, memberNode,
                 shortMethDefinition, methName,
                 isConstructors, isDestructor, templateType, templateName) = items
                if shortMethDefinition == 'TabulatedFunction& getTabulatedFunction':
                    self.fOut.write("%factory(OpenMM::TabulatedFunction& OpenMM::")
                    self.fOut.write("%s::%s" % (shortClassName, methName))
                    for name in sorted(tabulatedFunctionSubclassList):
                        self.fOut.write(",\n         OpenMM::%s" % name)
                    self.fOut.write(");\n\n")

356
        self.fOut.write("%factory(OpenMM::VirtualSite& OpenMM::System::getVirtualSite, OpenMM::TwoParticleAverageSite, OpenMM::ThreeParticleAverageSite, OpenMM::OutOfPlaneSite, OpenMM::LocalCoordinatesSite);\n\n")
357
358
359
360
        self.fOut.write("\n")

    def writeGlobalConstants(self):
        self.fOut.write("/* Global Constants */\n\n")
Peter Eastman's avatar
Peter Eastman committed
361
        node = next((x for x in findNodes(self.doc.getroot(), "compounddef", kind="namespace") if x.findtext("compoundname") == "OpenMM"))
362
363
364
365
        for section in findNodes(node, "sectiondef", kind="var"):
            for memberNode in findNodes(section, "memberdef", kind="variable", mutable="no", prot="public", static="yes"):
                vDef = stripOpenmmPrefix(getText("definition", memberNode))
                iDef = getText("initializer", memberNode)
366
367
                if iDef.startswith("="):
                    iDef = iDef[1:]
368
                self.fOut.write("static %s = %s;\n" % (vDef, iDef))
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        self.fOut.write("\n")

    def writeForwardDeclarations(self):
        self.fOut.write("\n/* Forward Declarations */\n\n")

        for classNode in self._orderedClassNodes:
            hasConstructor=False
            methodList=getClassMethodList(classNode, self.skipMethods)
            for items in methodList:
                (shortClassName, memberNode,
                 shortMethDefinition, methName,
                 isConstructors, isDestructor, templateType, templateName) = items
                if isConstructors:
                    hasConstructor=True

            className = stripOpenmmPrefix(getText("compoundname", classNode))
            # If has a constructor then tell swig tell to make a copy method
            if hasConstructor:
                self.fOut.write("%%copyctor %s ;\n" % className)
            self.fOut.write("class %s ;\n" % className)
        self.fOut.write("\n")

    def writeClassDeclarations(self):
        self.fOut.write("\n/* Class Declarations */\n\n")
        for classNode in self._orderedClassNodes:
            className = stripOpenmmPrefix(getText("compoundname", classNode))
395
396
397
398
            if self.fOutDocstring:
                dNode = classNode.find('detaileddescription')
                if dNode is not None:
                    docstring = getNodeText(dNode).strip().replace('"', '\\"')
399
                    docstring = striphtmltags(docstring)
400
                    self.fOutDocstring.write('%%feature("docstring") %s "%s";\n' % (className, docstring))
401
402
403
404
405
            self.fOut.write("class %s" % className)
            if className in self.configModule.MISSING_BASE_CLASSES:
                self.fOut.write(" : public %s" %
                                self.configModule.MISSING_BASE_CLASSES[className])

406
            for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
407
408
409
                if "refid" in baseNodePnt.attrib:
                    baseName = stripOpenmmPrefix(getText(".", baseNodePnt))
                    self.fOut.write(" : public %s" % baseName)
410
411
412
413
414
415
416
417
            self.fOut.write(" {\n")
            self.fOut.write("public:\n")
            self.writeEnumerations(classNode)
            self.writeMethods(classNode)
            self.fOut.write("};\n\n")
        self.fOut.write("\n")

    def writeEnumerations(self, classNode):
418
419
420
421
        enumNodes = []
        for section in findNodes(classNode, "sectiondef", kind="public-type"):
            for node in findNodes(section, "memberdef", kind="enum", prot="public"):
                enumNodes.append(node)
422
423
424
425
426
427
428
429
430
431
        for enumNode in enumNodes:
            className = getText("compoundname", classNode)
            shortClassName=stripOpenmmPrefix(className)
            enumName = getText("name", enumNode)
            try:
                self._enumByClassname[shortClassName].append(enumName)
            except KeyError:
                self._enumByClassname[shortClassName]=[enumName]
            self.fOut.write("%senum %s {" % (INDENT, enumName))
            argSep="\n"
432
            for valueNode in findNodes(enumNode, "enumvalue", prot="public"):
433
434
                vName = getText("name", valueNode)
                vInit = getText("initializer", valueNode)
435
436
                if vInit.startswith("="):
                    vInit = vInit[1:]
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
                self.fOut.write("%s%s%s = %s" % (argSep, 2*INDENT, vName, vInit))
                argSep=",\n"
            self.fOut.write("\n%s};\n" % INDENT)
        if len(enumNodes)>0: self.fOut.write("\n")

    def writeMethods(self, classNode):
        methodList=getClassMethodList(classNode, self.skipMethods)

        #write only Constructors
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isConstructors:
                mArgsstring = getText("argsstring", memberNode)
                try:
                    pExceptions = " %s" % getText('exceptions', memberNode)
                except IndexError:
                    pExceptions = ""
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition,
                                                 mArgsstring, pExceptions))
        #write only Destructors
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isDestructor:
                mArgsstring = getText("argsstring", memberNode)
                try:
                    pExceptions = " %s" % getText('exceptions', memberNode)
                except IndexError:
                    pExceptions = ""
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition,
                                                 mArgsstring, pExceptions))

        #write only non Constructor and Destructor methods and python mods
        self.fOut.write("\n")
474
        methodsWithOutputArgs = set()
475
476
477
478
479
480
481
482
483
484
485
486
        for items in methodList:
            clearOutput=""
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isConstructors or isDestructor: continue

            key = (shortClassName, methName)
            if key in self.configModule.DOC_STRINGS:
                self.fOut.write('%%feature("autodoc", "%s") %s;\n' %
                                (self.configModule.DOC_STRINGS[key], methName))

487
            paramList=findNodes(memberNode, 'param')
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            for pNode in paramList:
                try:
                    pType = getText('type', pNode)
                except IndexError:
                    pType = getText('type/ref', pNode)
                pName = getText('declname', pNode)
                key = (shortClassName, methName, pName)
                if pType.find('&')>=0 and \
                   'const' not in pType.split():
                    if key not in self.configModule.NO_OUTPUT_ARGS:
                        eType = pType.split()[0]
                        if shortClassName in self._enumByClassname and \
                           eType in self._enumByClassname[shortClassName]:
                            simpleType = re.sub(eType, 'int', pType)
                        else:
                            simpleType = pType
                        self.fOut.write("%s%%apply %s OUTPUT { %s %s };\n" %
                                        (INDENT, simpleType, pType, pName))
                        clearOutput = "%s%s%%clear %s %s;\n" \
                                     % (clearOutput, INDENT, pType, pName)
508
                        methodsWithOutputArgs.add((shortClassName, methName))
509
510
511
512
513
514

            mArgsstring = getText("argsstring", memberNode)
            try:
                pExceptions = " %s" % getText('exceptions', memberNode)
            except IndexError:
                pExceptions = ""
515
            if memberNode.attrib["virt"].strip()!='non-virtual':
516
517
518
519
520
521
522
523
524
525
526
527
528
529
                if 'virtual' not in shortMethDefinition.split():
                    shortMethDefinition="virtual %s" % shortMethDefinition
            if( len(templateType) > 0 ):
                self.fOut.write("%stemplate<%s %s> %s%s%s;\n" % (INDENT, templateType, templateName, shortMethDefinition, mArgsstring, pExceptions))
            else:
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition, mArgsstring, pExceptions))
            if clearOutput:
                self.fOut.write(clearOutput)

        #write python mod blocks
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
530
            paramList = findNodes(memberNode, 'param')
531

532
            # write pythonprepend blocks
533
534
535
536
            if isConstructors:
                mArgsstring = '' # specifying args to constructors seems to prevent append and prepend from working
            else:
                mArgsstring = getText("argsstring", memberNode)
537
538
            if self.fOutPythonprepend and \
               len(paramList) and \
539
               not is_method_abstract(mArgsstring):
540
541
542
543
544
545
                text = '''
%pythonprepend OpenMM::{shortClassName}::{methName}{mArgsstring} %{{{{{{0}}
%}}}}'''.format(shortClassName=shortClassName, methName=methName, mArgsstring=mArgsstring)
                textInside = ''
                key = (shortClassName, methName)
                for argNum in self.configModule.STEAL_OWNERSHIP.get(key, []):
546
                    argName = getText('declname', paramList[argNum])
547
548

                    textInside += '''
549
550
551
    if not {argName}.thisown:
        s = ("the %s object does not own its corresponding OpenMM object"
             % self.__class__.__name__)
552
        raise Exception(s)'''.format(argName=argName)
553

554

555
556
557
558
559
560
561
562
                # Convert input arguments to the proper units, if specified.
                if key not in methodsWithOutputArgs:
                    if key in self.configModule.UNITS:
                        argUnits=self.configModule.UNITS[key][1]
                    elif ("*", methName) in self.configModule.UNITS:
                        argUnits=self.configModule.UNITS[("*", methName)][1]
                    else:
                        argUnits = ()
563
                    if len(argUnits) > 0 and isConstructors:
564
565
566
567
                        textInside += '''
    args = list(args)'''
                    for i, units in enumerate(argUnits):
                        if units is not None:
568
                            if isConstructors:
569
570
571
572
573
574
575
                                argName = 'args[%s]' % i
                            else:
                                argName = getText('declname', paramList[i])
                            textInside += '''
    if unit.is_quantity({argName}):
        {argName} = {argName}.value_in_unit({units})'''.format(argName=argName, units=units)

576
                for argNum in self.configModule.REQUIRE_ORDERED_SET.get(key, []):
577
                    argName = getText('declname', paramList[argNum])
578
579
580
581
582
583
584

                    textInside += '''
    {argName} = list({argName})'''.format(argName=argName)
                if textInside:
                    self.fOutPythonprepend.write(text.format(textInside))

            # write pythonappend blocks
585
            if self.fOutPythonappend \
586
               and not is_method_abstract(mArgsstring):
587
                key = (shortClassName, methName)
588
589
                #sys.stdout.write("key %s %s \n" % (shortClassName, methName))

590
                addText=''
591
                returnType = getText("type", memberNode)
592
593
594
595
596

                if key in self.configModule.UNITS:
                    valueUnits=self.configModule.UNITS[key]
                elif ("*", methName) in self.configModule.UNITS:
                    valueUnits=self.configModule.UNITS[("*", methName)]
597
                elif methName.startswith('get') and returnType not in ('void', 'int', 'bool', 'std::string', 'const std::string &'):
598
599
                    s = 'do not know how to add units to %s %s::%s' \
                        % (returnType, shortClassName, methName)
600
601
602
603
604
                    raise Exception(s)
                else:
                    valueUnits=[None, ()]

                index=0
605
                if valueUnits[0] is not None:
606
607
608
609
610
611
612
613
614
615
616
617
618
                    sys.stdout.write("%s.%s() returns %s\n" %
                                     (shortClassName, methName, valueUnits[0]))
                    if len(valueUnits[1])>0:
                        addText = "%s%sval[%d]=unit.Quantity(val[%d], %s)\n" \
                                 % (addText, INDENT,
                                    index, index,
                                    valueUnits[0])
                        index+=1
                    else:
                        addText = "%s%sval=unit.Quantity(val, %s)\n" \
                                 % (addText, INDENT, valueUnits[0])

                for vUnit in valueUnits[1]:
619
                    if vUnit is not None and key in methodsWithOutputArgs:
620
                        addText = "%s%sval[%s]=unit.Quantity(val[%s], %s)\n" \
621
                                     % (addText, INDENT, index, index, vUnit)
622
                    index+=1
623
624
625

                if key in self.configModule.STEAL_OWNERSHIP:
                    for argNum in self.configModule.STEAL_OWNERSHIP[key]:
626
                        argName = getText('declname', paramList[argNum])
627
628
                        addText = "%s%s%s.thisown=0\n" \
                                % (addText, INDENT, argName)
629
630
631
632
633
634
635
636
637
638
639

                if addText:
                    self.fOutPythonappend.write("%pythonappend")
                    self.fOutPythonappend.write(" OpenMM::%s::%s(" % key)
                    sepChar=''
                    outputIndex=0
                    for pNode in paramList:
                        try:
                            pType = getText('type', pNode)
                        except IndexError:
                            pType = getText('type/ref', pNode)
640
641
642
643
644
645
646
                        # parse default arguments
                        try:
                            defaultValue = getText('defval', pNode)
                        except:
                            defaultValue = ""
                        if defaultValue != "":
                            defaultValue = "=%s" %defaultValue
647
                        pName = getText('declname', pNode)
648
                        self.fOutPythonappend.write("%s%s %s%s" % (sepChar, pType, pName, defaultValue))
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
                        sepChar=', '

                        if pType.find('&')>=0 and \
                          'const' not in pType.split() and \
                          key not in self.configModule.NO_OUTPUT_ARGS and \
                          len(valueUnits[1])>0:
                            try:
                                unitType=valueUnits[1][outputIndex]
                            except IndexError:
                                s = "missing unit type for %s.%s() arg named %s" \
                                   % (shortClassName, methName, pName)
                                raise Exception(s)
                            sys.stdout.write("%s.%s() returns %s as %s\n" %
                                             (shortClassName, methName,
                                              pName, unitType))
                            outputIndex+=1
665
                    if memberNode.attrib["const"]=="yes":
666
667
668
669
670
671
672
673
674
675
676
677
678
679
                        constString=' const'
                    else:
                        constString=''
                    self.fOutPythonappend.write(")%s %%{\n" % constString)
                    self.fOutPythonappend.write(addText)
                    self.fOutPythonappend.write("%}\n\n")


        #print "Done python mod blocks\n"
        #write Docstring info
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName ) = items
Robert McGibbon's avatar
Robert McGibbon committed
680

681
            if self.fOutDocstring:
Robert McGibbon's avatar
Robert McGibbon committed
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
                signatureParams = findNodes(memberNode, 'param')
                assert len(findNodes(memberNode, 'detaileddescription')) == 1
                dNode = findNodes(memberNode, 'detaileddescription')[0]

                try:
                    description=getText('para', dNode)
                    description.strip()
                except IndexError:
                    description = ''
                params = findNodes(dNode, 'para/parameterlist/parameteritem')

                paramString = ['Parameters', '----------']
                returnString = ['Returns', '-------']

                if len(params) > 0:
                    if len(signatureParams) != len(params):
                        raise ValueError('docstring in %s.%s does not match the signature' % (shortClassName, methName))

                    for pNode, pSignatureNode in zip(params, signatureParams):
                        parameterNameNode = findNodes(pNode, 'parameternamelist/parametername')[0]
702
                        argDoc = getText('parameterdescription/para', pNode)
Robert McGibbon's avatar
Robert McGibbon committed
703
704
                        argName = getNodeText(parameterNameNode)
                        argType = docstringTypemap(getText('type', pSignatureNode))
705

Robert McGibbon's avatar
Robert McGibbon committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
                        isOutput = parameterNameNode.get('direction') == 'out'
                        if isOutput:
                            returnString.extend(['%s : %s' % (argName, argType), '    %s' % argDoc])
                        else:
                            paramString.extend(['%s : %s' % (argName, argType), '    %s' % argDoc])


                returnSection = findNodes(dNode, 'para/simplesect')
                if len(returnSection) > 0:
                    returnNode = returnSection[0]
                    if returnNode.get('kind') == 'return':
                        argType = getNodeText(findNodes(memberNode, 'type')[0])
                        argType = docstringTypemap(argType)
                        returnString.extend([argType, '    %s' % getNodeText(returnNode).strip()])

                dString = '\n'.join(
                    ([description] + [''] if len(description) > 0 else []) +
                    (paramString + [''] if len(paramString) > 2 else []) +
                    (returnString if len(returnString) > 2 else [])).strip()
                if dString:
                    dString = re.sub(r'([^\\])"', r'\g<1>\"', dString)
                    s = '%%feature("docstring") OpenMM::%s::%s "%s";' \
                       % (shortClassName, methName, dString)
                    self.fOutDocstring.write("%s\n" % s)

                self.fOutDocstring.write("\n\n")
732
733
734
735
736
737


    def writeSwigFile(self):
        self.fOut.write("/* Swig input file,\n")
        self.fOut.write("%sgenerated by %s on %s\n*/\n\n\n"
                        % (INDENT, sys.argv[0], time.asctime()))
738
        self.fOut.write("%include \"factory.i\"\n")
739
        self.fOut.write("\nnamespace OpenMM {\n\n")
740
        self.writeFactories()
741
742
743
744
745
746
747
        self.writeGlobalConstants()
        self.writeForwardDeclarations()
        self.writeClassDeclarations()
        self.fOut.write("\n} // namespace OpenMM\n\n")


def parseCommandLine():
748
    opts, args_proper = getopt.getopt(sys.argv[1:], 'hi:c:o:d:a:z:s:v:')
749
    inputDirname = None
750
751
752
753
754
755
    configFilename = None
    outputFilename = ""
    docstringFilename = ""
    pythonprependFilename = ""
    pythonappendFilename = ""
    skipAdditionalMethods = []
756
    swigVersion = '3.0.2'
757
758
    for option, parameter in opts:
        if option=='-h': usageError()
759
        if option=='-i': inputDirname = parameter
760
761
762
763
764
765
        if option=='-c': configFilename=parameter
        if option=='-o': outputFilename = parameter
        if option=='-d': docstringFilename = parameter
        if option=='-a': pythonprependFilename=parameter
        if option=='-z': pythonappendFilename=parameter
        if option=='-s': skipAdditionalMethods.append(parameter)
766
        if option=='-v': swigVersion = parameter
767
    if not inputDirname: usageError()
768
    if not configFilename: usageError()
769
    return (args_proper, inputDirname, configFilename, outputFilename,
770
771
            docstringFilename, pythonprependFilename, pythonappendFilename,
            skipAdditionalMethods, swigVersion)
772
773

def main():
774
    (args_proper, inputDirname, configFilename, outputFilename,
775
     docstringFilename, pythonprependFilename, pythonappendFilename,
776
     skipAdditionalMethods, swigVersion) = parseCommandLine()
777
    sBuilder = SwigInputBuilder(inputDirname, configFilename, outputFilename,
778
                                docstringFilename, pythonprependFilename,
779
780
                                pythonappendFilename, skipAdditionalMethods,
                                swigVersion)
781
782
783
784
785
786
787
788
789
790
791
792
    #print "Calling writeSwigFile\n"
    sBuilder.writeSwigFile()
    #print "Done writeSwigFile\n"

    return


def usageError():
    sys.stdout.write('usage: %s -i xmlHeadersFilename \\\n' \
         % os.path.basename(sys.argv[0]))
    sys.stdout.write('       %s -c inputConfigFilename\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
793
    sys.stdout.write('       %s[-o swigInputDirname]\n' \
794
795
796
797
798
799
800
801
802
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-d docstringFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-a pythonprependFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-z pythonappendFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-s skippedClasses]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
803
804
    sys.stdout.write('       %s[-v swigVersion]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
805
806
807
808
809
810
    sys.exit(1)

if __name__=='__main__':
    main()