generateWrappers.py 34.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import sys, os
import time
import getopt
import re
import xml.etree.ElementTree as etree

def trimToSingleSpace(text):
    if text is None or len(text) == 0:
        return ""
    t = text.strip()
    if len(t) == 0:
        return t
    if text[0].isspace():
        t = " %s" % t
    if text[-1].isspace():
        t = "%s " % t
    return t

def getNodeText(node):
    if node.text is not None:
        s = node.text
    else:
        s = ""
    for n in node:
        if n.tag == "para":
            s = "%s%s\n\n" % (s, getNodeText(n))
        elif n.tag == "ref":
            s = "%s%s" % (s, getNodeText(n))
        if n.tail is not None:
            s = "%s%s" % (s, n.tail)
    return s

def getText(subNodePath, node):
    s = ""
    for n in node.findall(subNodePath):
        s = "%s%s" % (s, trimToSingleSpace(getNodeText(n)))
        if n.tag == "para":
            s = "%s\n\n" % s
    return s.strip()

def convertOpenMMPrefix(name):
    return name.replace('OpenMM::', 'OpenMM_')

OPENMM_RE_PATTERN=re.compile("(.*)OpenMM:[a-zA-Z:]*:(.*)")
def stripOpenMMPrefix(name, rePattern=OPENMM_RE_PATTERN):
    try:
        m=rePattern.search(name)
        rValue = "%s%s" % m.group(1,2)
        rValue.strip()
        return rValue
    except:
        return name

def findNodes(parent, path, **args):
    nodes = []
    for node in parent.findall(path):
        match = True
        for arg in args:
            if arg not in node.attrib or node.attrib[arg] != args[arg]:
                match = False
        if match:
            nodes.append(node)
    return nodes

class WrapperGenerator:
peastman's avatar
peastman committed
66
67
    """This is the parent class of generators for various API wrapper files.  It defines functions common to all of them."""
    
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    def __init__(self, inputDirname, output):
        self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy']
        self.skipMethods = ['OpenMM::Context::getState', 'OpenMM::Platform::loadPluginsFromDirectory', 'OpenMM::Context::createCheckpoint', 'OpenMM::Context::loadCheckpoint']
        self.hideClasses = ['Kernel', 'KernelImpl', 'KernelFactory', 'ContextImpl', 'SerializationNode', 'SerializationProxy']
        self.typeTranslations = {'bool': 'OpenMM_Boolean',
                                 'Vec3': 'OpenMM_Vec3',
                                 'std::string': 'char*',
                                 'const std::string &': 'const char*',
                                 'std::vector< std::string >': 'OpenMM_StringArray',
                                 'std::vector< Vec3 >': 'OpenMM_Vec3Array',
                                 'std::vector< std::pair< int, int > >': 'OpenMM_BondArray',
                                 'std::map< std::string, double >': 'OpenMM_ParameterArray',
                                 'std::map< std::string, std::string >': 'OpenMM_PropertyArray',
                                 'std::vector< double >': 'OpenMM_DoubleArray',
                                 'std::vector< int >': 'OpenMM_IntArray',
                                 'std::set< int >': 'OpenMM_IntSet'}
peastman's avatar
peastman committed
84
        self.inverseTranslations = dict((self.typeTranslations[key], key) for key in self.typeTranslations)
85
86
87
88
89
90
91
92
93
        self.nodeByID={}

        # Read all the XML files and merge them into a single document.
        self.doc = etree.ElementTree(etree.Element('root'))
        for file in os.listdir(inputDirname):
            root = etree.parse(os.path.join(inputDirname, file)).getroot()
            for node in root:
                self.doc.getroot().append(node)

peastman's avatar
peastman committed
94
        self.out = output
95
96

        self.typesByShortName = {}
97
        self._orderedClassNodes = self.buildOrderedClassNodes()
98

99
    def getNodeByID(self, id):
100
101
102
103
104
        if id not in self.nodeByID:
            for node in findNodes(self.doc.getroot(), "compounddef", id=id):
                self.nodeByID[id] = node
        return self.nodeByID[id]

105
    def buildOrderedClassNodes(self):
106
107
        orderedClassNodes=[]
        for node in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
108
            self.findBaseNodes(node, orderedClassNodes)
109
110
        return orderedClassNodes

111
    def findBaseNodes(self, node, excludedClassNodes=[]):
112
113
114
115
116
117
118
119
120
121
        if node in excludedClassNodes:
            return
        if node.attrib['prot'] == 'private':
            return
        nodeName = getText("compoundname", node)
        if nodeName in self.skipClasses:
            return
        for baseNodePnt in findNodes(node, "basecompoundref", prot="public"):
            if "refid" in baseNodePnt.attrib:
                baseNodeID = baseNodePnt.attrib["refid"]
122
123
                baseNode = self.getNodeByID(baseNodeID)
                self.findBaseNodes(baseNode, excludedClassNodes)
124
125
        excludedClassNodes.append(node)

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    def getClassMethods(self, classNode):
        className = getText("compoundname", classNode)
        shortClassName = stripOpenMMPrefix(className)
        methodList = []
        for section in findNodes(classNode, "sectiondef", kind="public-static-func")+findNodes(classNode, "sectiondef", kind="public-func"):
            for memberNode in findNodes(section, "memberdef", kind="function", prot="public"):
                methodDefinition = getText("definition", memberNode)
                shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
                methodName = shortMethodDefinition.split()[-1]
                if className+'::'+methodName in self.skipMethods:
                    continue
                methodList.append(memberNode)
        return methodList
    
    def shouldHideType(self, typeName):
        if typeName.startswith('const '):
            typeName = typeName[6:].strip()
        if typeName.endswith('&') or typeName.endswith('*'):
            typeName = typeName[:-1].strip()
        return typeName in self.hideClasses
    
    def shouldHideMethod(self, methodNode):
        paramList = findNodes(methodNode, 'param')
        returnType = self.getType(getText("type", methodNode))
        if self.shouldHideType(returnType):
            return True
        for node in paramList:
            try:
                type = getText('type', node)
            except IndexError:
                type = getText('type/ref', node)
            if self.shouldHideType(type):
                return True
        return False

class CHeaderGenerator(WrapperGenerator):
peastman's avatar
peastman committed
162
163
    """This class generates the header file for the C API wrappers."""
    
164
165
166
    def __init__(self, inputDirname, output):
        WrapperGenerator.__init__(self, inputDirname, output)
    
167
    def writeGlobalConstants(self):
peastman's avatar
peastman committed
168
        self.out.write("/* Global Constants */\n\n")
169
170
171
172
173
174
175
        node = next((x for x in findNodes(self.doc.getroot(), "compounddef", kind="namespace") if x.findtext("compoundname") == "OpenMM"))
        for section in findNodes(node, "sectiondef", kind="var"):
            for memberNode in findNodes(section, "memberdef", kind="variable", mutable="no", prot="public", static="yes"):
                vDef = convertOpenMMPrefix(getText("definition", memberNode))
                iDef = getText("initializer", memberNode)
                if iDef.startswith("="):
                    iDef = iDef[1:]
peastman's avatar
peastman committed
176
                self.out.write("static %s = %s;\n" % (vDef, iDef))
177
178

    def writeTypeDeclarations(self):
peastman's avatar
peastman committed
179
        self.out.write("\n/* Type Declarations */\n\n")
180
181
182
183
        for classNode in self._orderedClassNodes:
            className = getText("compoundname", classNode)
            shortName = stripOpenMMPrefix(className)
            typeName = convertOpenMMPrefix(className)
peastman's avatar
peastman committed
184
            self.out.write("typedef struct %s_struct %s;\n" % (typeName, typeName))
185
186
187
188
189
            self.typesByShortName[shortName] = typeName

    def writeClasses(self):
        for classNode in self._orderedClassNodes:
            className = stripOpenMMPrefix(getText("compoundname", classNode))
peastman's avatar
peastman committed
190
            self.out.write("\n/* %s */\n" % className)
191
192
            self.writeEnumerations(classNode)
            self.writeMethods(classNode)
peastman's avatar
peastman committed
193
        self.out.write("\n")
194
195
196
197
198
199
200
201
202
203
204
205

    def writeEnumerations(self, classNode):
        enumNodes = []
        for section in findNodes(classNode, "sectiondef", kind="public-type"):
            for node in findNodes(section, "memberdef", kind="enum", prot="public"):
                enumNodes.append(node)
        className = getText("compoundname", classNode)
        shortClassName = stripOpenMMPrefix(className)
        typeName = convertOpenMMPrefix(className)
        for enumNode in enumNodes:
            enumName = getText("name", enumNode)
            enumTypeName = "%s_%s" % (typeName, enumName)
peastman's avatar
peastman committed
206
            self.out.write("typedef enum {\n  ")
207
208
209
210
211
212
            argSep=""
            for valueNode in findNodes(enumNode, "enumvalue", prot="public"):
                vName = convertOpenMMPrefix(getText("name", valueNode))
                vInit = getText("initializer", valueNode)
                if vInit.startswith("="):
                    vInit = vInit[1:].strip()
peastman's avatar
peastman committed
213
                self.out.write("%s%s_%s = %s" % (argSep, typeName, vName, vInit))
214
                argSep=", "
peastman's avatar
peastman committed
215
            self.out.write("\n} %s;\n" % enumTypeName)
216
            self.typesByShortName[enumName] = enumTypeName
peastman's avatar
peastman committed
217
        if len(enumNodes)>0: self.out.write("\n")
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

    def writeMethods(self, classNode):
        methodList = self.getClassMethods(classNode)
        className = getText("compoundname", classNode)
        shortClassName = stripOpenMMPrefix(className)
        typeName = convertOpenMMPrefix(className)
        destructorName = '~'+shortClassName

        if not ('abstract' in classNode.attrib and classNode.attrib['abstract'] == 'yes'):
            # Write constructors
            numConstructors = 0
            for methodNode in methodList:
                methodDefinition = getText("definition", methodNode)
                shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
                methodName = shortMethodDefinition.split()[-1]
                if methodName == shortClassName:
                    if self.shouldHideMethod(methodNode):
                        continue
                    numConstructors += 1
                    if numConstructors == 1:
                        suffix = ""
                    else:
                        suffix = "_%d" % numConstructors
peastman's avatar
peastman committed
241
                    self.out.write("extern OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix))
242
                    self.writeArguments(methodNode, False)
peastman's avatar
peastman committed
243
                    self.out.write(");\n")
244
245
    
            # Write destructor
peastman's avatar
peastman committed
246
            self.out.write("extern OPENMM_EXPORT void %s_destroy(%s* target);\n" % (typeName, typeName))
247

peastman's avatar
peastman committed
248
249
        # Record method names for future reference.
        methodNames = {}
250
251
252
        for methodNode in methodList:
            methodDefinition = getText("definition", methodNode)
            shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
peastman's avatar
peastman committed
253
254
255
256
257
            methodNames[methodNode] = shortMethodDefinition.split()[-1]
        
        # Write other methods
        for methodNode in methodList:
            methodName = methodNames[methodNode]
258
259
260
261
            if methodName in (shortClassName, destructorName):
                continue
            if self.shouldHideMethod(methodNode):
                continue
peastman's avatar
peastman committed
262
263
264
265
            isConstMethod = (methodNode.attrib['const'] == 'yes')
            if isConstMethod and any(methodNames[m] == methodName and m.attrib['const'] == 'no' for m in methodList):
                # There are two identical methods that differ only in whether they are const.  Skip the const one.
                continue
266
            returnType = self.getType(getText("type", methodNode))
peastman's avatar
peastman committed
267
            self.out.write("extern OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
268
269
            isInstanceMethod = (methodNode.attrib['static'] != 'yes')
            if isInstanceMethod:
peastman's avatar
peastman committed
270
271
272
                if isConstMethod:
                    self.out.write('const ')
                self.out.write("%s* target" % typeName)
273
            self.writeArguments(methodNode, isInstanceMethod)
peastman's avatar
peastman committed
274
            self.out.write(");\n")
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    
    def writeArguments(self, methodNode, initialSeparator):
        paramList = findNodes(methodNode, 'param')
        if initialSeparator:
            separator = ", "
        else:
            separator = ""
        for node in paramList:
            try:
                type = getText('type', node)
            except IndexError:
                type = getText('type/ref', node)
            if type == 'void':
                continue
            type = self.getType(type)
            name = getText('declname', node)
peastman's avatar
peastman committed
291
            self.out.write("%s%s %s" % (separator, type, name))
292
293
294
295
296
297
298
299
300
301
302
303
304
            separator = ", "
    
    def getType(self, type):
        if type in self.typeTranslations:
            return self.typeTranslations[type]
        if type in self.typesByShortName:
            return self.typesByShortName[type]
        if type.startswith('const '):
            return 'const '+self.getType(type[6:].strip())
        if type.endswith('&') or type.endswith('*'):
            return self.getType(type[:-1].strip())+'*'
        return type

305
    def writeOutput(self):
peastman's avatar
peastman committed
306
        print >>self.out, """
307
308
309
310
311
312
313
#ifndef OPENMM_CWRAPPER_H_
#define OPENMM_CWRAPPER_H_

#ifndef OPENMM_EXPORT
#define OPENMM_EXPORT
#endif
"""
314
315
        self.writeGlobalConstants()
        self.writeTypeDeclarations()
peastman's avatar
peastman committed
316
        print >>self.out, """
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
typedef struct OpenMM_Vec3Array_struct OpenMM_Vec3Array;
typedef struct OpenMM_StringArray_struct OpenMM_StringArray;
typedef struct OpenMM_BondArray_struct OpenMM_BondArray;
typedef struct OpenMM_ParameterArray_struct OpenMM_ParameterArray;
typedef struct OpenMM_PropertyArray_struct OpenMM_PropertyArray;
typedef struct OpenMM_DoubleArray_struct OpenMM_DoubleArray;
typedef struct OpenMM_IntArray_struct OpenMM_IntArray;
typedef struct OpenMM_IntSet_struct OpenMM_IntSet;
typedef struct {double x, y, z;} OpenMM_Vec3;

typedef enum {OpenMM_False = 0, OpenMM_True = 1} OpenMM_Boolean;

#if defined(__cplusplus)
extern "C" {
#endif

/* OpenMM_Vec3 */
extern OPENMM_EXPORT OpenMM_Vec3 OpenMM_Vec3_scale(const OpenMM_Vec3 vec, double scale);

/* OpenMM_Vec3Array */
extern OPENMM_EXPORT OpenMM_Vec3Array* OpenMM_Vec3Array_create(int size);
extern OPENMM_EXPORT void OpenMM_Vec3Array_destroy(OpenMM_Vec3Array* array);
extern OPENMM_EXPORT int OpenMM_Vec3Array_getSize(const OpenMM_Vec3Array* array);
extern OPENMM_EXPORT void OpenMM_Vec3Array_resize(OpenMM_Vec3Array* array, int size);
extern OPENMM_EXPORT void OpenMM_Vec3Array_append(OpenMM_Vec3Array* array, const OpenMM_Vec3 vec);
extern OPENMM_EXPORT void OpenMM_Vec3Array_set(OpenMM_Vec3Array* array, int index, const OpenMM_Vec3 vec);
extern OPENMM_EXPORT const OpenMM_Vec3* OpenMM_Vec3Array_get(const OpenMM_Vec3Array* array, int index);

/* OpenMM_StringArray */
extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_StringArray_create(int size);
extern OPENMM_EXPORT void OpenMM_StringArray_destroy(OpenMM_StringArray* array);
extern OPENMM_EXPORT int OpenMM_StringArray_getSize(const OpenMM_StringArray* array);
extern OPENMM_EXPORT void OpenMM_StringArray_resize(OpenMM_StringArray* array, int size);
extern OPENMM_EXPORT void OpenMM_StringArray_append(OpenMM_StringArray* array, const char* string);
extern OPENMM_EXPORT void OpenMM_StringArray_set(OpenMM_StringArray* array, int index, const char* string);
extern OPENMM_EXPORT const char* OpenMM_StringArray_get(const OpenMM_StringArray* array, int index);

/* OpenMM_BondArray */
extern OPENMM_EXPORT OpenMM_BondArray* OpenMM_BondArray_create(int size);
extern OPENMM_EXPORT void OpenMM_BondArray_destroy(OpenMM_BondArray* array);
extern OPENMM_EXPORT int OpenMM_BondArray_getSize(const OpenMM_BondArray* array);
extern OPENMM_EXPORT void OpenMM_BondArray_resize(OpenMM_BondArray* array, int size);
extern OPENMM_EXPORT void OpenMM_BondArray_append(OpenMM_BondArray* array, int particle1, int particle2);
extern OPENMM_EXPORT void OpenMM_BondArray_set(OpenMM_BondArray* array, int index, int particle1, int particle2);
extern OPENMM_EXPORT void OpenMM_BondArray_get(const OpenMM_BondArray* array, int index, int* particle1, int* particle2);

/* OpenMM_ParameterArray */
extern OPENMM_EXPORT int OpenMM_ParameterArray_getSize(const OpenMM_ParameterArray* array);
extern OPENMM_EXPORT double OpenMM_ParameterArray_get(const OpenMM_ParameterArray* array, const char* name);

/* OpenMM_PropertyArray */
extern OPENMM_EXPORT int OpenMM_PropertyArray_getSize(const OpenMM_PropertyArray* array);
extern OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyArray* array, const char* name);"""

371
372
373
        for type in ('double', 'int'):
            name = 'OpenMM_%sArray' % type.capitalize()
            values = {'type':type, 'name':name}
peastman's avatar
peastman committed
374
            print >>self.out, """
375
376
377
378
379
380
381
382
383
/* %(name)s */
extern OPENMM_EXPORT %(name)s* %(name)s_create(int size);
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* array);
extern OPENMM_EXPORT int %(name)s_getSize(const %(name)s* array);
extern OPENMM_EXPORT void %(name)s_resize(%(name)s* array, int size);
extern OPENMM_EXPORT void %(name)s_append(%(name)s* array, %(type)s value);
extern OPENMM_EXPORT void %(name)s_set(%(name)s* array, int index, %(type)s value);
extern OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index);""" % values

384
385
386
        for type in ('int',):
            name = 'OpenMM_%sSet' % type.capitalize()
            values = {'type':type, 'name':name}
peastman's avatar
peastman committed
387
            print >>self.out, """
388
389
390
391
392
393
/* %(name)s */
extern OPENMM_EXPORT %(name)s* %(name)s_create();
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* set);
extern OPENMM_EXPORT int %(name)s_getSize(const %(name)s* set);
extern OPENMM_EXPORT void %(name)s_insert(%(name)s* set, %(type)s value);""" % values

peastman's avatar
peastman committed
394
        print >>self.out, """
395
396
397
398
399
/* These methods need to be handled specially, since their C++ APIs cannot be directly translated to C.
   Unlike the C++ versions, the return value is allocated on the heap, and you must delete it yourself. */
extern OPENMM_EXPORT OpenMM_State* OpenMM_Context_getState(const OpenMM_Context* target, int types, int enforcePeriodicBox);
extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirectory(const char* directory);"""

400
        self.writeClasses()
401

peastman's avatar
peastman committed
402
        print >>self.out, """
403
404
405
406
#if defined(__cplusplus)
}
#endif

407
408
409
410
#endif /*OPENMM_CWRAPPER_H_*/"""


class CSourceGenerator(WrapperGenerator):
peastman's avatar
peastman committed
411
412
    """This class generates the source file for the C API wrappers."""

413
414
415
    def __init__(self, inputDirname, output):
        WrapperGenerator.__init__(self, inputDirname, output)
        self.classesByShortName = {}
peastman's avatar
peastman committed
416
        self.enumerationTypes = {}
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        self.findTypes()
    
    def findTypes(self):
        for classNode in self._orderedClassNodes:
            className = getText("compoundname", classNode)
            shortName = stripOpenMMPrefix(className)
            typeName = convertOpenMMPrefix(className)
            self.typesByShortName[shortName] = typeName
            self.classesByShortName[shortName] = className

    def findEnumerations(self, classNode):
        enumNodes = []
        for section in findNodes(classNode, "sectiondef", kind="public-type"):
            for node in findNodes(section, "memberdef", kind="enum", prot="public"):
                enumNodes.append(node)
        className = getText("compoundname", classNode)
        typeName = convertOpenMMPrefix(className)
        for enumNode in enumNodes:
            enumName = getText("name", enumNode)
peastman's avatar
peastman committed
436
437
438
439
440
            enumTypeName = "%s_%s" % (typeName, enumName)
            enumClassName = "%s::%s" % (className, enumName)
            self.typesByShortName[enumName] = enumTypeName
            self.classesByShortName[enumName] = enumClassName
            self.enumerationTypes[enumClassName] = enumTypeName
441
442
443
444

    def writeClasses(self):
        for classNode in self._orderedClassNodes:
            className = stripOpenMMPrefix(getText("compoundname", classNode))
peastman's avatar
peastman committed
445
            self.out.write("\n/* OpenMM::%s */\n" % className)
446
447
            self.findEnumerations(classNode)
            self.writeMethods(classNode)
peastman's avatar
peastman committed
448
        self.out.write("\n")
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471

    def writeMethods(self, classNode):
        methodList = self.getClassMethods(classNode)
        className = getText("compoundname", classNode)
        shortClassName = stripOpenMMPrefix(className)
        typeName = convertOpenMMPrefix(className)
        destructorName = '~'+shortClassName

        if not ('abstract' in classNode.attrib and classNode.attrib['abstract'] == 'yes'):
            # Write constructors
            numConstructors = 0
            for methodNode in methodList:
                methodDefinition = getText("definition", methodNode)
                shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
                methodName = shortMethodDefinition.split()[-1]
                if methodName == shortClassName:
                    if self.shouldHideMethod(methodNode):
                        continue
                    numConstructors += 1
                    if numConstructors == 1:
                        suffix = ""
                    else:
                        suffix = "_%d" % numConstructors
peastman's avatar
peastman committed
472
                    self.out.write("OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix))
473
                    self.writeArguments(methodNode, False)
peastman's avatar
peastman committed
474
475
                    self.out.write(") {\n")
                    self.out.write("    return reinterpret_cast<%s*>(new %s(" % (typeName, className))
476
                    self.writeInvocationArguments(methodNode, False)
peastman's avatar
peastman committed
477
478
                    self.out.write("));\n")
                    self.out.write("}\n")
479
480
    
            # Write destructor
peastman's avatar
peastman committed
481
482
483
            self.out.write("OPENMM_EXPORT void %s_destroy(%s* target) {\n" % (typeName, typeName))
            self.out.write("    delete reinterpret_cast<%s*>(target);\n" % className)
            self.out.write("}\n")
484

peastman's avatar
peastman committed
485
486
        # Record method names for future reference.
        methodNames = {}
487
488
489
        for methodNode in methodList:
            methodDefinition = getText("definition", methodNode)
            shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
peastman's avatar
peastman committed
490
491
492
493
494
            methodNames[methodNode] = shortMethodDefinition.split()[-1]
        
        # Write other methods
        for methodNode in methodList:
            methodName = methodNames[methodNode]
495
496
497
498
            if methodName in (shortClassName, destructorName):
                continue
            if self.shouldHideMethod(methodNode):
                continue
peastman's avatar
peastman committed
499
500
501
502
            isConstMethod = (methodNode.attrib['const'] == 'yes')
            if isConstMethod and any(methodNames[m] == methodName and m.attrib['const'] == 'no' for m in methodList):
                # There are two identical methods that differ only in whether they are const.  Skip the const one.
                continue
503
504
505
506
            methodType = getText("type", methodNode)
            returnType = self.getType(methodType)
            if methodType in self.classesByShortName:
                methodType = self.classesByShortName[methodType]
peastman's avatar
peastman committed
507
            self.out.write("OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
508
509
510
            isInstanceMethod = (methodNode.attrib['static'] != 'yes')
            if isInstanceMethod:
                if isConstMethod:
peastman's avatar
peastman committed
511
512
                    self.out.write('const ')
                self.out.write("%s* target" % typeName)
513
            self.writeArguments(methodNode, isInstanceMethod)
peastman's avatar
peastman committed
514
515
            self.out.write(") {\n")
            self.out.write("    ")
516
            if returnType != 'void':
peastman's avatar
peastman committed
517
518
519
520
521
                if methodType.endswith('&'):
                    # Convert references to pointers
                    self.out.write('%s* result = &' % methodType[:-1].strip())
                else:
                    self.out.write('%s result = ' % methodType)
522
            if isInstanceMethod:
peastman's avatar
peastman committed
523
                self.out.write('reinterpret_cast<')
524
                if isConstMethod:
peastman's avatar
peastman committed
525
526
                    self.out.write('const ')
                self.out.write('%s*>(target)->' % className)
527
            else:
peastman's avatar
peastman committed
528
529
                self.out.write('%s::' % className)
            self.out.write('%s(' % methodName)
530
            self.writeInvocationArguments(methodNode, False)
peastman's avatar
peastman committed
531
            self.out.write(');\n')
532
            if returnType != 'void':
peastman's avatar
peastman committed
533
534
                self.out.write('    return %s;\n' % self.wrapValue(methodType, 'result'))
            self.out.write("}\n")
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    
    def writeArguments(self, methodNode, initialSeparator):
        paramList = findNodes(methodNode, 'param')
        if initialSeparator:
            separator = ", "
        else:
            separator = ""
        for node in paramList:
            try:
                type = getText('type', node)
            except IndexError:
                type = getText('type/ref', node)
            if type == 'void':
                continue
            type = self.getType(type)
            name = getText('declname', node)
peastman's avatar
peastman committed
551
            self.out.write("%s%s %s" % (separator, type, name))
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
            separator = ", "
    
    def writeInvocationArguments(self, methodNode, initialSeparator):
        paramList = findNodes(methodNode, 'param')
        if initialSeparator:
            separator = ", "
        else:
            separator = ""
        for node in paramList:
            try:
                type = getText('type', node)
            except IndexError:
                type = getText('type/ref', node)
            if type == 'void':
                continue
            name = getText('declname', node)
peastman's avatar
peastman committed
568
569
570
            if self.getType(type) != type:
                name = self.unwrapValue(type, name)
            self.out.write("%s%s" % (separator, name))
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
            separator = ", "
    
    def getType(self, type):
        if type in self.typeTranslations:
            return self.typeTranslations[type]
        if type in self.typesByShortName:
            return self.typesByShortName[type]
        if type.startswith('const '):
            return 'const '+self.getType(type[6:].strip())
        if type.endswith('&') or type.endswith('*'):
            return self.getType(type[:-1].strip())+'*'
        return type
    
    def wrapValue(self, type, value):
        if type == 'bool':
            return '(%s ? OpenMM_True : OpenMM_False)' % value
        if type == 'std::string':
            return '%s.c_str()' % value
        if type == 'const std::string &':
peastman's avatar
peastman committed
590
591
592
            return '%s->c_str()' % value
        if type in self.enumerationTypes:
            return 'static_cast<%s>(%s)' % (self.enumerationTypes[type], value)
593
594
595
596
597
598
        wrappedType = self.getType(type)
        if wrappedType == type:
            return value;
        if type.endswith('*') or type.endswith('&'):
            return 'reinterpret_cast<%s>(%s)' % (wrappedType, value)
        return 'static_cast<%s>(%s)' % (wrappedType, value)
peastman's avatar
peastman committed
599
600
601
602
603
604
605
606
607
608
609
610
    
    def unwrapValue(self, type, value):
        if type.endswith('&'):
            unwrappedType = type[:-1].strip()
            if unwrappedType in self.classesByShortName:
                unwrappedType  = self.classesByShortName[unwrappedType]
            return '*'+self.unwrapValue(unwrappedType+'*', value)
        if type in self.classesByShortName:
            return 'static_cast<%s>(%s)' % (self.classesByShortName[type], value)
        if type == 'bool':
            return value
        return 'reinterpret_cast<%s>(%s)' % (type, value)
611
612

    def writeOutput(self):
peastman's avatar
peastman committed
613
        print >>self.out, """
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
#include "OpenMM.h"
#include "OpenMMCWrapper.h"
#include <cstring>
#include <vector>

using namespace OpenMM;
using namespace std;

extern "C" {

/* OpenMM_Vec3 */
OPENMM_EXPORT OpenMM_Vec3 OpenMM_Vec3_scale(const OpenMM_Vec3 vec, double scale) {
    OpenMM_Vec3 result = {vec.x*scale, vec.y*scale, vec.z*scale};
    return result;
}

/* OpenMM_Vec3Array */
OPENMM_EXPORT OpenMM_Vec3Array* OpenMM_Vec3Array_create(int size) {
    return reinterpret_cast<OpenMM_Vec3Array*>(new vector<Vec3>(size));
}
OPENMM_EXPORT void OpenMM_Vec3Array_destroy(OpenMM_Vec3Array* array) {
    delete reinterpret_cast<vector<Vec3>*>(array);
}
OPENMM_EXPORT int OpenMM_Vec3Array_getSize(const OpenMM_Vec3Array* array) {
    return reinterpret_cast<const vector<Vec3>*>(array)->size();
}
OPENMM_EXPORT void OpenMM_Vec3Array_resize(OpenMM_Vec3Array* array, int size) {
    reinterpret_cast<vector<Vec3>*>(array)->resize(size);
}
OPENMM_EXPORT void OpenMM_Vec3Array_append(OpenMM_Vec3Array* array, const OpenMM_Vec3 vec) {
    reinterpret_cast<vector<Vec3>*>(array)->push_back(Vec3(vec.x, vec.y, vec.z));
}
OPENMM_EXPORT void OpenMM_Vec3Array_set(OpenMM_Vec3Array* array, int index, const OpenMM_Vec3 vec) {
    (*reinterpret_cast<vector<Vec3>*>(array))[index] = Vec3(vec.x, vec.y, vec.z);
}
OPENMM_EXPORT const OpenMM_Vec3* OpenMM_Vec3Array_get(const OpenMM_Vec3Array* array, int index) {
peastman's avatar
peastman committed
650
    return reinterpret_cast<const OpenMM_Vec3*>((&(*reinterpret_cast<const vector<Vec3>*>(array))[index]));
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
}

/* OpenMM_StringArray */
OPENMM_EXPORT OpenMM_StringArray* OpenMM_StringArray_create(int size) {
    return reinterpret_cast<OpenMM_StringArray*>(new vector<string>(size));
}
OPENMM_EXPORT void OpenMM_StringArray_destroy(OpenMM_StringArray* array) {
    delete reinterpret_cast<vector<string>*>(array);
}
OPENMM_EXPORT int OpenMM_StringArray_getSize(const OpenMM_StringArray* array) {
    return reinterpret_cast<const vector<string>*>(array)->size();
}
OPENMM_EXPORT void OpenMM_StringArray_resize(OpenMM_StringArray* array, int size) {
    reinterpret_cast<vector<string>*>(array)->resize(size);
}
OPENMM_EXPORT void OpenMM_StringArray_append(OpenMM_StringArray* array, const char* str) {
    reinterpret_cast<vector<string>*>(array)->push_back(string(str));
}
OPENMM_EXPORT void OpenMM_StringArray_set(OpenMM_StringArray* array, int index, const char* str) {
    (*reinterpret_cast<vector<string>*>(array))[index] = string(str);
}
OPENMM_EXPORT const char* OpenMM_StringArray_get(const OpenMM_StringArray* array, int index) {
    return (*reinterpret_cast<const vector<string>*>(array))[index].c_str();
}

/* OpenMM_BondArray */
OPENMM_EXPORT OpenMM_BondArray* OpenMM_BondArray_create(int size) {
    return reinterpret_cast<OpenMM_BondArray*>(new vector<pair<int, int> >(size));
}
OPENMM_EXPORT void OpenMM_BondArray_destroy(OpenMM_BondArray* array) {
    delete reinterpret_cast<vector<pair<int, int> >*>(array);
}
OPENMM_EXPORT int OpenMM_BondArray_getSize(const OpenMM_BondArray* array) {
    return reinterpret_cast<const vector<pair<int, int> >*>(array)->size();
}
OPENMM_EXPORT void OpenMM_BondArray_resize(OpenMM_BondArray* array, int size) {
    reinterpret_cast<vector<pair<int, int> >*>(array)->resize(size);
}
OPENMM_EXPORT void OpenMM_BondArray_append(OpenMM_BondArray* array, int particle1, int particle2) {
    reinterpret_cast<vector<pair<int, int> >*>(array)->push_back(pair<int, int>(particle1, particle2));
}
OPENMM_EXPORT void OpenMM_BondArray_set(OpenMM_BondArray* array, int index, int particle1, int particle2) {
    (*reinterpret_cast<vector<pair<int, int> >*>(array))[index] = pair<int, int>(particle1, particle2);
}
OPENMM_EXPORT void OpenMM_BondArray_get(const OpenMM_BondArray* array, int index, int* particle1, int* particle2) {
    pair<int, int> particles = (*reinterpret_cast<const vector<pair<int, int> >*>(array))[index];
    *particle1 = particles.first;
    *particle2 = particles.second;
}

/* OpenMM_ParameterArray */
OPENMM_EXPORT int OpenMM_ParameterArray_getSize(const OpenMM_ParameterArray* array) {
    return reinterpret_cast<const map<string, double>*>(array)->size();
}
OPENMM_EXPORT double OpenMM_ParameterArray_get(const OpenMM_ParameterArray* array, const char* name) {
    const map<string, double>* params = reinterpret_cast<const map<string, double>*>(array);
    const map<string, double>::const_iterator iter = params->find(string(name));
    if (iter == params->end())
        throw OpenMMException("OpenMM_ParameterArray_get: No such parameter");
    return iter->second;
}

/* OpenMM_PropertyArray */
OPENMM_EXPORT int OpenMM_PropertyArray_getSize(const OpenMM_PropertyArray* array) {
    return reinterpret_cast<const map<string, double>*>(array)->size();
}
OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyArray* array, const char* name) {
    const map<string, string>* params = reinterpret_cast<const map<string, string>*>(array);
    const map<string, string>::const_iterator iter = params->find(string(name));
    if (iter == params->end())
        throw OpenMMException("OpenMM_PropertyArray_get: No such property");
    return iter->second.c_str();
}"""

        for type in ('double', 'int'):
            name = 'OpenMM_%sArray' % type.capitalize()
            values = {'type':type, 'name':name}
peastman's avatar
peastman committed
728
            print >>self.out, """
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
/* %(name)s */
OPENMM_EXPORT %(name)s* %(name)s_create(int size) {
    return reinterpret_cast<%(name)s*>(new vector<%(type)s>(size));
}
OPENMM_EXPORT void %(name)s_destroy(%(name)s* array) {
    delete reinterpret_cast<vector<%(type)s>*>(array);
}
OPENMM_EXPORT int %(name)s_getSize(const %(name)s* array) {
    return reinterpret_cast<const vector<%(type)s>*>(array)->size();
}
OPENMM_EXPORT void %(name)s_resize(%(name)s* array, int size) {
    reinterpret_cast<vector<%(type)s>*>(array)->resize(size);
}
OPENMM_EXPORT void %(name)s_append(%(name)s* array, %(type)s value) {
    reinterpret_cast<vector<%(type)s>*>(array)->push_back(value);
}
OPENMM_EXPORT void %(name)s_set(%(name)s* array, int index, %(type)s value) {
    (*reinterpret_cast<vector<%(type)s>*>(array))[index] = value;
}
OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index) {
    return (*reinterpret_cast<const vector<%(type)s>*>(array))[index];
}""" % values

        for type in ('int',):
            name = 'OpenMM_%sSet' % type.capitalize()
            values = {'type':type, 'name':name}
peastman's avatar
peastman committed
755
            print >>self.out, """
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
/* %(name)s */
OPENMM_EXPORT %(name)s* %(name)s_create() {
    return reinterpret_cast<%(name)s*>(new set<%(type)s>());
}
OPENMM_EXPORT void %(name)s_destroy(%(name)s* s) {
    delete reinterpret_cast<set<%(type)s>*>(s);
}
OPENMM_EXPORT int %(name)s_getSize(const %(name)s* s) {
    return reinterpret_cast<const set<%(type)s>*>(s)->size();
}
OPENMM_EXPORT void %(name)s_insert(%(name)s* s, %(type)s value) {
    reinterpret_cast<set<%(type)s>*>(s)->insert(value);
}""" % values

        self.writeClasses()
peastman's avatar
peastman committed
771
        print >>self.out, "}\n"
772

peastman's avatar
peastman committed
773
774
775
776
777
#inputDirname = '/Users/peastman/workspace/openmm/bin-release/wrappers/doxygen/xml'
inputDirname = sys.argv[1]
builder = CHeaderGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.h'), 'w'))
builder.writeOutput()
builder = CSourceGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.cpp'), 'w'))
778
builder.writeOutput()