swigInputBuilder.py 34.2 KB
Newer Older
Robert McGibbon's avatar
typo  
Robert McGibbon committed
1
#!/usr/bin/env python
2
3
4
"""Build swig imput file from xml encoded header files (see gccxml)."""
__author__ = "Randall J. Radmer"
__version__ = "1.0"
5
6


7
8
9
10
import sys, os
import time
import getopt
import re
11
import xml.etree.ElementTree as etree
12

13
14
15
16
17
18
try:
    from html.parser import HTMLParser
except ImportError:
    # python 2
    from HTMLParser import HTMLParser

19
INDENT = "   "
20
21
docTags = {'emphasis':'i', 'bold':'b', 'itemizedlist':'ul', 'listitem':'li', 'preformatted':'pre', 'computeroutput':'tt', 
           'superscript': 'sup', 'subscript':'sub', 'verbatim': 'verbatim'}
22

23
24
25
def is_method_abstract(argstring):
    return argstring.split(")")[-1].find("=0") >= 0

26
27
28
29
30
31
def striphtmltags(s):
    """Strip a couple html tags used inside docstrings in the C++ source
    to produce something more easily read as plain text.
    """
    class ConvertLists(HTMLParser):
        def reset(self):
Robert McGibbon's avatar
Robert McGibbon committed
32
            HTMLParser.reset(self)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
            self.out = []

        def handle_starttag(self, tag, attrs):
            if tag == 'li':
                self.out.append('\n - ')
        def handle_data(self, data):
            self.out.append(data.strip())

    convertlists = ConvertLists()

    def replace_ul_tags(m):
        a, b = m.span()
        sub = s[a:b]

        convertlists.reset()
        convertlists.feed(sub)
        return '\n%s\n\n' % ''.join(convertlists.out)

Robert McGibbon's avatar
typo  
Robert McGibbon committed
51
52
53
    s = s.replace('<i>', '_').replace('</i>', '_')
    s = s.replace('<b>', '*').replace('</b>', '*')

Robert McGibbon's avatar
Robert McGibbon committed
54
    s = re.sub('\s*(<ul>.*?</ul>\s*)', replace_ul_tags, s, flags=re.MULTILINE | re.DOTALL)
55
56
    return s

57
58
59
60
61
62
63
64
65
66
67
def trimToSingleSpace(text):
    if text is None or len(text) == 0:
        return ""
    t = text.strip()
    if len(t) == 0:
        return t
    if text[0].isspace():
        t = " %s" % t
    if text[-1].isspace():
        t = "%s " % t
    return t
68

69
70
71
72
73
74
75
76
77
78
def getNodeText(node):
    if node.text is not None:
        s = node.text
    else:
        s = ""
    for n in node:
        if n.tag == "para":
            s = "%s%s\n\n" % (s, getNodeText(n))
        elif n.tag == "ref":
            s = "%s%s" % (s, getNodeText(n))
79
80
81
82
83
        elif n.tag == "xrefsect":
            title = n.find("xreftitle")
            description = n.find("xrefdescription")
            if title is not None and description is not None and getNodeText(title).lower() == "deprecated":
                s = "%s\n@deprecated %s\n\n" % (s, getNodeText(description))
84
85
86
87
88
89
90
91
        else:
            if n.tag in docTags:
                tag = docTags[n.tag]
                s = "%s<%s>%s</%s>" % (s, tag, getNodeText(n), tag)
        if n.tail is not None:
            s = "%s%s" % (s, n.tail)
    return s

92
def getText(subNodePath, node):
93
    s = ""
94
    for n in node.findall(subNodePath):
95
96
97
        s = "%s%s" % (s, trimToSingleSpace(getNodeText(n)))
        if n.tag == "para":
            s = "%s\n\n" % s
98
99
    return s.strip()

100
OPENMM_RE_PATTERN=re.compile("(.*)OpenMM:[a-zA-Z0-9:]*:(.*)")
101
def stripOpenmmPrefix(name, rePattern=OPENMM_RE_PATTERN):
102
103
104
105
106
107
108
    try:
        m=rePattern.search(name)
        rValue = "%s%s" % m.group(1,2)
        rValue.strip()
        return rValue
    except:
        return name
109

110
111
112
113
114
115
116
117
118
119
def findNodes(parent, path, **args):
    nodes = []
    for node in parent.findall(path):
        match = True
        for arg in args:
            if arg not in node.attrib or node.attrib[arg] != args[arg]:
                match = False
        if match:
            nodes.append(node)
    return nodes
120
121
122
123
124

def getClassMethodList(classNode, skipMethods):
    className = getText("compoundname", classNode)
    shortClassName=stripOpenmmPrefix(className)
    methodList=[]
125
    for section in findNodes(classNode, "sectiondef", kind="public-static-func")+findNodes(classNode, "sectiondef", kind="public-func"):
126
        for memberNode in findNodes(section, "memberdef", kind="function", prot="public"):
127
128
            methDefinition = getText("definition", memberNode)
            shortMethDefinition=stripOpenmmPrefix(methDefinition)
129
            shortMethDefinition = shortMethDefinition.replace(' &', '&')
130
131
132
133
134
135
136
137
138
            methName=shortMethDefinition.split()[-1]
            if (shortClassName, methName) in skipMethods: continue
            numParams=len(findNodes(memberNode, 'param'))
            if (shortClassName, methName, numParams) in skipMethods: continue
            for catchString in ['Factory', 'Impl', 'Info', 'Kernel']:
                if shortClassName.endswith(catchString):
                    sys.stderr.write("Warning: Including class %s\n" %
                                     shortClassName)
                    continue
139

140
            if (shortClassName, methName) in skipMethods: continue
141

142
            # set template info
143

144
145
            templateType = getText("templateparamlist/param/type", memberNode)
            templateName = getText("templateparamlist/param/declname", memberNode)
146

147
148
149
150
151
152
            methodList.append( (shortClassName,
                                memberNode,
                                shortMethDefinition,
                                methName,
                                shortClassName==methName,
                                '~'+shortClassName==methName, templateType, templateName ) )
153
154
155
    return methodList


Robert McGibbon's avatar
Robert McGibbon committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def docstringTypemap(cpptype):
    """Translate a C++ type to Python for inclusion in the Python docstrings.
    This doesn't need to be perfectly accurate -- it's not used for generating
    the actual swig wrapper code. It's only used for generating the docstrings.
    """
    pytype = cpptype
    if pytype.startswith('const '):
        pytype = pytype[6:]
    if pytype.startswith('std::'):
        pytype = pytype[5:]
    pytype = pytype.strip('&')
    return pytype.strip()


170
171
class SwigInputBuilder:
    def __init__(self,
172
                 inputDirname,
173
174
175
176
177
                 configFilename,
                 outputFilename=None,
                 docstringFilename=None,
                 pythonprependFilename=None,
                 pythonappendFilename=None,
178
179
                 skipAdditionalMethods=[],
                 SWIG_VERSION='3.0.2'):
180
181
182
183
184
185
186
187
188
189
190
        self.nodeByID={}

        self.configModule = __import__(os.path.splitext(configFilename)[0])

        self.skipMethods=self.configModule.SKIP_METHODS[:]
        for skipMethod in skipAdditionalMethods:
            items=skipMethod.split('::')
            if len(items)==3:
                items[2]=int(items[2])
            self.skipMethods.append(tuple(items))

191
192
193
        # Read all the XML files and merge them into a single document.
        self.doc = etree.ElementTree(etree.Element('root'))
        for file in os.listdir(inputDirname):
194
195
196
197
            if file.lower().endswith('xml'):
                root = etree.parse(os.path.join(inputDirname, file)).getroot()
                for node in root:
                    self.doc.getroot().append(node)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

        if outputFilename:
            self.fOut = open(outputFilename, 'w')
        else:
            self.fOut = sys.stdout

        if docstringFilename:
            self.fOutDocstring = open(docstringFilename, 'w')
        else:
            self.fOutDocstring = None

        if pythonprependFilename:
            self.fOutPythonprepend = open(pythonprependFilename, 'w')
        else:
            self.fOutPythonprepend = None

        if pythonappendFilename:
            self.fOutPythonappend = open(pythonappendFilename, 'w')
        else:
            self.fOutPythonappend = None

        self._enumByClassname={}

        self._orderedClassNodes=self._buildOrderedClassNodes()

    def _getNodeByID(self, id):
        if id not in self.nodeByID:
225
226
            for node in findNodes(self.doc.getroot(), "compounddef", id=id):
                self.nodeByID[id] = node
227
228
229
230
        return self.nodeByID[id]

    def _buildOrderedClassNodes(self):
        orderedClassNodes=[]
231
        for node in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
232
233
234
235
236
237
238
239
            self._findBaseNodes(node, orderedClassNodes)
        return orderedClassNodes

    def _findBaseNodes(self, node, excludedClassNodes=[]):
        if node in excludedClassNodes: return
        nodeName = getText("compoundname", node)
        if (nodeName.split("::")[-1],) in self.skipMethods:
            return
240
        for baseNodePnt in findNodes(node, "basecompoundref", prot="public"):
241
242
243
244
            if "refid" in baseNodePnt.attrib:
                baseNodeID = baseNodePnt.attrib["refid"]
                baseNode = self._getNodeByID(baseNodeID)
                self._findBaseNodes(baseNode, excludedClassNodes)
245
246
247
        excludedClassNodes.append(node)


248
249
250
251
    def writeFactories(self):
        self.fOut.write("\n/* Declare factories */\n\n")
        forceSubclassList = []
        integratorSubclassList = []
252
        tabulatedFunctionSubclassList = []
253
        for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
254
255
256
257
            className = getText("compoundname", classNode)
            shortClassName=stripOpenmmPrefix(className)
            if (className.split("::")[-1],) in self.skipMethods:
                continue
258
            for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
259
260
261
262
263
264
                if "refid" in baseNodePnt.attrib:
                    baseNodeID=baseNodePnt.attrib["refid"]
                    baseNode=self._getNodeByID(baseNodeID)
                    baseName = getText("compoundname", baseNode)
                    if baseName == 'OpenMM::Force':
                        forceSubclassList.append(shortClassName)
265
                    elif baseName in ('OpenMM::Integrator', 'OpenMM::DrudeIntegrator'):
266
                        integratorSubclassList.append(shortClassName)
267
268
                    elif baseName == 'OpenMM::TabulatedFunction':
                        tabulatedFunctionSubclassList.append(shortClassName)
269
270
        # We need to include subclasses of DrudeIntegrator, but not DrudeIntegrator itself.
        integratorSubclassList.remove('DrudeIntegrator')
Robert McGibbon's avatar
Robert McGibbon committed
271

272
273
274
275
        self.fOut.write("%factory(OpenMM::Force& OpenMM::System::getForce")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
276

277
        self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__cloneForce")
Robert McGibbon's avatar
Robert McGibbon committed
278
279
280
281
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

282
283
284
285
        self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__deserializeForce")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
286

287
288
289
290
291
        self.fOut.write("%factory(OpenMM::Force& OpenMM::CustomCVForce::getCollectiveVariable")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

292
        self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__cloneIntegrator")
Robert McGibbon's avatar
Robert McGibbon committed
293
294
295
296
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

297
298
299
300
        self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__deserializeIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
301

302
303
304
305
        self.fOut.write("%factory(OpenMM::Integrator& OpenMM::Context::getIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
306

307
308
309
310
311
        self.fOut.write("%factory(OpenMM::Integrator& OpenMM::CompoundIntegrator::getIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

312
        self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM_XmlSerializer__cloneTabulatedFunction")
313
314
315
316
317
318
319
320
321
        for name in sorted(tabulatedFunctionSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

        self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM_XmlSerializer__deserializeTabulatedFunction")
        for name in sorted(tabulatedFunctionSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

322
323
324
325
326
327
328
329
330
331
332
333
334
        for classNode in self._orderedClassNodes:
            methodList=getClassMethodList(classNode, self.skipMethods)
            for items in methodList:
                (shortClassName, memberNode,
                 shortMethDefinition, methName,
                 isConstructors, isDestructor, templateType, templateName) = items
                if shortMethDefinition == 'TabulatedFunction& getTabulatedFunction':
                    self.fOut.write("%factory(OpenMM::TabulatedFunction& OpenMM::")
                    self.fOut.write("%s::%s" % (shortClassName, methName))
                    for name in sorted(tabulatedFunctionSubclassList):
                        self.fOut.write(",\n         OpenMM::%s" % name)
                    self.fOut.write(");\n\n")

335
        self.fOut.write("%factory(OpenMM::VirtualSite& OpenMM::System::getVirtualSite, OpenMM::TwoParticleAverageSite, OpenMM::ThreeParticleAverageSite, OpenMM::OutOfPlaneSite, OpenMM::LocalCoordinatesSite);\n\n")
336
337
338
339
        self.fOut.write("\n")

    def writeGlobalConstants(self):
        self.fOut.write("/* Global Constants */\n\n")
Peter Eastman's avatar
Peter Eastman committed
340
        node = next((x for x in findNodes(self.doc.getroot(), "compounddef", kind="namespace") if x.findtext("compoundname") == "OpenMM"))
341
342
343
344
        for section in findNodes(node, "sectiondef", kind="var"):
            for memberNode in findNodes(section, "memberdef", kind="variable", mutable="no", prot="public", static="yes"):
                vDef = stripOpenmmPrefix(getText("definition", memberNode))
                iDef = getText("initializer", memberNode)
345
346
                if iDef.startswith("="):
                    iDef = iDef[1:]
347
                self.fOut.write("static %s = %s;\n" % (vDef, iDef))
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        self.fOut.write("\n")

    def writeForwardDeclarations(self):
        self.fOut.write("\n/* Forward Declarations */\n\n")

        for classNode in self._orderedClassNodes:
            hasConstructor=False
            methodList=getClassMethodList(classNode, self.skipMethods)
            for items in methodList:
                (shortClassName, memberNode,
                 shortMethDefinition, methName,
                 isConstructors, isDestructor, templateType, templateName) = items
                if isConstructors:
                    hasConstructor=True

            className = stripOpenmmPrefix(getText("compoundname", classNode))
            # If has a constructor then tell swig tell to make a copy method
            if hasConstructor:
                self.fOut.write("%%copyctor %s ;\n" % className)
            self.fOut.write("class %s ;\n" % className)
        self.fOut.write("\n")

    def writeClassDeclarations(self):
        self.fOut.write("\n/* Class Declarations */\n\n")
        for classNode in self._orderedClassNodes:
            className = stripOpenmmPrefix(getText("compoundname", classNode))
374
375
376
377
            if self.fOutDocstring:
                dNode = classNode.find('detaileddescription')
                if dNode is not None:
                    docstring = getNodeText(dNode).strip().replace('"', '\\"')
378
                    docstring = striphtmltags(docstring)
379
                    self.fOutDocstring.write('%%feature("docstring") %s "%s";\n' % (className, docstring))
380
381
382
383
384
            self.fOut.write("class %s" % className)
            if className in self.configModule.MISSING_BASE_CLASSES:
                self.fOut.write(" : public %s" %
                                self.configModule.MISSING_BASE_CLASSES[className])

385
            for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
386
387
388
                if "refid" in baseNodePnt.attrib:
                    baseName = stripOpenmmPrefix(getText(".", baseNodePnt))
                    self.fOut.write(" : public %s" % baseName)
389
390
391
392
393
394
395
396
            self.fOut.write(" {\n")
            self.fOut.write("public:\n")
            self.writeEnumerations(classNode)
            self.writeMethods(classNode)
            self.fOut.write("};\n\n")
        self.fOut.write("\n")

    def writeEnumerations(self, classNode):
397
398
399
400
        enumNodes = []
        for section in findNodes(classNode, "sectiondef", kind="public-type"):
            for node in findNodes(section, "memberdef", kind="enum", prot="public"):
                enumNodes.append(node)
401
402
403
404
405
406
407
408
409
410
        for enumNode in enumNodes:
            className = getText("compoundname", classNode)
            shortClassName=stripOpenmmPrefix(className)
            enumName = getText("name", enumNode)
            try:
                self._enumByClassname[shortClassName].append(enumName)
            except KeyError:
                self._enumByClassname[shortClassName]=[enumName]
            self.fOut.write("%senum %s {" % (INDENT, enumName))
            argSep="\n"
411
            for valueNode in findNodes(enumNode, "enumvalue", prot="public"):
412
413
                vName = getText("name", valueNode)
                vInit = getText("initializer", valueNode)
414
415
                if vInit.startswith("="):
                    vInit = vInit[1:]
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
                self.fOut.write("%s%s%s = %s" % (argSep, 2*INDENT, vName, vInit))
                argSep=",\n"
            self.fOut.write("\n%s};\n" % INDENT)
        if len(enumNodes)>0: self.fOut.write("\n")

    def writeMethods(self, classNode):
        methodList=getClassMethodList(classNode, self.skipMethods)

        #write only Constructors
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isConstructors:
                mArgsstring = getText("argsstring", memberNode)
                try:
                    pExceptions = " %s" % getText('exceptions', memberNode)
                except IndexError:
                    pExceptions = ""
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition,
                                                 mArgsstring, pExceptions))
        #write only Destructors
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isDestructor:
                mArgsstring = getText("argsstring", memberNode)
                try:
                    pExceptions = " %s" % getText('exceptions', memberNode)
                except IndexError:
                    pExceptions = ""
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition,
                                                 mArgsstring, pExceptions))

        #write only non Constructor and Destructor methods and python mods
        self.fOut.write("\n")
453
        methodsWithOutputArgs = set()
454
455
456
457
458
459
460
461
462
463
464
465
        for items in methodList:
            clearOutput=""
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isConstructors or isDestructor: continue

            key = (shortClassName, methName)
            if key in self.configModule.DOC_STRINGS:
                self.fOut.write('%%feature("autodoc", "%s") %s;\n' %
                                (self.configModule.DOC_STRINGS[key], methName))

466
            paramList=findNodes(memberNode, 'param')
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
            for pNode in paramList:
                try:
                    pType = getText('type', pNode)
                except IndexError:
                    pType = getText('type/ref', pNode)
                pName = getText('declname', pNode)
                key = (shortClassName, methName, pName)
                if pType.find('&')>=0 and \
                   'const' not in pType.split():
                    if key not in self.configModule.NO_OUTPUT_ARGS:
                        eType = pType.split()[0]
                        if shortClassName in self._enumByClassname and \
                           eType in self._enumByClassname[shortClassName]:
                            simpleType = re.sub(eType, 'int', pType)
                        else:
                            simpleType = pType
                        self.fOut.write("%s%%apply %s OUTPUT { %s %s };\n" %
                                        (INDENT, simpleType, pType, pName))
                        clearOutput = "%s%s%%clear %s %s;\n" \
                                     % (clearOutput, INDENT, pType, pName)
487
                        methodsWithOutputArgs.add((shortClassName, methName))
488
489
490
491
492
493

            mArgsstring = getText("argsstring", memberNode)
            try:
                pExceptions = " %s" % getText('exceptions', memberNode)
            except IndexError:
                pExceptions = ""
494
            if memberNode.attrib["virt"].strip()!='non-virtual':
495
496
497
498
499
500
501
502
503
504
505
506
507
508
                if 'virtual' not in shortMethDefinition.split():
                    shortMethDefinition="virtual %s" % shortMethDefinition
            if( len(templateType) > 0 ):
                self.fOut.write("%stemplate<%s %s> %s%s%s;\n" % (INDENT, templateType, templateName, shortMethDefinition, mArgsstring, pExceptions))
            else:
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition, mArgsstring, pExceptions))
            if clearOutput:
                self.fOut.write(clearOutput)

        #write python mod blocks
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
509
            paramList = findNodes(memberNode, 'param')
510

511
            # write pythonprepend blocks
512
513
514
515
            if isConstructors:
                mArgsstring = '' # specifying args to constructors seems to prevent append and prepend from working
            else:
                mArgsstring = getText("argsstring", memberNode)
516
517
            if self.fOutPythonprepend and \
               len(paramList) and \
518
               not is_method_abstract(mArgsstring):
519
520
521
522
523
524
                text = '''
%pythonprepend OpenMM::{shortClassName}::{methName}{mArgsstring} %{{{{{{0}}
%}}}}'''.format(shortClassName=shortClassName, methName=methName, mArgsstring=mArgsstring)
                textInside = ''
                key = (shortClassName, methName)
                for argNum in self.configModule.STEAL_OWNERSHIP.get(key, []):
525
                    argName = getText('declname', paramList[argNum])
526
527

                    textInside += '''
528
529
530
    if not {argName}.thisown:
        s = ("the %s object does not own its corresponding OpenMM object"
             % self.__class__.__name__)
531
        raise Exception(s)'''.format(argName=argName)
532

533

534
535
536
537
538
539
540
541
                # Convert input arguments to the proper units, if specified.
                if key not in methodsWithOutputArgs:
                    if key in self.configModule.UNITS:
                        argUnits=self.configModule.UNITS[key][1]
                    elif ("*", methName) in self.configModule.UNITS:
                        argUnits=self.configModule.UNITS[("*", methName)][1]
                    else:
                        argUnits = ()
542
                    if len(argUnits) > 0 and isConstructors:
543
544
545
546
                        textInside += '''
    args = list(args)'''
                    for i, units in enumerate(argUnits):
                        if units is not None:
547
                            if isConstructors:
548
549
550
551
552
553
554
                                argName = 'args[%s]' % i
                            else:
                                argName = getText('declname', paramList[i])
                            textInside += '''
    if unit.is_quantity({argName}):
        {argName} = {argName}.value_in_unit({units})'''.format(argName=argName, units=units)

555
                for argNum in self.configModule.REQUIRE_ORDERED_SET.get(key, []):
556
                    argName = getText('declname', paramList[argNum])
557
558
559
560
561
562
563

                    textInside += '''
    {argName} = list({argName})'''.format(argName=argName)
                if textInside:
                    self.fOutPythonprepend.write(text.format(textInside))

            # write pythonappend blocks
564
            if self.fOutPythonappend \
565
               and not is_method_abstract(mArgsstring):
566
                key = (shortClassName, methName)
567
568
                #sys.stdout.write("key %s %s \n" % (shortClassName, methName))

569
                addText=''
570
                returnType = getText("type", memberNode)
571
572
573
574
575

                if key in self.configModule.UNITS:
                    valueUnits=self.configModule.UNITS[key]
                elif ("*", methName) in self.configModule.UNITS:
                    valueUnits=self.configModule.UNITS[("*", methName)]
576
                elif methName.startswith('get') and returnType not in ('void', 'int', 'bool', 'std::string', 'const std::string &'):
577
578
                    s = 'do not know how to add units to %s %s::%s' \
                        % (returnType, shortClassName, methName)
579
580
581
582
583
                    raise Exception(s)
                else:
                    valueUnits=[None, ()]

                index=0
584
                if valueUnits[0] is not None:
585
586
587
588
589
590
591
592
593
594
595
596
597
                    sys.stdout.write("%s.%s() returns %s\n" %
                                     (shortClassName, methName, valueUnits[0]))
                    if len(valueUnits[1])>0:
                        addText = "%s%sval[%d]=unit.Quantity(val[%d], %s)\n" \
                                 % (addText, INDENT,
                                    index, index,
                                    valueUnits[0])
                        index+=1
                    else:
                        addText = "%s%sval=unit.Quantity(val, %s)\n" \
                                 % (addText, INDENT, valueUnits[0])

                for vUnit in valueUnits[1]:
598
                    if vUnit is not None and key in methodsWithOutputArgs:
599
                        addText = "%s%sval[%s]=unit.Quantity(val[%s], %s)\n" \
600
                                     % (addText, INDENT, index, index, vUnit)
601
                    index+=1
602
603
604

                if key in self.configModule.STEAL_OWNERSHIP:
                    for argNum in self.configModule.STEAL_OWNERSHIP[key]:
605
                        argName = getText('declname', paramList[argNum])
606
607
                        addText = "%s%s%s.thisown=0\n" \
                                % (addText, INDENT, argName)
608
609
610
611
612
613
614
615
616
617
618

                if addText:
                    self.fOutPythonappend.write("%pythonappend")
                    self.fOutPythonappend.write(" OpenMM::%s::%s(" % key)
                    sepChar=''
                    outputIndex=0
                    for pNode in paramList:
                        try:
                            pType = getText('type', pNode)
                        except IndexError:
                            pType = getText('type/ref', pNode)
619
620
621
622
623
624
625
                        # parse default arguments
                        try:
                            defaultValue = getText('defval', pNode)
                        except:
                            defaultValue = ""
                        if defaultValue != "":
                            defaultValue = "=%s" %defaultValue
626
                        pName = getText('declname', pNode)
627
                        self.fOutPythonappend.write("%s%s %s%s" % (sepChar, pType, pName, defaultValue))
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
                        sepChar=', '

                        if pType.find('&')>=0 and \
                          'const' not in pType.split() and \
                          key not in self.configModule.NO_OUTPUT_ARGS and \
                          len(valueUnits[1])>0:
                            try:
                                unitType=valueUnits[1][outputIndex]
                            except IndexError:
                                s = "missing unit type for %s.%s() arg named %s" \
                                   % (shortClassName, methName, pName)
                                raise Exception(s)
                            sys.stdout.write("%s.%s() returns %s as %s\n" %
                                             (shortClassName, methName,
                                              pName, unitType))
                            outputIndex+=1
644
                    if memberNode.attrib["const"]=="yes":
645
646
647
648
649
650
651
652
653
654
655
656
657
658
                        constString=' const'
                    else:
                        constString=''
                    self.fOutPythonappend.write(")%s %%{\n" % constString)
                    self.fOutPythonappend.write(addText)
                    self.fOutPythonappend.write("%}\n\n")


        #print "Done python mod blocks\n"
        #write Docstring info
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName ) = items
Robert McGibbon's avatar
Robert McGibbon committed
659

660
            if self.fOutDocstring:
Robert McGibbon's avatar
Robert McGibbon committed
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
                signatureParams = findNodes(memberNode, 'param')
                assert len(findNodes(memberNode, 'detaileddescription')) == 1
                dNode = findNodes(memberNode, 'detaileddescription')[0]

                try:
                    description=getText('para', dNode)
                    description.strip()
                except IndexError:
                    description = ''
                params = findNodes(dNode, 'para/parameterlist/parameteritem')

                paramString = ['Parameters', '----------']
                returnString = ['Returns', '-------']

                if len(params) > 0:
                    if len(signatureParams) != len(params):
                        raise ValueError('docstring in %s.%s does not match the signature' % (shortClassName, methName))

                    for pNode, pSignatureNode in zip(params, signatureParams):
                        parameterNameNode = findNodes(pNode, 'parameternamelist/parametername')[0]
681
                        argDoc = getText('parameterdescription/para', pNode)
Robert McGibbon's avatar
Robert McGibbon committed
682
683
                        argName = getNodeText(parameterNameNode)
                        argType = docstringTypemap(getText('type', pSignatureNode))
684

Robert McGibbon's avatar
Robert McGibbon committed
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
                        isOutput = parameterNameNode.get('direction') == 'out'
                        if isOutput:
                            returnString.extend(['%s : %s' % (argName, argType), '    %s' % argDoc])
                        else:
                            paramString.extend(['%s : %s' % (argName, argType), '    %s' % argDoc])


                returnSection = findNodes(dNode, 'para/simplesect')
                if len(returnSection) > 0:
                    returnNode = returnSection[0]
                    if returnNode.get('kind') == 'return':
                        argType = getNodeText(findNodes(memberNode, 'type')[0])
                        argType = docstringTypemap(argType)
                        returnString.extend([argType, '    %s' % getNodeText(returnNode).strip()])

                dString = '\n'.join(
                    ([description] + [''] if len(description) > 0 else []) +
                    (paramString + [''] if len(paramString) > 2 else []) +
                    (returnString if len(returnString) > 2 else [])).strip()
                if dString:
                    dString = re.sub(r'([^\\])"', r'\g<1>\"', dString)
                    s = '%%feature("docstring") OpenMM::%s::%s "%s";' \
                       % (shortClassName, methName, dString)
                    self.fOutDocstring.write("%s\n" % s)

                self.fOutDocstring.write("\n\n")
711
712
713
714
715
716


    def writeSwigFile(self):
        self.fOut.write("/* Swig input file,\n")
        self.fOut.write("%sgenerated by %s on %s\n*/\n\n\n"
                        % (INDENT, sys.argv[0], time.asctime()))
717
        self.fOut.write("%include \"factory.i\"\n")
718
        self.fOut.write("\nnamespace OpenMM {\n\n")
719
        self.writeFactories()
720
721
722
723
724
725
726
        self.writeGlobalConstants()
        self.writeForwardDeclarations()
        self.writeClassDeclarations()
        self.fOut.write("\n} // namespace OpenMM\n\n")


def parseCommandLine():
727
    opts, args_proper = getopt.getopt(sys.argv[1:], 'hi:c:o:d:a:z:s:v:')
728
    inputDirname = None
729
730
731
732
733
734
    configFilename = None
    outputFilename = ""
    docstringFilename = ""
    pythonprependFilename = ""
    pythonappendFilename = ""
    skipAdditionalMethods = []
735
    swigVersion = '3.0.2'
736
737
    for option, parameter in opts:
        if option=='-h': usageError()
738
        if option=='-i': inputDirname = parameter
739
740
741
742
743
744
        if option=='-c': configFilename=parameter
        if option=='-o': outputFilename = parameter
        if option=='-d': docstringFilename = parameter
        if option=='-a': pythonprependFilename=parameter
        if option=='-z': pythonappendFilename=parameter
        if option=='-s': skipAdditionalMethods.append(parameter)
745
        if option=='-v': swigVersion = parameter
746
    if not inputDirname: usageError()
747
    if not configFilename: usageError()
748
    return (args_proper, inputDirname, configFilename, outputFilename,
749
750
            docstringFilename, pythonprependFilename, pythonappendFilename,
            skipAdditionalMethods, swigVersion)
751
752

def main():
753
    (args_proper, inputDirname, configFilename, outputFilename,
754
     docstringFilename, pythonprependFilename, pythonappendFilename,
755
     skipAdditionalMethods, swigVersion) = parseCommandLine()
756
    sBuilder = SwigInputBuilder(inputDirname, configFilename, outputFilename,
757
                                docstringFilename, pythonprependFilename,
758
759
                                pythonappendFilename, skipAdditionalMethods,
                                swigVersion)
760
761
762
763
764
765
766
767
768
769
770
771
    #print "Calling writeSwigFile\n"
    sBuilder.writeSwigFile()
    #print "Done writeSwigFile\n"

    return


def usageError():
    sys.stdout.write('usage: %s -i xmlHeadersFilename \\\n' \
         % os.path.basename(sys.argv[0]))
    sys.stdout.write('       %s -c inputConfigFilename\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
772
    sys.stdout.write('       %s[-o swigInputDirname]\n' \
773
774
775
776
777
778
779
780
781
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-d docstringFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-a pythonprependFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-z pythonappendFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-s skippedClasses]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
782
783
    sys.stdout.write('       %s[-v swigVersion]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
784
785
786
787
788
789
    sys.exit(1)

if __name__=='__main__':
    main()