Commit 3fda9879 authored by peastman's avatar peastman
Browse files

Finished C wrapper generation

parent 19b4fc48
...@@ -63,6 +63,8 @@ def findNodes(parent, path, **args): ...@@ -63,6 +63,8 @@ def findNodes(parent, path, **args):
return nodes return nodes
class WrapperGenerator: class WrapperGenerator:
"""This is the parent class of generators for various API wrapper files. It defines functions common to all of them."""
def __init__(self, inputDirname, output): def __init__(self, inputDirname, output):
self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy'] self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy']
self.skipMethods = ['OpenMM::Context::getState', 'OpenMM::Platform::loadPluginsFromDirectory', 'OpenMM::Context::createCheckpoint', 'OpenMM::Context::loadCheckpoint'] self.skipMethods = ['OpenMM::Context::getState', 'OpenMM::Platform::loadPluginsFromDirectory', 'OpenMM::Context::createCheckpoint', 'OpenMM::Context::loadCheckpoint']
...@@ -79,6 +81,7 @@ class WrapperGenerator: ...@@ -79,6 +81,7 @@ class WrapperGenerator:
'std::vector< double >': 'OpenMM_DoubleArray', 'std::vector< double >': 'OpenMM_DoubleArray',
'std::vector< int >': 'OpenMM_IntArray', 'std::vector< int >': 'OpenMM_IntArray',
'std::set< int >': 'OpenMM_IntSet'} 'std::set< int >': 'OpenMM_IntSet'}
self.inverseTranslations = dict((self.typeTranslations[key], key) for key in self.typeTranslations)
self.nodeByID={} self.nodeByID={}
# Read all the XML files and merge them into a single document. # Read all the XML files and merge them into a single document.
...@@ -88,7 +91,7 @@ class WrapperGenerator: ...@@ -88,7 +91,7 @@ class WrapperGenerator:
for node in root: for node in root:
self.doc.getroot().append(node) self.doc.getroot().append(node)
self.fOut = output self.out = output
self.typesByShortName = {} self.typesByShortName = {}
self._orderedClassNodes = self.buildOrderedClassNodes() self._orderedClassNodes = self.buildOrderedClassNodes()
...@@ -156,11 +159,13 @@ class WrapperGenerator: ...@@ -156,11 +159,13 @@ class WrapperGenerator:
return False return False
class CHeaderGenerator(WrapperGenerator): class CHeaderGenerator(WrapperGenerator):
"""This class generates the header file for the C API wrappers."""
def __init__(self, inputDirname, output): def __init__(self, inputDirname, output):
WrapperGenerator.__init__(self, inputDirname, output) WrapperGenerator.__init__(self, inputDirname, output)
def writeGlobalConstants(self): def writeGlobalConstants(self):
self.fOut.write("/* Global Constants */\n\n") self.out.write("/* Global Constants */\n\n")
node = next((x for x in findNodes(self.doc.getroot(), "compounddef", kind="namespace") if x.findtext("compoundname") == "OpenMM")) node = next((x for x in findNodes(self.doc.getroot(), "compounddef", kind="namespace") if x.findtext("compoundname") == "OpenMM"))
for section in findNodes(node, "sectiondef", kind="var"): for section in findNodes(node, "sectiondef", kind="var"):
for memberNode in findNodes(section, "memberdef", kind="variable", mutable="no", prot="public", static="yes"): for memberNode in findNodes(section, "memberdef", kind="variable", mutable="no", prot="public", static="yes"):
...@@ -168,24 +173,24 @@ class CHeaderGenerator(WrapperGenerator): ...@@ -168,24 +173,24 @@ class CHeaderGenerator(WrapperGenerator):
iDef = getText("initializer", memberNode) iDef = getText("initializer", memberNode)
if iDef.startswith("="): if iDef.startswith("="):
iDef = iDef[1:] iDef = iDef[1:]
self.fOut.write("static %s = %s;\n" % (vDef, iDef)) self.out.write("static %s = %s;\n" % (vDef, iDef))
def writeTypeDeclarations(self): def writeTypeDeclarations(self):
self.fOut.write("\n/* Type Declarations */\n\n") self.out.write("\n/* Type Declarations */\n\n")
for classNode in self._orderedClassNodes: for classNode in self._orderedClassNodes:
className = getText("compoundname", classNode) className = getText("compoundname", classNode)
shortName = stripOpenMMPrefix(className) shortName = stripOpenMMPrefix(className)
typeName = convertOpenMMPrefix(className) typeName = convertOpenMMPrefix(className)
self.fOut.write("typedef struct %s_struct %s;\n" % (typeName, typeName)) self.out.write("typedef struct %s_struct %s;\n" % (typeName, typeName))
self.typesByShortName[shortName] = typeName self.typesByShortName[shortName] = typeName
def writeClasses(self): def writeClasses(self):
for classNode in self._orderedClassNodes: for classNode in self._orderedClassNodes:
className = stripOpenMMPrefix(getText("compoundname", classNode)) className = stripOpenMMPrefix(getText("compoundname", classNode))
self.fOut.write("\n/* %s */\n" % className) self.out.write("\n/* %s */\n" % className)
self.writeEnumerations(classNode) self.writeEnumerations(classNode)
self.writeMethods(classNode) self.writeMethods(classNode)
self.fOut.write("\n") self.out.write("\n")
def writeEnumerations(self, classNode): def writeEnumerations(self, classNode):
enumNodes = [] enumNodes = []
...@@ -198,18 +203,18 @@ class CHeaderGenerator(WrapperGenerator): ...@@ -198,18 +203,18 @@ class CHeaderGenerator(WrapperGenerator):
for enumNode in enumNodes: for enumNode in enumNodes:
enumName = getText("name", enumNode) enumName = getText("name", enumNode)
enumTypeName = "%s_%s" % (typeName, enumName) enumTypeName = "%s_%s" % (typeName, enumName)
self.fOut.write("typedef enum {\n ") self.out.write("typedef enum {\n ")
argSep="" argSep=""
for valueNode in findNodes(enumNode, "enumvalue", prot="public"): for valueNode in findNodes(enumNode, "enumvalue", prot="public"):
vName = convertOpenMMPrefix(getText("name", valueNode)) vName = convertOpenMMPrefix(getText("name", valueNode))
vInit = getText("initializer", valueNode) vInit = getText("initializer", valueNode)
if vInit.startswith("="): if vInit.startswith("="):
vInit = vInit[1:].strip() vInit = vInit[1:].strip()
self.fOut.write("%s%s_%s = %s" % (argSep, typeName, vName, vInit)) self.out.write("%s%s_%s = %s" % (argSep, typeName, vName, vInit))
argSep=", " argSep=", "
self.fOut.write("\n} %s;\n" % enumTypeName) self.out.write("\n} %s;\n" % enumTypeName)
self.typesByShortName[enumName] = enumTypeName self.typesByShortName[enumName] = enumTypeName
if len(enumNodes)>0: self.fOut.write("\n") if len(enumNodes)>0: self.out.write("\n")
def writeMethods(self, classNode): def writeMethods(self, classNode):
methodList = self.getClassMethods(classNode) methodList = self.getClassMethods(classNode)
...@@ -233,31 +238,40 @@ class CHeaderGenerator(WrapperGenerator): ...@@ -233,31 +238,40 @@ class CHeaderGenerator(WrapperGenerator):
suffix = "" suffix = ""
else: else:
suffix = "_%d" % numConstructors suffix = "_%d" % numConstructors
self.fOut.write("extern OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix)) self.out.write("extern OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix))
self.writeArguments(methodNode, False) self.writeArguments(methodNode, False)
self.fOut.write(");\n") self.out.write(");\n")
# Write destructor # Write destructor
self.fOut.write("extern OPENMM_EXPORT void %s_destroy(%s* target);\n" % (typeName, typeName)) self.out.write("extern OPENMM_EXPORT void %s_destroy(%s* target);\n" % (typeName, typeName))
# Write other methods # Record method names for future reference.
methodNames = {}
for methodNode in methodList: for methodNode in methodList:
methodDefinition = getText("definition", methodNode) methodDefinition = getText("definition", methodNode)
shortMethodDefinition = stripOpenMMPrefix(methodDefinition) shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
methodName = shortMethodDefinition.split()[-1] methodNames[methodNode] = shortMethodDefinition.split()[-1]
# Write other methods
for methodNode in methodList:
methodName = methodNames[methodNode]
if methodName in (shortClassName, destructorName): if methodName in (shortClassName, destructorName):
continue continue
if self.shouldHideMethod(methodNode): if self.shouldHideMethod(methodNode):
continue continue
isConstMethod = (methodNode.attrib['const'] == 'yes')
if isConstMethod and any(methodNames[m] == methodName and m.attrib['const'] == 'no' for m in methodList):
# There are two identical methods that differ only in whether they are const. Skip the const one.
continue
returnType = self.getType(getText("type", methodNode)) returnType = self.getType(getText("type", methodNode))
self.fOut.write("extern OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName)) self.out.write("extern OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
isInstanceMethod = (methodNode.attrib['static'] != 'yes') isInstanceMethod = (methodNode.attrib['static'] != 'yes')
if isInstanceMethod: if isInstanceMethod:
if methodNode.attrib['const'] == 'yes': if isConstMethod:
self.fOut.write('const ') self.out.write('const ')
self.fOut.write("%s* target" % typeName) self.out.write("%s* target" % typeName)
self.writeArguments(methodNode, isInstanceMethod) self.writeArguments(methodNode, isInstanceMethod)
self.fOut.write(");\n") self.out.write(");\n")
def writeArguments(self, methodNode, initialSeparator): def writeArguments(self, methodNode, initialSeparator):
paramList = findNodes(methodNode, 'param') paramList = findNodes(methodNode, 'param')
...@@ -274,7 +288,7 @@ class CHeaderGenerator(WrapperGenerator): ...@@ -274,7 +288,7 @@ class CHeaderGenerator(WrapperGenerator):
continue continue
type = self.getType(type) type = self.getType(type)
name = getText('declname', node) name = getText('declname', node)
self.fOut.write("%s%s %s" % (separator, type, name)) self.out.write("%s%s %s" % (separator, type, name))
separator = ", " separator = ", "
def getType(self, type): def getType(self, type):
...@@ -289,7 +303,7 @@ class CHeaderGenerator(WrapperGenerator): ...@@ -289,7 +303,7 @@ class CHeaderGenerator(WrapperGenerator):
return type return type
def writeOutput(self): def writeOutput(self):
print >>out, """ print >>self.out, """
#ifndef OPENMM_CWRAPPER_H_ #ifndef OPENMM_CWRAPPER_H_
#define OPENMM_CWRAPPER_H_ #define OPENMM_CWRAPPER_H_
...@@ -299,7 +313,7 @@ class CHeaderGenerator(WrapperGenerator): ...@@ -299,7 +313,7 @@ class CHeaderGenerator(WrapperGenerator):
""" """
self.writeGlobalConstants() self.writeGlobalConstants()
self.writeTypeDeclarations() self.writeTypeDeclarations()
print >>out, """ print >>self.out, """
typedef struct OpenMM_Vec3Array_struct OpenMM_Vec3Array; typedef struct OpenMM_Vec3Array_struct OpenMM_Vec3Array;
typedef struct OpenMM_StringArray_struct OpenMM_StringArray; typedef struct OpenMM_StringArray_struct OpenMM_StringArray;
typedef struct OpenMM_BondArray_struct OpenMM_BondArray; typedef struct OpenMM_BondArray_struct OpenMM_BondArray;
...@@ -357,7 +371,7 @@ extern OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyA ...@@ -357,7 +371,7 @@ extern OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyA
for type in ('double', 'int'): for type in ('double', 'int'):
name = 'OpenMM_%sArray' % type.capitalize() name = 'OpenMM_%sArray' % type.capitalize()
values = {'type':type, 'name':name} values = {'type':type, 'name':name}
print >>out, """ print >>self.out, """
/* %(name)s */ /* %(name)s */
extern OPENMM_EXPORT %(name)s* %(name)s_create(int size); extern OPENMM_EXPORT %(name)s* %(name)s_create(int size);
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* array); extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* array);
...@@ -370,14 +384,14 @@ extern OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index);""" ...@@ -370,14 +384,14 @@ extern OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index);"""
for type in ('int',): for type in ('int',):
name = 'OpenMM_%sSet' % type.capitalize() name = 'OpenMM_%sSet' % type.capitalize()
values = {'type':type, 'name':name} values = {'type':type, 'name':name}
print >>out, """ print >>self.out, """
/* %(name)s */ /* %(name)s */
extern OPENMM_EXPORT %(name)s* %(name)s_create(); extern OPENMM_EXPORT %(name)s* %(name)s_create();
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* set); extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* set);
extern OPENMM_EXPORT int %(name)s_getSize(const %(name)s* set); extern OPENMM_EXPORT int %(name)s_getSize(const %(name)s* set);
extern OPENMM_EXPORT void %(name)s_insert(%(name)s* set, %(type)s value);""" % values extern OPENMM_EXPORT void %(name)s_insert(%(name)s* set, %(type)s value);""" % values
print >>out, """ print >>self.out, """
/* These methods need to be handled specially, since their C++ APIs cannot be directly translated to C. /* These methods need to be handled specially, since their C++ APIs cannot be directly translated to C.
Unlike the C++ versions, the return value is allocated on the heap, and you must delete it yourself. */ Unlike the C++ versions, the return value is allocated on the heap, and you must delete it yourself. */
extern OPENMM_EXPORT OpenMM_State* OpenMM_Context_getState(const OpenMM_Context* target, int types, int enforcePeriodicBox); extern OPENMM_EXPORT OpenMM_State* OpenMM_Context_getState(const OpenMM_Context* target, int types, int enforcePeriodicBox);
...@@ -385,7 +399,7 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector ...@@ -385,7 +399,7 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector
self.writeClasses() self.writeClasses()
print >>out, """ print >>self.out, """
#if defined(__cplusplus) #if defined(__cplusplus)
} }
#endif #endif
...@@ -394,9 +408,12 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector ...@@ -394,9 +408,12 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector
class CSourceGenerator(WrapperGenerator): class CSourceGenerator(WrapperGenerator):
"""This class generates the source file for the C API wrappers."""
def __init__(self, inputDirname, output): def __init__(self, inputDirname, output):
WrapperGenerator.__init__(self, inputDirname, output) WrapperGenerator.__init__(self, inputDirname, output)
self.classesByShortName = {} self.classesByShortName = {}
self.enumerationTypes = {}
self.findTypes() self.findTypes()
def findTypes(self): def findTypes(self):
...@@ -416,16 +433,19 @@ class CSourceGenerator(WrapperGenerator): ...@@ -416,16 +433,19 @@ class CSourceGenerator(WrapperGenerator):
typeName = convertOpenMMPrefix(className) typeName = convertOpenMMPrefix(className)
for enumNode in enumNodes: for enumNode in enumNodes:
enumName = getText("name", enumNode) enumName = getText("name", enumNode)
self.typesByShortName[enumName] = "%s_%s" % (typeName, enumName) enumTypeName = "%s_%s" % (typeName, enumName)
self.classesByShortName[enumName] = '%s::%s' % (className, enumName) enumClassName = "%s::%s" % (className, enumName)
self.typesByShortName[enumName] = enumTypeName
self.classesByShortName[enumName] = enumClassName
self.enumerationTypes[enumClassName] = enumTypeName
def writeClasses(self): def writeClasses(self):
for classNode in self._orderedClassNodes: for classNode in self._orderedClassNodes:
className = stripOpenMMPrefix(getText("compoundname", classNode)) className = stripOpenMMPrefix(getText("compoundname", classNode))
self.fOut.write("\n/* OpenMM::%s */\n" % className) self.out.write("\n/* OpenMM::%s */\n" % className)
self.findEnumerations(classNode) self.findEnumerations(classNode)
self.writeMethods(classNode) self.writeMethods(classNode)
self.fOut.write("\n") self.out.write("\n")
def writeMethods(self, classNode): def writeMethods(self, classNode):
methodList = self.getClassMethods(classNode) methodList = self.getClassMethods(classNode)
...@@ -449,57 +469,69 @@ class CSourceGenerator(WrapperGenerator): ...@@ -449,57 +469,69 @@ class CSourceGenerator(WrapperGenerator):
suffix = "" suffix = ""
else: else:
suffix = "_%d" % numConstructors suffix = "_%d" % numConstructors
self.fOut.write("OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix)) self.out.write("OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix))
self.writeArguments(methodNode, False) self.writeArguments(methodNode, False)
self.fOut.write(") {\n") self.out.write(") {\n")
self.fOut.write(" return reinterpret_cast<%s*>(new %s(" % (className, className)) self.out.write(" return reinterpret_cast<%s*>(new %s(" % (typeName, className))
self.writeInvocationArguments(methodNode, False) self.writeInvocationArguments(methodNode, False)
self.fOut.write("));\n") self.out.write("));\n")
self.fOut.write("}\n") self.out.write("}\n")
# Write destructor # Write destructor
self.fOut.write("OPENMM_EXPORT void %s_destroy(%s* target) {\n" % (typeName, typeName)) self.out.write("OPENMM_EXPORT void %s_destroy(%s* target) {\n" % (typeName, typeName))
self.fOut.write(" delete reinterpret_cast<%s*>(target);\n" % className) self.out.write(" delete reinterpret_cast<%s*>(target);\n" % className)
self.fOut.write("}\n") self.out.write("}\n")
# Write other methods # Record method names for future reference.
methodNames = {}
for methodNode in methodList: for methodNode in methodList:
methodDefinition = getText("definition", methodNode) methodDefinition = getText("definition", methodNode)
shortMethodDefinition = stripOpenMMPrefix(methodDefinition) shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
methodName = shortMethodDefinition.split()[-1] methodNames[methodNode] = shortMethodDefinition.split()[-1]
# Write other methods
for methodNode in methodList:
methodName = methodNames[methodNode]
if methodName in (shortClassName, destructorName): if methodName in (shortClassName, destructorName):
continue continue
if self.shouldHideMethod(methodNode): if self.shouldHideMethod(methodNode):
continue continue
isConstMethod = (methodNode.attrib['const'] == 'yes')
if isConstMethod and any(methodNames[m] == methodName and m.attrib['const'] == 'no' for m in methodList):
# There are two identical methods that differ only in whether they are const. Skip the const one.
continue
methodType = getText("type", methodNode) methodType = getText("type", methodNode)
returnType = self.getType(methodType) returnType = self.getType(methodType)
if methodType in self.classesByShortName: if methodType in self.classesByShortName:
methodType = self.classesByShortName[methodType] methodType = self.classesByShortName[methodType]
self.fOut.write("OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName)) self.out.write("OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
isInstanceMethod = (methodNode.attrib['static'] != 'yes') isInstanceMethod = (methodNode.attrib['static'] != 'yes')
if isInstanceMethod: if isInstanceMethod:
isConstMethod = (methodNode.attrib['const'] == 'yes')
if isConstMethod: if isConstMethod:
self.fOut.write('const ') self.out.write('const ')
self.fOut.write("%s* target" % typeName) self.out.write("%s* target" % typeName)
self.writeArguments(methodNode, isInstanceMethod) self.writeArguments(methodNode, isInstanceMethod)
self.fOut.write(") {\n") self.out.write(") {\n")
self.fOut.write(" ") self.out.write(" ")
if returnType != 'void': if returnType != 'void':
self.fOut.write('%s result = ' % methodType) if methodType.endswith('&'):
# Convert references to pointers
self.out.write('%s* result = &' % methodType[:-1].strip())
else:
self.out.write('%s result = ' % methodType)
if isInstanceMethod: if isInstanceMethod:
self.fOut.write('reinterpret_cast<') self.out.write('reinterpret_cast<')
if isConstMethod: if isConstMethod:
self.fOut.write('const ') self.out.write('const ')
self.fOut.write('%s*>(target)->' % className) self.out.write('%s*>(target)->' % className)
else: else:
self.fOut.write('%s::' % className) self.out.write('%s::' % className)
self.fOut.write('%s(' % methodName) self.out.write('%s(' % methodName)
self.writeInvocationArguments(methodNode, False) self.writeInvocationArguments(methodNode, False)
self.fOut.write(');\n') self.out.write(');\n')
if returnType != 'void': if returnType != 'void':
self.fOut.write(' return %s;\n' % self.wrapValue(methodType, 'result')) self.out.write(' return %s;\n' % self.wrapValue(methodType, 'result'))
self.fOut.write("}\n") self.out.write("}\n")
def writeArguments(self, methodNode, initialSeparator): def writeArguments(self, methodNode, initialSeparator):
paramList = findNodes(methodNode, 'param') paramList = findNodes(methodNode, 'param')
...@@ -516,7 +548,7 @@ class CSourceGenerator(WrapperGenerator): ...@@ -516,7 +548,7 @@ class CSourceGenerator(WrapperGenerator):
continue continue
type = self.getType(type) type = self.getType(type)
name = getText('declname', node) name = getText('declname', node)
self.fOut.write("%s%s %s" % (separator, type, name)) self.out.write("%s%s %s" % (separator, type, name))
separator = ", " separator = ", "
def writeInvocationArguments(self, methodNode, initialSeparator): def writeInvocationArguments(self, methodNode, initialSeparator):
...@@ -533,7 +565,9 @@ class CSourceGenerator(WrapperGenerator): ...@@ -533,7 +565,9 @@ class CSourceGenerator(WrapperGenerator):
if type == 'void': if type == 'void':
continue continue
name = getText('declname', node) name = getText('declname', node)
self.fOut.write("%s%s" % (separator, name)) if self.getType(type) != type:
name = self.unwrapValue(type, name)
self.out.write("%s%s" % (separator, name))
separator = ", " separator = ", "
def getType(self, type): def getType(self, type):
...@@ -553,7 +587,9 @@ class CSourceGenerator(WrapperGenerator): ...@@ -553,7 +587,9 @@ class CSourceGenerator(WrapperGenerator):
if type == 'std::string': if type == 'std::string':
return '%s.c_str()' % value return '%s.c_str()' % value
if type == 'const std::string &': if type == 'const std::string &':
return '%s.c_str()' % value return '%s->c_str()' % value
if type in self.enumerationTypes:
return 'static_cast<%s>(%s)' % (self.enumerationTypes[type], value)
wrappedType = self.getType(type) wrappedType = self.getType(type)
if wrappedType == type: if wrappedType == type:
return value; return value;
...@@ -561,8 +597,20 @@ class CSourceGenerator(WrapperGenerator): ...@@ -561,8 +597,20 @@ class CSourceGenerator(WrapperGenerator):
return 'reinterpret_cast<%s>(%s)' % (wrappedType, value) return 'reinterpret_cast<%s>(%s)' % (wrappedType, value)
return 'static_cast<%s>(%s)' % (wrappedType, value) return 'static_cast<%s>(%s)' % (wrappedType, value)
def unwrapValue(self, type, value):
if type.endswith('&'):
unwrappedType = type[:-1].strip()
if unwrappedType in self.classesByShortName:
unwrappedType = self.classesByShortName[unwrappedType]
return '*'+self.unwrapValue(unwrappedType+'*', value)
if type in self.classesByShortName:
return 'static_cast<%s>(%s)' % (self.classesByShortName[type], value)
if type == 'bool':
return value
return 'reinterpret_cast<%s>(%s)' % (type, value)
def writeOutput(self): def writeOutput(self):
print >>out, """ print >>self.out, """
#include "OpenMM.h" #include "OpenMM.h"
#include "OpenMMCWrapper.h" #include "OpenMMCWrapper.h"
#include <cstring> #include <cstring>
...@@ -599,7 +647,7 @@ OPENMM_EXPORT void OpenMM_Vec3Array_set(OpenMM_Vec3Array* array, int index, cons ...@@ -599,7 +647,7 @@ OPENMM_EXPORT void OpenMM_Vec3Array_set(OpenMM_Vec3Array* array, int index, cons
(*reinterpret_cast<vector<Vec3>*>(array))[index] = Vec3(vec.x, vec.y, vec.z); (*reinterpret_cast<vector<Vec3>*>(array))[index] = Vec3(vec.x, vec.y, vec.z);
} }
OPENMM_EXPORT const OpenMM_Vec3* OpenMM_Vec3Array_get(const OpenMM_Vec3Array* array, int index) { OPENMM_EXPORT const OpenMM_Vec3* OpenMM_Vec3Array_get(const OpenMM_Vec3Array* array, int index) {
return reinterpret_cast<const OpenMM_Vec3*>((&amp;(*reinterpret_cast<const vector<Vec3>*>(array))[index])); return reinterpret_cast<const OpenMM_Vec3*>((&(*reinterpret_cast<const vector<Vec3>*>(array))[index]));
} }
/* OpenMM_StringArray */ /* OpenMM_StringArray */
...@@ -677,7 +725,7 @@ OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyArray* a ...@@ -677,7 +725,7 @@ OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyArray* a
for type in ('double', 'int'): for type in ('double', 'int'):
name = 'OpenMM_%sArray' % type.capitalize() name = 'OpenMM_%sArray' % type.capitalize()
values = {'type':type, 'name':name} values = {'type':type, 'name':name}
print >>out, """ print >>self.out, """
/* %(name)s */ /* %(name)s */
OPENMM_EXPORT %(name)s* %(name)s_create(int size) { OPENMM_EXPORT %(name)s* %(name)s_create(int size) {
return reinterpret_cast<%(name)s*>(new vector<%(type)s>(size)); return reinterpret_cast<%(name)s*>(new vector<%(type)s>(size));
...@@ -704,7 +752,7 @@ OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index) { ...@@ -704,7 +752,7 @@ OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index) {
for type in ('int',): for type in ('int',):
name = 'OpenMM_%sSet' % type.capitalize() name = 'OpenMM_%sSet' % type.capitalize()
values = {'type':type, 'name':name} values = {'type':type, 'name':name}
print >>out, """ print >>self.out, """
/* %(name)s */ /* %(name)s */
OPENMM_EXPORT %(name)s* %(name)s_create() { OPENMM_EXPORT %(name)s* %(name)s_create() {
return reinterpret_cast<%(name)s*>(new set<%(type)s>()); return reinterpret_cast<%(name)s*>(new set<%(type)s>());
...@@ -720,9 +768,11 @@ OPENMM_EXPORT void %(name)s_insert(%(name)s* s, %(type)s value) { ...@@ -720,9 +768,11 @@ OPENMM_EXPORT void %(name)s_insert(%(name)s* s, %(type)s value) {
}""" % values }""" % values
self.writeClasses() self.writeClasses()
print >>self.out, "}\n"
inputDirname = '/Users/peastman/workspace/openmm/bin-release/wrappers/doxygen/xml' #inputDirname = '/Users/peastman/workspace/openmm/bin-release/wrappers/doxygen/xml'
out = sys.stdout inputDirname = sys.argv[1]
#builder = CHeaderGenerator(inputDirname, out) builder = CHeaderGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.h'), 'w'))
builder = CSourceGenerator(inputDirname, out) builder.writeOutput()
builder = CSourceGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.cpp'), 'w'))
builder.writeOutput() builder.writeOutput()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment