swigInputBuilder.py 33.9 KB
Newer Older
Robert McGibbon's avatar
typo  
Robert McGibbon committed
1
#!/usr/bin/env python
2
3
4
"""Build swig imput file from xml encoded header files (see gccxml)."""
__author__ = "Randall J. Radmer"
__version__ = "1.0"
5
6


7
8
9
10
import sys, os
import time
import getopt
import re
11
import xml.etree.ElementTree as etree
12

13
14
15
16
17
18
try:
    from html.parser import HTMLParser
except ImportError:
    # python 2
    from HTMLParser import HTMLParser

19
INDENT = "   "
20
docTags = {'emphasis':'i', 'bold':'b', 'itemizedlist':'ul', 'listitem':'li', 'preformatted':'pre', 'computeroutput':'tt', 'subscript':'sub', 'verbatim': 'verbatim'}
21

22
23
24
def is_method_abstract(argstring):
    return argstring.split(")")[-1].find("=0") >= 0

25
26
27
28
29
30
def striphtmltags(s):
    """Strip a couple html tags used inside docstrings in the C++ source
    to produce something more easily read as plain text.
    """
    class ConvertLists(HTMLParser):
        def reset(self):
Robert McGibbon's avatar
Robert McGibbon committed
31
            HTMLParser.reset(self)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
            self.out = []

        def handle_starttag(self, tag, attrs):
            if tag == 'li':
                self.out.append('\n - ')
        def handle_data(self, data):
            self.out.append(data.strip())

    convertlists = ConvertLists()

    def replace_ul_tags(m):
        a, b = m.span()
        sub = s[a:b]

        convertlists.reset()
        convertlists.feed(sub)
        return '\n%s\n\n' % ''.join(convertlists.out)

Robert McGibbon's avatar
typo  
Robert McGibbon committed
50
51
52
    s = s.replace('<i>', '_').replace('</i>', '_')
    s = s.replace('<b>', '*').replace('</b>', '*')

Robert McGibbon's avatar
Robert McGibbon committed
53
    s = re.sub('\s*(<ul>.*?</ul>\s*)', replace_ul_tags, s, flags=re.MULTILINE | re.DOTALL)
54
55
    return s

56
57
58
59
60
61
62
63
64
65
66
def trimToSingleSpace(text):
    if text is None or len(text) == 0:
        return ""
    t = text.strip()
    if len(t) == 0:
        return t
    if text[0].isspace():
        t = " %s" % t
    if text[-1].isspace():
        t = "%s " % t
    return t
67

68
69
70
71
72
73
74
75
76
77
def getNodeText(node):
    if node.text is not None:
        s = node.text
    else:
        s = ""
    for n in node:
        if n.tag == "para":
            s = "%s%s\n\n" % (s, getNodeText(n))
        elif n.tag == "ref":
            s = "%s%s" % (s, getNodeText(n))
78
79
80
81
82
        elif n.tag == "xrefsect":
            title = n.find("xreftitle")
            description = n.find("xrefdescription")
            if title is not None and description is not None and getNodeText(title).lower() == "deprecated":
                s = "%s\n@deprecated %s\n\n" % (s, getNodeText(description))
83
84
85
86
87
88
89
90
        else:
            if n.tag in docTags:
                tag = docTags[n.tag]
                s = "%s<%s>%s</%s>" % (s, tag, getNodeText(n), tag)
        if n.tail is not None:
            s = "%s%s" % (s, n.tail)
    return s

91
def getText(subNodePath, node):
92
    s = ""
93
    for n in node.findall(subNodePath):
94
95
96
        s = "%s%s" % (s, trimToSingleSpace(getNodeText(n)))
        if n.tag == "para":
            s = "%s\n\n" % s
97
98
99
100
    return s.strip()

OPENMM_RE_PATTERN=re.compile("(.*)OpenMM:[a-zA-Z:]*:(.*)")
def stripOpenmmPrefix(name, rePattern=OPENMM_RE_PATTERN):
101
102
103
104
105
106
107
    try:
        m=rePattern.search(name)
        rValue = "%s%s" % m.group(1,2)
        rValue.strip()
        return rValue
    except:
        return name
108

109
110
111
112
113
114
115
116
117
118
def findNodes(parent, path, **args):
    nodes = []
    for node in parent.findall(path):
        match = True
        for arg in args:
            if arg not in node.attrib or node.attrib[arg] != args[arg]:
                match = False
        if match:
            nodes.append(node)
    return nodes
119
120
121
122
123

def getClassMethodList(classNode, skipMethods):
    className = getText("compoundname", classNode)
    shortClassName=stripOpenmmPrefix(className)
    methodList=[]
124
    for section in findNodes(classNode, "sectiondef", kind="public-static-func")+findNodes(classNode, "sectiondef", kind="public-func"):
125
        for memberNode in findNodes(section, "memberdef", kind="function", prot="public"):
126
127
128
129
130
131
132
133
134
135
136
            methDefinition = getText("definition", memberNode)
            shortMethDefinition=stripOpenmmPrefix(methDefinition)
            methName=shortMethDefinition.split()[-1]
            if (shortClassName, methName) in skipMethods: continue
            numParams=len(findNodes(memberNode, 'param'))
            if (shortClassName, methName, numParams) in skipMethods: continue
            for catchString in ['Factory', 'Impl', 'Info', 'Kernel']:
                if shortClassName.endswith(catchString):
                    sys.stderr.write("Warning: Including class %s\n" %
                                     shortClassName)
                    continue
137

138
            if (shortClassName, methName) in skipMethods: continue
139

140
            # set template info
141

142
143
            templateType = getText("templateparamlist/param/type", memberNode)
            templateName = getText("templateparamlist/param/declname", memberNode)
144

145
146
147
148
149
150
            methodList.append( (shortClassName,
                                memberNode,
                                shortMethDefinition,
                                methName,
                                shortClassName==methName,
                                '~'+shortClassName==methName, templateType, templateName ) )
151
152
153
    return methodList


Robert McGibbon's avatar
Robert McGibbon committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def docstringTypemap(cpptype):
    """Translate a C++ type to Python for inclusion in the Python docstrings.
    This doesn't need to be perfectly accurate -- it's not used for generating
    the actual swig wrapper code. It's only used for generating the docstrings.
    """
    pytype = cpptype
    if pytype.startswith('const '):
        pytype = pytype[6:]
    if pytype.startswith('std::'):
        pytype = pytype[5:]
    pytype = pytype.strip('&')
    return pytype.strip()


168
169
class SwigInputBuilder:
    def __init__(self,
170
                 inputDirname,
171
172
173
174
175
                 configFilename,
                 outputFilename=None,
                 docstringFilename=None,
                 pythonprependFilename=None,
                 pythonappendFilename=None,
176
177
                 skipAdditionalMethods=[],
                 SWIG_VERSION='3.0.2'):
178
179
180
181
182
183
184
185
186
187
188
        self.nodeByID={}

        self.configModule = __import__(os.path.splitext(configFilename)[0])

        self.skipMethods=self.configModule.SKIP_METHODS[:]
        for skipMethod in skipAdditionalMethods:
            items=skipMethod.split('::')
            if len(items)==3:
                items[2]=int(items[2])
            self.skipMethods.append(tuple(items))

189
190
191
        # Read all the XML files and merge them into a single document.
        self.doc = etree.ElementTree(etree.Element('root'))
        for file in os.listdir(inputDirname):
192
193
194
195
            if file.lower().endswith('xml'):
                root = etree.parse(os.path.join(inputDirname, file)).getroot()
                for node in root:
                    self.doc.getroot().append(node)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

        if outputFilename:
            self.fOut = open(outputFilename, 'w')
        else:
            self.fOut = sys.stdout

        if docstringFilename:
            self.fOutDocstring = open(docstringFilename, 'w')
        else:
            self.fOutDocstring = None

        if pythonprependFilename:
            self.fOutPythonprepend = open(pythonprependFilename, 'w')
        else:
            self.fOutPythonprepend = None

        if pythonappendFilename:
            self.fOutPythonappend = open(pythonappendFilename, 'w')
        else:
            self.fOutPythonappend = None

        self._enumByClassname={}

        self._orderedClassNodes=self._buildOrderedClassNodes()

    def _getNodeByID(self, id):
        if id not in self.nodeByID:
223
224
            for node in findNodes(self.doc.getroot(), "compounddef", id=id):
                self.nodeByID[id] = node
225
226
227
228
        return self.nodeByID[id]

    def _buildOrderedClassNodes(self):
        orderedClassNodes=[]
229
        for node in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
230
231
232
233
234
235
236
237
            self._findBaseNodes(node, orderedClassNodes)
        return orderedClassNodes

    def _findBaseNodes(self, node, excludedClassNodes=[]):
        if node in excludedClassNodes: return
        nodeName = getText("compoundname", node)
        if (nodeName.split("::")[-1],) in self.skipMethods:
            return
238
        for baseNodePnt in findNodes(node, "basecompoundref", prot="public"):
239
240
241
242
            if "refid" in baseNodePnt.attrib:
                baseNodeID = baseNodePnt.attrib["refid"]
                baseNode = self._getNodeByID(baseNodeID)
                self._findBaseNodes(baseNode, excludedClassNodes)
243
244
245
        excludedClassNodes.append(node)


246
247
248
249
    def writeFactories(self):
        self.fOut.write("\n/* Declare factories */\n\n")
        forceSubclassList = []
        integratorSubclassList = []
250
        tabulatedFunctionSubclassList = []
251
        for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
252
253
254
255
            className = getText("compoundname", classNode)
            shortClassName=stripOpenmmPrefix(className)
            if (className.split("::")[-1],) in self.skipMethods:
                continue
256
            for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
257
258
259
260
261
262
                if "refid" in baseNodePnt.attrib:
                    baseNodeID=baseNodePnt.attrib["refid"]
                    baseNode=self._getNodeByID(baseNodeID)
                    baseName = getText("compoundname", baseNode)
                    if baseName == 'OpenMM::Force':
                        forceSubclassList.append(shortClassName)
263
264
                    elif baseName == 'OpenMM::Integrator':
                        integratorSubclassList.append(shortClassName)
265
266
                    elif baseName == 'OpenMM::TabulatedFunction':
                        tabulatedFunctionSubclassList.append(shortClassName)
Robert McGibbon's avatar
Robert McGibbon committed
267

268
269
270
271
        self.fOut.write("%factory(OpenMM::Force& OpenMM::System::getForce")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
272
273
274
275
276
277

        self.fOut.write("%factory(OpenMM::Force* OpenMM::Force::__copy__")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

278
279
280
281
        self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__deserializeForce")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
282

283
284
285
286
287
        self.fOut.write("%factory(OpenMM::Force& OpenMM::CustomCVForce::getCollectiveVariable")
        for name in sorted(forceSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

Robert McGibbon's avatar
Robert McGibbon committed
288
289
290
291
292
        self.fOut.write("%factory(OpenMM::Integrator* OpenMM::Integrator::__copy__")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

293
294
295
296
        self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__deserializeIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
297

298
299
300
301
        self.fOut.write("%factory(OpenMM::Integrator& OpenMM::Context::getIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")
Robert McGibbon's avatar
Robert McGibbon committed
302

303
304
305
306
307
        self.fOut.write("%factory(OpenMM::Integrator& OpenMM::CompoundIntegrator::getIntegrator")
        for name in sorted(integratorSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

308
309
310
311
312
313
314
315
316
317
        self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM::TabulatedFunction::__copy__")
        for name in sorted(tabulatedFunctionSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

        self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM_XmlSerializer__deserializeTabulatedFunction")
        for name in sorted(tabulatedFunctionSubclassList):
            self.fOut.write(",\n         OpenMM::%s" % name)
        self.fOut.write(");\n\n")

318
319
320
321
322
323
324
325
326
327
328
329
330
        for classNode in self._orderedClassNodes:
            methodList=getClassMethodList(classNode, self.skipMethods)
            for items in methodList:
                (shortClassName, memberNode,
                 shortMethDefinition, methName,
                 isConstructors, isDestructor, templateType, templateName) = items
                if shortMethDefinition == 'TabulatedFunction& getTabulatedFunction':
                    self.fOut.write("%factory(OpenMM::TabulatedFunction& OpenMM::")
                    self.fOut.write("%s::%s" % (shortClassName, methName))
                    for name in sorted(tabulatedFunctionSubclassList):
                        self.fOut.write(",\n         OpenMM::%s" % name)
                    self.fOut.write(");\n\n")

331
        self.fOut.write("%factory(OpenMM::VirtualSite& OpenMM::System::getVirtualSite, OpenMM::TwoParticleAverageSite, OpenMM::ThreeParticleAverageSite, OpenMM::OutOfPlaneSite, OpenMM::LocalCoordinatesSite);\n\n")
332
333
334
335
        self.fOut.write("\n")

    def writeGlobalConstants(self):
        self.fOut.write("/* Global Constants */\n\n")
Peter Eastman's avatar
Peter Eastman committed
336
        node = next((x for x in findNodes(self.doc.getroot(), "compounddef", kind="namespace") if x.findtext("compoundname") == "OpenMM"))
337
338
339
340
        for section in findNodes(node, "sectiondef", kind="var"):
            for memberNode in findNodes(section, "memberdef", kind="variable", mutable="no", prot="public", static="yes"):
                vDef = stripOpenmmPrefix(getText("definition", memberNode))
                iDef = getText("initializer", memberNode)
341
342
                if iDef.startswith("="):
                    iDef = iDef[1:]
343
                self.fOut.write("static %s = %s;\n" % (vDef, iDef))
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        self.fOut.write("\n")

    def writeForwardDeclarations(self):
        self.fOut.write("\n/* Forward Declarations */\n\n")

        for classNode in self._orderedClassNodes:
            hasConstructor=False
            methodList=getClassMethodList(classNode, self.skipMethods)
            for items in methodList:
                (shortClassName, memberNode,
                 shortMethDefinition, methName,
                 isConstructors, isDestructor, templateType, templateName) = items
                if isConstructors:
                    hasConstructor=True

            className = stripOpenmmPrefix(getText("compoundname", classNode))
            # If has a constructor then tell swig tell to make a copy method
            if hasConstructor:
                self.fOut.write("%%copyctor %s ;\n" % className)
            self.fOut.write("class %s ;\n" % className)
        self.fOut.write("\n")

    def writeClassDeclarations(self):
        self.fOut.write("\n/* Class Declarations */\n\n")
        for classNode in self._orderedClassNodes:
            className = stripOpenmmPrefix(getText("compoundname", classNode))
370
371
372
373
            if self.fOutDocstring:
                dNode = classNode.find('detaileddescription')
                if dNode is not None:
                    docstring = getNodeText(dNode).strip().replace('"', '\\"')
374
                    docstring = striphtmltags(docstring)
375
                    self.fOutDocstring.write('%%feature("docstring") %s "%s";\n' % (className, docstring))
376
377
378
379
380
            self.fOut.write("class %s" % className)
            if className in self.configModule.MISSING_BASE_CLASSES:
                self.fOut.write(" : public %s" %
                                self.configModule.MISSING_BASE_CLASSES[className])

381
            for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
382
383
384
                if "refid" in baseNodePnt.attrib:
                    baseName = stripOpenmmPrefix(getText(".", baseNodePnt))
                    self.fOut.write(" : public %s" % baseName)
385
386
387
388
389
390
391
392
            self.fOut.write(" {\n")
            self.fOut.write("public:\n")
            self.writeEnumerations(classNode)
            self.writeMethods(classNode)
            self.fOut.write("};\n\n")
        self.fOut.write("\n")

    def writeEnumerations(self, classNode):
393
394
395
396
        enumNodes = []
        for section in findNodes(classNode, "sectiondef", kind="public-type"):
            for node in findNodes(section, "memberdef", kind="enum", prot="public"):
                enumNodes.append(node)
397
398
399
400
401
402
403
404
405
406
        for enumNode in enumNodes:
            className = getText("compoundname", classNode)
            shortClassName=stripOpenmmPrefix(className)
            enumName = getText("name", enumNode)
            try:
                self._enumByClassname[shortClassName].append(enumName)
            except KeyError:
                self._enumByClassname[shortClassName]=[enumName]
            self.fOut.write("%senum %s {" % (INDENT, enumName))
            argSep="\n"
407
            for valueNode in findNodes(enumNode, "enumvalue", prot="public"):
408
409
                vName = getText("name", valueNode)
                vInit = getText("initializer", valueNode)
410
411
                if vInit.startswith("="):
                    vInit = vInit[1:]
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
                self.fOut.write("%s%s%s = %s" % (argSep, 2*INDENT, vName, vInit))
                argSep=",\n"
            self.fOut.write("\n%s};\n" % INDENT)
        if len(enumNodes)>0: self.fOut.write("\n")

    def writeMethods(self, classNode):
        methodList=getClassMethodList(classNode, self.skipMethods)

        #write only Constructors
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isConstructors:
                mArgsstring = getText("argsstring", memberNode)
                try:
                    pExceptions = " %s" % getText('exceptions', memberNode)
                except IndexError:
                    pExceptions = ""
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition,
                                                 mArgsstring, pExceptions))
        #write only Destructors
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isDestructor:
                mArgsstring = getText("argsstring", memberNode)
                try:
                    pExceptions = " %s" % getText('exceptions', memberNode)
                except IndexError:
                    pExceptions = ""
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition,
                                                 mArgsstring, pExceptions))

        #write only non Constructor and Destructor methods and python mods
        self.fOut.write("\n")
449
        methodsWithOutputArgs = set()
450
451
452
453
454
455
456
457
458
459
460
461
        for items in methodList:
            clearOutput=""
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
            if isConstructors or isDestructor: continue

            key = (shortClassName, methName)
            if key in self.configModule.DOC_STRINGS:
                self.fOut.write('%%feature("autodoc", "%s") %s;\n' %
                                (self.configModule.DOC_STRINGS[key], methName))

462
            paramList=findNodes(memberNode, 'param')
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
            for pNode in paramList:
                try:
                    pType = getText('type', pNode)
                except IndexError:
                    pType = getText('type/ref', pNode)
                pName = getText('declname', pNode)
                key = (shortClassName, methName, pName)
                if pType.find('&')>=0 and \
                   'const' not in pType.split():
                    if key not in self.configModule.NO_OUTPUT_ARGS:
                        eType = pType.split()[0]
                        if shortClassName in self._enumByClassname and \
                           eType in self._enumByClassname[shortClassName]:
                            simpleType = re.sub(eType, 'int', pType)
                        else:
                            simpleType = pType
                        self.fOut.write("%s%%apply %s OUTPUT { %s %s };\n" %
                                        (INDENT, simpleType, pType, pName))
                        clearOutput = "%s%s%%clear %s %s;\n" \
                                     % (clearOutput, INDENT, pType, pName)
483
                        methodsWithOutputArgs.add((shortClassName, methName))
484
485
486
487
488
489

            mArgsstring = getText("argsstring", memberNode)
            try:
                pExceptions = " %s" % getText('exceptions', memberNode)
            except IndexError:
                pExceptions = ""
490
            if memberNode.attrib["virt"].strip()!='non-virtual':
491
492
493
494
495
496
497
498
499
500
501
502
503
504
                if 'virtual' not in shortMethDefinition.split():
                    shortMethDefinition="virtual %s" % shortMethDefinition
            if( len(templateType) > 0 ):
                self.fOut.write("%stemplate<%s %s> %s%s%s;\n" % (INDENT, templateType, templateName, shortMethDefinition, mArgsstring, pExceptions))
            else:
                self.fOut.write("%s%s%s%s;\n" % (INDENT, shortMethDefinition, mArgsstring, pExceptions))
            if clearOutput:
                self.fOut.write(clearOutput)

        #write python mod blocks
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName) = items
505
            paramList = findNodes(memberNode, 'param')
506

507
            # write pythonprepend blocks
508
509
510
511
            if isConstructors:
                mArgsstring = '' # specifying args to constructors seems to prevent append and prepend from working
            else:
                mArgsstring = getText("argsstring", memberNode)
512
513
            if self.fOutPythonprepend and \
               len(paramList) and \
514
               not is_method_abstract(mArgsstring):
515
516
517
518
519
520
                text = '''
%pythonprepend OpenMM::{shortClassName}::{methName}{mArgsstring} %{{{{{{0}}
%}}}}'''.format(shortClassName=shortClassName, methName=methName, mArgsstring=mArgsstring)
                textInside = ''
                key = (shortClassName, methName)
                for argNum in self.configModule.STEAL_OWNERSHIP.get(key, []):
521
                    argName = getText('declname', paramList[argNum])
522
523

                    textInside += '''
524
525
526
    if not {argName}.thisown:
        s = ("the %s object does not own its corresponding OpenMM object"
             % self.__class__.__name__)
527
        raise Exception(s)'''.format(argName=argName)
528

529

530
531
532
533
534
535
536
537
                # Convert input arguments to the proper units, if specified.
                if key not in methodsWithOutputArgs:
                    if key in self.configModule.UNITS:
                        argUnits=self.configModule.UNITS[key][1]
                    elif ("*", methName) in self.configModule.UNITS:
                        argUnits=self.configModule.UNITS[("*", methName)][1]
                    else:
                        argUnits = ()
538
                    if len(argUnits) > 0 and isConstructors:
539
540
541
542
                        textInside += '''
    args = list(args)'''
                    for i, units in enumerate(argUnits):
                        if units is not None:
543
                            if isConstructors:
544
545
546
547
548
549
550
                                argName = 'args[%s]' % i
                            else:
                                argName = getText('declname', paramList[i])
                            textInside += '''
    if unit.is_quantity({argName}):
        {argName} = {argName}.value_in_unit({units})'''.format(argName=argName, units=units)

551
                for argNum in self.configModule.REQUIRE_ORDERED_SET.get(key, []):
552
                    argName = getText('declname', paramList[argNum])
553
554
555
556
557
558
559

                    textInside += '''
    {argName} = list({argName})'''.format(argName=argName)
                if textInside:
                    self.fOutPythonprepend.write(text.format(textInside))

            # write pythonappend blocks
560
            if self.fOutPythonappend \
561
               and not is_method_abstract(mArgsstring):
562
                key = (shortClassName, methName)
563
564
                #sys.stdout.write("key %s %s \n" % (shortClassName, methName))

565
                addText=''
566
                returnType = getText("type", memberNode)
567
568
569
570
571

                if key in self.configModule.UNITS:
                    valueUnits=self.configModule.UNITS[key]
                elif ("*", methName) in self.configModule.UNITS:
                    valueUnits=self.configModule.UNITS[("*", methName)]
572
                elif methName.startswith('get') and returnType not in ('void', 'int', 'bool', 'std::string', 'const std::string &'):
573
574
                    s = 'do not know how to add units to %s %s::%s' \
                        % (returnType, shortClassName, methName)
575
576
577
578
579
                    raise Exception(s)
                else:
                    valueUnits=[None, ()]

                index=0
580
                if valueUnits[0] is not None:
581
582
583
584
585
586
587
588
589
590
591
592
593
                    sys.stdout.write("%s.%s() returns %s\n" %
                                     (shortClassName, methName, valueUnits[0]))
                    if len(valueUnits[1])>0:
                        addText = "%s%sval[%d]=unit.Quantity(val[%d], %s)\n" \
                                 % (addText, INDENT,
                                    index, index,
                                    valueUnits[0])
                        index+=1
                    else:
                        addText = "%s%sval=unit.Quantity(val, %s)\n" \
                                 % (addText, INDENT, valueUnits[0])

                for vUnit in valueUnits[1]:
594
                    if vUnit is not None and key in methodsWithOutputArgs:
595
                        addText = "%s%sval[%s]=unit.Quantity(val[%s], %s)\n" \
596
                                     % (addText, INDENT, index, index, vUnit)
597
                    index+=1
598
599
600

                if key in self.configModule.STEAL_OWNERSHIP:
                    for argNum in self.configModule.STEAL_OWNERSHIP[key]:
601
                        argName = getText('declname', paramList[argNum])
602
603
                        addText = "%s%s%s.thisown=0\n" \
                                % (addText, INDENT, argName)
604
605
606
607
608
609
610
611
612
613
614

                if addText:
                    self.fOutPythonappend.write("%pythonappend")
                    self.fOutPythonappend.write(" OpenMM::%s::%s(" % key)
                    sepChar=''
                    outputIndex=0
                    for pNode in paramList:
                        try:
                            pType = getText('type', pNode)
                        except IndexError:
                            pType = getText('type/ref', pNode)
615
616
617
618
619
620
621
                        # parse default arguments
                        try:
                            defaultValue = getText('defval', pNode)
                        except:
                            defaultValue = ""
                        if defaultValue != "":
                            defaultValue = "=%s" %defaultValue
622
                        pName = getText('declname', pNode)
623
                        self.fOutPythonappend.write("%s%s %s%s" % (sepChar, pType, pName, defaultValue))
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
                        sepChar=', '

                        if pType.find('&')>=0 and \
                          'const' not in pType.split() and \
                          key not in self.configModule.NO_OUTPUT_ARGS and \
                          len(valueUnits[1])>0:
                            try:
                                unitType=valueUnits[1][outputIndex]
                            except IndexError:
                                s = "missing unit type for %s.%s() arg named %s" \
                                   % (shortClassName, methName, pName)
                                raise Exception(s)
                            sys.stdout.write("%s.%s() returns %s as %s\n" %
                                             (shortClassName, methName,
                                              pName, unitType))
                            outputIndex+=1
640
                    if memberNode.attrib["const"]=="yes":
641
642
643
644
645
646
647
648
649
650
651
652
653
654
                        constString=' const'
                    else:
                        constString=''
                    self.fOutPythonappend.write(")%s %%{\n" % constString)
                    self.fOutPythonappend.write(addText)
                    self.fOutPythonappend.write("%}\n\n")


        #print "Done python mod blocks\n"
        #write Docstring info
        for items in methodList:
            (shortClassName, memberNode,
             shortMethDefinition, methName,
             isConstructors, isDestructor, templateType, templateName ) = items
Robert McGibbon's avatar
Robert McGibbon committed
655

656
            if self.fOutDocstring:
Robert McGibbon's avatar
Robert McGibbon committed
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
                signatureParams = findNodes(memberNode, 'param')
                assert len(findNodes(memberNode, 'detaileddescription')) == 1
                dNode = findNodes(memberNode, 'detaileddescription')[0]

                try:
                    description=getText('para', dNode)
                    description.strip()
                except IndexError:
                    description = ''
                params = findNodes(dNode, 'para/parameterlist/parameteritem')

                paramString = ['Parameters', '----------']
                returnString = ['Returns', '-------']

                if len(params) > 0:
                    if len(signatureParams) != len(params):
                        raise ValueError('docstring in %s.%s does not match the signature' % (shortClassName, methName))

                    for pNode, pSignatureNode in zip(params, signatureParams):
                        parameterNameNode = findNodes(pNode, 'parameternamelist/parametername')[0]
677
                        argDoc = getText('parameterdescription/para', pNode)
Robert McGibbon's avatar
Robert McGibbon committed
678
679
                        argName = getNodeText(parameterNameNode)
                        argType = docstringTypemap(getText('type', pSignatureNode))
680

Robert McGibbon's avatar
Robert McGibbon committed
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
                        isOutput = parameterNameNode.get('direction') == 'out'
                        if isOutput:
                            returnString.extend(['%s : %s' % (argName, argType), '    %s' % argDoc])
                        else:
                            paramString.extend(['%s : %s' % (argName, argType), '    %s' % argDoc])


                returnSection = findNodes(dNode, 'para/simplesect')
                if len(returnSection) > 0:
                    returnNode = returnSection[0]
                    if returnNode.get('kind') == 'return':
                        argType = getNodeText(findNodes(memberNode, 'type')[0])
                        argType = docstringTypemap(argType)
                        returnString.extend([argType, '    %s' % getNodeText(returnNode).strip()])

                dString = '\n'.join(
                    ([description] + [''] if len(description) > 0 else []) +
                    (paramString + [''] if len(paramString) > 2 else []) +
                    (returnString if len(returnString) > 2 else [])).strip()
                if dString:
                    dString = re.sub(r'([^\\])"', r'\g<1>\"', dString)
                    s = '%%feature("docstring") OpenMM::%s::%s "%s";' \
                       % (shortClassName, methName, dString)
                    self.fOutDocstring.write("%s\n" % s)

                self.fOutDocstring.write("\n\n")
707
708
709
710
711
712
713


    def writeSwigFile(self):
        self.fOut.write("/* Swig input file,\n")
        self.fOut.write("%sgenerated by %s on %s\n*/\n\n\n"
                        % (INDENT, sys.argv[0], time.asctime()))
        self.fOut.write("\nnamespace OpenMM {\n\n")
714
        self.writeFactories()
715
716
717
718
719
720
721
        self.writeGlobalConstants()
        self.writeForwardDeclarations()
        self.writeClassDeclarations()
        self.fOut.write("\n} // namespace OpenMM\n\n")


def parseCommandLine():
722
    opts, args_proper = getopt.getopt(sys.argv[1:], 'hi:c:o:d:a:z:s:v:')
723
    inputDirname = None
724
725
726
727
728
729
    configFilename = None
    outputFilename = ""
    docstringFilename = ""
    pythonprependFilename = ""
    pythonappendFilename = ""
    skipAdditionalMethods = []
730
    swigVersion = '3.0.2'
731
732
    for option, parameter in opts:
        if option=='-h': usageError()
733
        if option=='-i': inputDirname = parameter
734
735
736
737
738
739
        if option=='-c': configFilename=parameter
        if option=='-o': outputFilename = parameter
        if option=='-d': docstringFilename = parameter
        if option=='-a': pythonprependFilename=parameter
        if option=='-z': pythonappendFilename=parameter
        if option=='-s': skipAdditionalMethods.append(parameter)
740
        if option=='-v': swigVersion = parameter
741
    if not inputDirname: usageError()
742
    if not configFilename: usageError()
743
    return (args_proper, inputDirname, configFilename, outputFilename,
744
745
            docstringFilename, pythonprependFilename, pythonappendFilename,
            skipAdditionalMethods, swigVersion)
746
747

def main():
748
    (args_proper, inputDirname, configFilename, outputFilename,
749
     docstringFilename, pythonprependFilename, pythonappendFilename,
750
     skipAdditionalMethods, swigVersion) = parseCommandLine()
751
    sBuilder = SwigInputBuilder(inputDirname, configFilename, outputFilename,
752
                                docstringFilename, pythonprependFilename,
753
754
                                pythonappendFilename, skipAdditionalMethods,
                                swigVersion)
755
756
757
758
759
760
761
762
763
764
765
766
    #print "Calling writeSwigFile\n"
    sBuilder.writeSwigFile()
    #print "Done writeSwigFile\n"

    return


def usageError():
    sys.stdout.write('usage: %s -i xmlHeadersFilename \\\n' \
         % os.path.basename(sys.argv[0]))
    sys.stdout.write('       %s -c inputConfigFilename\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
767
    sys.stdout.write('       %s[-o swigInputDirname]\n' \
768
769
770
771
772
773
774
775
776
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-d docstringFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-a pythonprependFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-z pythonappendFilename]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
    sys.stdout.write('       %s[-s skippedClasses]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
777
778
    sys.stdout.write('       %s[-v swigVersion]\n' \
         % (' '*len(os.path.basename(sys.argv[0]))))
779
780
781
782
783
784
    sys.exit(1)

if __name__=='__main__':
    main()