Commit 3fda9879 authored by peastman's avatar peastman
Browse files

Finished C wrapper generation

parent 19b4fc48
......@@ -63,6 +63,8 @@ def findNodes(parent, path, **args):
return nodes
class WrapperGenerator:
"""This is the parent class of generators for various API wrapper files. It defines functions common to all of them."""
def __init__(self, inputDirname, output):
self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy']
self.skipMethods = ['OpenMM::Context::getState', 'OpenMM::Platform::loadPluginsFromDirectory', 'OpenMM::Context::createCheckpoint', 'OpenMM::Context::loadCheckpoint']
......@@ -79,6 +81,7 @@ class WrapperGenerator:
'std::vector< double >': 'OpenMM_DoubleArray',
'std::vector< int >': 'OpenMM_IntArray',
'std::set< int >': 'OpenMM_IntSet'}
self.inverseTranslations = dict((self.typeTranslations[key], key) for key in self.typeTranslations)
self.nodeByID={}
# Read all the XML files and merge them into a single document.
......@@ -88,7 +91,7 @@ class WrapperGenerator:
for node in root:
self.doc.getroot().append(node)
self.fOut = output
self.out = output
self.typesByShortName = {}
self._orderedClassNodes = self.buildOrderedClassNodes()
......@@ -156,11 +159,13 @@ class WrapperGenerator:
return False
class CHeaderGenerator(WrapperGenerator):
"""This class generates the header file for the C API wrappers."""
def __init__(self, inputDirname, output):
WrapperGenerator.__init__(self, inputDirname, output)
def writeGlobalConstants(self):
self.fOut.write("/* Global Constants */\n\n")
self.out.write("/* Global Constants */\n\n")
node = next((x for x in findNodes(self.doc.getroot(), "compounddef", kind="namespace") if x.findtext("compoundname") == "OpenMM"))
for section in findNodes(node, "sectiondef", kind="var"):
for memberNode in findNodes(section, "memberdef", kind="variable", mutable="no", prot="public", static="yes"):
......@@ -168,24 +173,24 @@ class CHeaderGenerator(WrapperGenerator):
iDef = getText("initializer", memberNode)
if iDef.startswith("="):
iDef = iDef[1:]
self.fOut.write("static %s = %s;\n" % (vDef, iDef))
self.out.write("static %s = %s;\n" % (vDef, iDef))
def writeTypeDeclarations(self):
self.fOut.write("\n/* Type Declarations */\n\n")
self.out.write("\n/* Type Declarations */\n\n")
for classNode in self._orderedClassNodes:
className = getText("compoundname", classNode)
shortName = stripOpenMMPrefix(className)
typeName = convertOpenMMPrefix(className)
self.fOut.write("typedef struct %s_struct %s;\n" % (typeName, typeName))
self.out.write("typedef struct %s_struct %s;\n" % (typeName, typeName))
self.typesByShortName[shortName] = typeName
def writeClasses(self):
for classNode in self._orderedClassNodes:
className = stripOpenMMPrefix(getText("compoundname", classNode))
self.fOut.write("\n/* %s */\n" % className)
self.out.write("\n/* %s */\n" % className)
self.writeEnumerations(classNode)
self.writeMethods(classNode)
self.fOut.write("\n")
self.out.write("\n")
def writeEnumerations(self, classNode):
enumNodes = []
......@@ -198,18 +203,18 @@ class CHeaderGenerator(WrapperGenerator):
for enumNode in enumNodes:
enumName = getText("name", enumNode)
enumTypeName = "%s_%s" % (typeName, enumName)
self.fOut.write("typedef enum {\n ")
self.out.write("typedef enum {\n ")
argSep=""
for valueNode in findNodes(enumNode, "enumvalue", prot="public"):
vName = convertOpenMMPrefix(getText("name", valueNode))
vInit = getText("initializer", valueNode)
if vInit.startswith("="):
vInit = vInit[1:].strip()
self.fOut.write("%s%s_%s = %s" % (argSep, typeName, vName, vInit))
self.out.write("%s%s_%s = %s" % (argSep, typeName, vName, vInit))
argSep=", "
self.fOut.write("\n} %s;\n" % enumTypeName)
self.out.write("\n} %s;\n" % enumTypeName)
self.typesByShortName[enumName] = enumTypeName
if len(enumNodes)>0: self.fOut.write("\n")
if len(enumNodes)>0: self.out.write("\n")
def writeMethods(self, classNode):
methodList = self.getClassMethods(classNode)
......@@ -233,31 +238,40 @@ class CHeaderGenerator(WrapperGenerator):
suffix = ""
else:
suffix = "_%d" % numConstructors
self.fOut.write("extern OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix))
self.out.write("extern OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix))
self.writeArguments(methodNode, False)
self.fOut.write(");\n")
self.out.write(");\n")
# Write destructor
self.fOut.write("extern OPENMM_EXPORT void %s_destroy(%s* target);\n" % (typeName, typeName))
self.out.write("extern OPENMM_EXPORT void %s_destroy(%s* target);\n" % (typeName, typeName))
# Write other methods
# Record method names for future reference.
methodNames = {}
for methodNode in methodList:
methodDefinition = getText("definition", methodNode)
shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
methodName = shortMethodDefinition.split()[-1]
methodNames[methodNode] = shortMethodDefinition.split()[-1]
# Write other methods
for methodNode in methodList:
methodName = methodNames[methodNode]
if methodName in (shortClassName, destructorName):
continue
if self.shouldHideMethod(methodNode):
continue
isConstMethod = (methodNode.attrib['const'] == 'yes')
if isConstMethod and any(methodNames[m] == methodName and m.attrib['const'] == 'no' for m in methodList):
# There are two identical methods that differ only in whether they are const. Skip the const one.
continue
returnType = self.getType(getText("type", methodNode))
self.fOut.write("extern OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
self.out.write("extern OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
isInstanceMethod = (methodNode.attrib['static'] != 'yes')
if isInstanceMethod:
if methodNode.attrib['const'] == 'yes':
self.fOut.write('const ')
self.fOut.write("%s* target" % typeName)
if isConstMethod:
self.out.write('const ')
self.out.write("%s* target" % typeName)
self.writeArguments(methodNode, isInstanceMethod)
self.fOut.write(");\n")
self.out.write(");\n")
def writeArguments(self, methodNode, initialSeparator):
paramList = findNodes(methodNode, 'param')
......@@ -274,7 +288,7 @@ class CHeaderGenerator(WrapperGenerator):
continue
type = self.getType(type)
name = getText('declname', node)
self.fOut.write("%s%s %s" % (separator, type, name))
self.out.write("%s%s %s" % (separator, type, name))
separator = ", "
def getType(self, type):
......@@ -289,7 +303,7 @@ class CHeaderGenerator(WrapperGenerator):
return type
def writeOutput(self):
print >>out, """
print >>self.out, """
#ifndef OPENMM_CWRAPPER_H_
#define OPENMM_CWRAPPER_H_
......@@ -299,7 +313,7 @@ class CHeaderGenerator(WrapperGenerator):
"""
self.writeGlobalConstants()
self.writeTypeDeclarations()
print >>out, """
print >>self.out, """
typedef struct OpenMM_Vec3Array_struct OpenMM_Vec3Array;
typedef struct OpenMM_StringArray_struct OpenMM_StringArray;
typedef struct OpenMM_BondArray_struct OpenMM_BondArray;
......@@ -357,7 +371,7 @@ extern OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyA
for type in ('double', 'int'):
name = 'OpenMM_%sArray' % type.capitalize()
values = {'type':type, 'name':name}
print >>out, """
print >>self.out, """
/* %(name)s */
extern OPENMM_EXPORT %(name)s* %(name)s_create(int size);
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* array);
......@@ -370,14 +384,14 @@ extern OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index);"""
for type in ('int',):
name = 'OpenMM_%sSet' % type.capitalize()
values = {'type':type, 'name':name}
print >>out, """
print >>self.out, """
/* %(name)s */
extern OPENMM_EXPORT %(name)s* %(name)s_create();
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* set);
extern OPENMM_EXPORT int %(name)s_getSize(const %(name)s* set);
extern OPENMM_EXPORT void %(name)s_insert(%(name)s* set, %(type)s value);""" % values
print >>out, """
print >>self.out, """
/* These methods need to be handled specially, since their C++ APIs cannot be directly translated to C.
Unlike the C++ versions, the return value is allocated on the heap, and you must delete it yourself. */
extern OPENMM_EXPORT OpenMM_State* OpenMM_Context_getState(const OpenMM_Context* target, int types, int enforcePeriodicBox);
......@@ -385,7 +399,7 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector
self.writeClasses()
print >>out, """
print >>self.out, """
#if defined(__cplusplus)
}
#endif
......@@ -394,9 +408,12 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector
class CSourceGenerator(WrapperGenerator):
"""This class generates the source file for the C API wrappers."""
def __init__(self, inputDirname, output):
WrapperGenerator.__init__(self, inputDirname, output)
self.classesByShortName = {}
self.enumerationTypes = {}
self.findTypes()
def findTypes(self):
......@@ -416,16 +433,19 @@ class CSourceGenerator(WrapperGenerator):
typeName = convertOpenMMPrefix(className)
for enumNode in enumNodes:
enumName = getText("name", enumNode)
self.typesByShortName[enumName] = "%s_%s" % (typeName, enumName)
self.classesByShortName[enumName] = '%s::%s' % (className, enumName)
enumTypeName = "%s_%s" % (typeName, enumName)
enumClassName = "%s::%s" % (className, enumName)
self.typesByShortName[enumName] = enumTypeName
self.classesByShortName[enumName] = enumClassName
self.enumerationTypes[enumClassName] = enumTypeName
def writeClasses(self):
for classNode in self._orderedClassNodes:
className = stripOpenMMPrefix(getText("compoundname", classNode))
self.fOut.write("\n/* OpenMM::%s */\n" % className)
self.out.write("\n/* OpenMM::%s */\n" % className)
self.findEnumerations(classNode)
self.writeMethods(classNode)
self.fOut.write("\n")
self.out.write("\n")
def writeMethods(self, classNode):
methodList = self.getClassMethods(classNode)
......@@ -449,57 +469,69 @@ class CSourceGenerator(WrapperGenerator):
suffix = ""
else:
suffix = "_%d" % numConstructors
self.fOut.write("OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix))
self.out.write("OPENMM_EXPORT %s* %s_create%s(" % (typeName, typeName, suffix))
self.writeArguments(methodNode, False)
self.fOut.write(") {\n")
self.fOut.write(" return reinterpret_cast<%s*>(new %s(" % (className, className))
self.out.write(") {\n")
self.out.write(" return reinterpret_cast<%s*>(new %s(" % (typeName, className))
self.writeInvocationArguments(methodNode, False)
self.fOut.write("));\n")
self.fOut.write("}\n")
self.out.write("));\n")
self.out.write("}\n")
# Write destructor
self.fOut.write("OPENMM_EXPORT void %s_destroy(%s* target) {\n" % (typeName, typeName))
self.fOut.write(" delete reinterpret_cast<%s*>(target);\n" % className)
self.fOut.write("}\n")
self.out.write("OPENMM_EXPORT void %s_destroy(%s* target) {\n" % (typeName, typeName))
self.out.write(" delete reinterpret_cast<%s*>(target);\n" % className)
self.out.write("}\n")
# Write other methods
# Record method names for future reference.
methodNames = {}
for methodNode in methodList:
methodDefinition = getText("definition", methodNode)
shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
methodName = shortMethodDefinition.split()[-1]
methodNames[methodNode] = shortMethodDefinition.split()[-1]
# Write other methods
for methodNode in methodList:
methodName = methodNames[methodNode]
if methodName in (shortClassName, destructorName):
continue
if self.shouldHideMethod(methodNode):
continue
isConstMethod = (methodNode.attrib['const'] == 'yes')
if isConstMethod and any(methodNames[m] == methodName and m.attrib['const'] == 'no' for m in methodList):
# There are two identical methods that differ only in whether they are const. Skip the const one.
continue
methodType = getText("type", methodNode)
returnType = self.getType(methodType)
if methodType in self.classesByShortName:
methodType = self.classesByShortName[methodType]
self.fOut.write("OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
self.out.write("OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
isInstanceMethod = (methodNode.attrib['static'] != 'yes')
if isInstanceMethod:
isConstMethod = (methodNode.attrib['const'] == 'yes')
if isConstMethod:
self.fOut.write('const ')
self.fOut.write("%s* target" % typeName)
self.out.write('const ')
self.out.write("%s* target" % typeName)
self.writeArguments(methodNode, isInstanceMethod)
self.fOut.write(") {\n")
self.fOut.write(" ")
self.out.write(") {\n")
self.out.write(" ")
if returnType != 'void':
self.fOut.write('%s result = ' % methodType)
if methodType.endswith('&'):
# Convert references to pointers
self.out.write('%s* result = &' % methodType[:-1].strip())
else:
self.out.write('%s result = ' % methodType)
if isInstanceMethod:
self.fOut.write('reinterpret_cast<')
self.out.write('reinterpret_cast<')
if isConstMethod:
self.fOut.write('const ')
self.fOut.write('%s*>(target)->' % className)
self.out.write('const ')
self.out.write('%s*>(target)->' % className)
else:
self.fOut.write('%s::' % className)
self.fOut.write('%s(' % methodName)
self.out.write('%s::' % className)
self.out.write('%s(' % methodName)
self.writeInvocationArguments(methodNode, False)
self.fOut.write(');\n')
self.out.write(');\n')
if returnType != 'void':
self.fOut.write(' return %s;\n' % self.wrapValue(methodType, 'result'))
self.fOut.write("}\n")
self.out.write(' return %s;\n' % self.wrapValue(methodType, 'result'))
self.out.write("}\n")
def writeArguments(self, methodNode, initialSeparator):
paramList = findNodes(methodNode, 'param')
......@@ -516,7 +548,7 @@ class CSourceGenerator(WrapperGenerator):
continue
type = self.getType(type)
name = getText('declname', node)
self.fOut.write("%s%s %s" % (separator, type, name))
self.out.write("%s%s %s" % (separator, type, name))
separator = ", "
def writeInvocationArguments(self, methodNode, initialSeparator):
......@@ -533,7 +565,9 @@ class CSourceGenerator(WrapperGenerator):
if type == 'void':
continue
name = getText('declname', node)
self.fOut.write("%s%s" % (separator, name))
if self.getType(type) != type:
name = self.unwrapValue(type, name)
self.out.write("%s%s" % (separator, name))
separator = ", "
def getType(self, type):
......@@ -553,7 +587,9 @@ class CSourceGenerator(WrapperGenerator):
if type == 'std::string':
return '%s.c_str()' % value
if type == 'const std::string &':
return '%s.c_str()' % value
return '%s->c_str()' % value
if type in self.enumerationTypes:
return 'static_cast<%s>(%s)' % (self.enumerationTypes[type], value)
wrappedType = self.getType(type)
if wrappedType == type:
return value;
......@@ -561,8 +597,20 @@ class CSourceGenerator(WrapperGenerator):
return 'reinterpret_cast<%s>(%s)' % (wrappedType, value)
return 'static_cast<%s>(%s)' % (wrappedType, value)
def unwrapValue(self, type, value):
if type.endswith('&'):
unwrappedType = type[:-1].strip()
if unwrappedType in self.classesByShortName:
unwrappedType = self.classesByShortName[unwrappedType]
return '*'+self.unwrapValue(unwrappedType+'*', value)
if type in self.classesByShortName:
return 'static_cast<%s>(%s)' % (self.classesByShortName[type], value)
if type == 'bool':
return value
return 'reinterpret_cast<%s>(%s)' % (type, value)
def writeOutput(self):
print >>out, """
print >>self.out, """
#include "OpenMM.h"
#include "OpenMMCWrapper.h"
#include <cstring>
......@@ -599,7 +647,7 @@ OPENMM_EXPORT void OpenMM_Vec3Array_set(OpenMM_Vec3Array* array, int index, cons
(*reinterpret_cast<vector<Vec3>*>(array))[index] = Vec3(vec.x, vec.y, vec.z);
}
OPENMM_EXPORT const OpenMM_Vec3* OpenMM_Vec3Array_get(const OpenMM_Vec3Array* array, int index) {
return reinterpret_cast<const OpenMM_Vec3*>((&amp;(*reinterpret_cast<const vector<Vec3>*>(array))[index]));
return reinterpret_cast<const OpenMM_Vec3*>((&(*reinterpret_cast<const vector<Vec3>*>(array))[index]));
}
/* OpenMM_StringArray */
......@@ -677,7 +725,7 @@ OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyArray* a
for type in ('double', 'int'):
name = 'OpenMM_%sArray' % type.capitalize()
values = {'type':type, 'name':name}
print >>out, """
print >>self.out, """
/* %(name)s */
OPENMM_EXPORT %(name)s* %(name)s_create(int size) {
return reinterpret_cast<%(name)s*>(new vector<%(type)s>(size));
......@@ -704,7 +752,7 @@ OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index) {
for type in ('int',):
name = 'OpenMM_%sSet' % type.capitalize()
values = {'type':type, 'name':name}
print >>out, """
print >>self.out, """
/* %(name)s */
OPENMM_EXPORT %(name)s* %(name)s_create() {
return reinterpret_cast<%(name)s*>(new set<%(type)s>());
......@@ -720,9 +768,11 @@ OPENMM_EXPORT void %(name)s_insert(%(name)s* s, %(type)s value) {
}""" % values
self.writeClasses()
print >>self.out, "}\n"
inputDirname = '/Users/peastman/workspace/openmm/bin-release/wrappers/doxygen/xml'
out = sys.stdout
#builder = CHeaderGenerator(inputDirname, out)
builder = CSourceGenerator(inputDirname, out)
#inputDirname = '/Users/peastman/workspace/openmm/bin-release/wrappers/doxygen/xml'
inputDirname = sys.argv[1]
builder = CHeaderGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.h'), 'w'))
builder.writeOutput()
builder = CSourceGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.cpp'), 'w'))
builder.writeOutput()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment