Commit a27fd3dc authored by peastman's avatar peastman
Browse files

Implemented generating Fortran wrapper header

parent 3fda9879
...@@ -69,19 +69,6 @@ class WrapperGenerator: ...@@ -69,19 +69,6 @@ class WrapperGenerator:
self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy'] self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy']
self.skipMethods = ['OpenMM::Context::getState', 'OpenMM::Platform::loadPluginsFromDirectory', 'OpenMM::Context::createCheckpoint', 'OpenMM::Context::loadCheckpoint'] self.skipMethods = ['OpenMM::Context::getState', 'OpenMM::Platform::loadPluginsFromDirectory', 'OpenMM::Context::createCheckpoint', 'OpenMM::Context::loadCheckpoint']
self.hideClasses = ['Kernel', 'KernelImpl', 'KernelFactory', 'ContextImpl', 'SerializationNode', 'SerializationProxy'] self.hideClasses = ['Kernel', 'KernelImpl', 'KernelFactory', 'ContextImpl', 'SerializationNode', 'SerializationProxy']
self.typeTranslations = {'bool': 'OpenMM_Boolean',
'Vec3': 'OpenMM_Vec3',
'std::string': 'char*',
'const std::string &': 'const char*',
'std::vector< std::string >': 'OpenMM_StringArray',
'std::vector< Vec3 >': 'OpenMM_Vec3Array',
'std::vector< std::pair< int, int > >': 'OpenMM_BondArray',
'std::map< std::string, double >': 'OpenMM_ParameterArray',
'std::map< std::string, std::string >': 'OpenMM_PropertyArray',
'std::vector< double >': 'OpenMM_DoubleArray',
'std::vector< int >': 'OpenMM_IntArray',
'std::set< int >': 'OpenMM_IntSet'}
self.inverseTranslations = dict((self.typeTranslations[key], key) for key in self.typeTranslations)
self.nodeByID={} self.nodeByID={}
# Read all the XML files and merge them into a single document. # Read all the XML files and merge them into a single document.
...@@ -163,6 +150,18 @@ class CHeaderGenerator(WrapperGenerator): ...@@ -163,6 +150,18 @@ class CHeaderGenerator(WrapperGenerator):
def __init__(self, inputDirname, output): def __init__(self, inputDirname, output):
WrapperGenerator.__init__(self, inputDirname, output) WrapperGenerator.__init__(self, inputDirname, output)
self.typeTranslations = {'bool': 'OpenMM_Boolean',
'Vec3': 'OpenMM_Vec3',
'std::string': 'char*',
'const std::string &': 'const char*',
'std::vector< std::string >': 'OpenMM_StringArray',
'std::vector< Vec3 >': 'OpenMM_Vec3Array',
'std::vector< std::pair< int, int > >': 'OpenMM_BondArray',
'std::map< std::string, double >': 'OpenMM_ParameterArray',
'std::map< std::string, std::string >': 'OpenMM_PropertyArray',
'std::vector< double >': 'OpenMM_DoubleArray',
'std::vector< int >': 'OpenMM_IntArray',
'std::set< int >': 'OpenMM_IntSet'}
def writeGlobalConstants(self): def writeGlobalConstants(self):
self.out.write("/* Global Constants */\n\n") self.out.write("/* Global Constants */\n\n")
...@@ -412,6 +411,19 @@ class CSourceGenerator(WrapperGenerator): ...@@ -412,6 +411,19 @@ class CSourceGenerator(WrapperGenerator):
def __init__(self, inputDirname, output): def __init__(self, inputDirname, output):
WrapperGenerator.__init__(self, inputDirname, output) WrapperGenerator.__init__(self, inputDirname, output)
self.typeTranslations = {'bool': 'OpenMM_Boolean',
'Vec3': 'OpenMM_Vec3',
'std::string': 'char*',
'const std::string &': 'const char*',
'std::vector< std::string >': 'OpenMM_StringArray',
'std::vector< Vec3 >': 'OpenMM_Vec3Array',
'std::vector< std::pair< int, int > >': 'OpenMM_BondArray',
'std::map< std::string, double >': 'OpenMM_ParameterArray',
'std::map< std::string, std::string >': 'OpenMM_PropertyArray',
'std::vector< double >': 'OpenMM_DoubleArray',
'std::vector< int >': 'OpenMM_IntArray',
'std::set< int >': 'OpenMM_IntSet'}
self.inverseTranslations = dict((self.typeTranslations[key], key) for key in self.typeTranslations)
self.classesByShortName = {} self.classesByShortName = {}
self.enumerationTypes = {} self.enumerationTypes = {}
self.findTypes() self.findTypes()
...@@ -770,9 +782,483 @@ OPENMM_EXPORT void %(name)s_insert(%(name)s* s, %(type)s value) { ...@@ -770,9 +782,483 @@ OPENMM_EXPORT void %(name)s_insert(%(name)s* s, %(type)s value) {
self.writeClasses() self.writeClasses()
print >>self.out, "}\n" print >>self.out, "}\n"
class FortranHeaderGenerator(WrapperGenerator):
"""This class generates the header file for the Fortran API wrappers."""
def __init__(self, inputDirname, output):
WrapperGenerator.__init__(self, inputDirname, output)
self.typeTranslations = {'int': 'integer*4',
'bool': 'integer*4',
'double': 'real*8',
'std::string': 'character(*)',
'const std::string &': 'character(*)',
'std::vector< std::string >': 'type (OpenMM_StringArray)',
'std::vector< Vec3 >': 'type (OpenMM_Vec3Array)',
'std::vector< std::pair< int, int > >': 'type (OpenMM_BondArray)',
'std::map< std::string, double >': 'type (OpenMM_ParameterArray)',
'std::map< std::string, std::string >': 'type (OpenMM_PropertyArray)',
'std::vector< double >': 'type (OpenMM_DoubleArray)',
'std::vector< int >': 'type (OpenMM_IntArray)',
'std::set< int >': 'type (OpenMM_IntSet)'}
def writeGlobalConstants(self):
self.out.write(" ! Global Constants\n\n")
node = next((x for x in findNodes(self.doc.getroot(), "compounddef", kind="namespace") if x.findtext("compoundname") == "OpenMM"))
for section in findNodes(node, "sectiondef", kind="var"):
for memberNode in findNodes(section, "memberdef", kind="variable", mutable="no", prot="public", static="yes"):
vDef = convertOpenMMPrefix(getText("definition", memberNode))
iDef = getText("initializer", memberNode)
if iDef.startswith("="):
iDef = iDef[1:]
self.out.write(" real*8, parameter :: %s = %s\n" % (vDef, iDef))
def writeTypeDeclarations(self):
self.out.write("\n ! Type Declarations\n")
for classNode in self._orderedClassNodes:
className = getText("compoundname", classNode)
shortName = stripOpenMMPrefix(className)
typeName = convertOpenMMPrefix(className)
self.out.write("\n type OpenMM_%s\n" % typeName)
self.out.write(" integer*8 :: handle = 0\n")
self.out.write(" end type\n")
self.typesByShortName[shortName] = typeName
def writeClasses(self):
for classNode in self._orderedClassNodes:
className = getText("compoundname", classNode)
self.out.write("\n ! %s\n" % className)
self.writeMethods(classNode)
self.out.write("\n")
def writeEnumerations(self, classNode):
enumNodes = []
for section in findNodes(classNode, "sectiondef", kind="public-type"):
for node in findNodes(section, "memberdef", kind="enum", prot="public"):
enumNodes.append(node)
className = getText("compoundname", classNode)
typeName = convertOpenMMPrefix(className)
for enumNode in enumNodes:
for valueNode in findNodes(enumNode, "enumvalue", prot="public"):
vName = convertOpenMMPrefix(getText("name", valueNode))
vInit = getText("initializer", valueNode)
if vInit.startswith("="):
vInit = vInit[1:].strip()
self.out.write(" integer*4, parameter :: %s_%s = %s\n" % (typeName, vName, vInit))
enumName = getText("name", enumNode)
enumTypeName = "%s_%s" % (typeName, enumName)
self.typesByShortName[enumName] = enumTypeName
if len(enumNodes)>0: self.out.write("\n")
def writeMethods(self, classNode):
methodList = self.getClassMethods(classNode)
className = getText("compoundname", classNode)
shortClassName = stripOpenMMPrefix(className)
typeName = convertOpenMMPrefix(className)
destructorName = '~'+shortClassName
if not ('abstract' in classNode.attrib and classNode.attrib['abstract'] == 'yes'):
# Write constructors
numConstructors = 0
for methodNode in methodList:
methodDefinition = getText("definition", methodNode)
shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
methodName = shortMethodDefinition.split()[-1]
if methodName == shortClassName:
if self.shouldHideMethod(methodNode):
continue
numConstructors += 1
if numConstructors == 1:
suffix = ""
else:
suffix = "_%d" % numConstructors
self.out.write(" subroutine %s_create%s(result, " % (typeName, suffix))
self.writeArguments(methodNode, False)
self.out.write(")\n")
self.out.write(" use OpenMM_Types; implicit none\n")
self.out.write(" type (%s) result\n" % typeName)
self.declareArguments(methodNode)
self.out.write(" end subroutine\n")
# Write destructor
self.out.write(" subroutine %s_destroy(destroy)\n" % typeName)
self.out.write(" use OpenMM_Types; implicit none\n")
self.out.write(" type (%s) destroy\n" % typeName)
self.out.write(" end subroutine\n")
# Record method names for future reference.
methodNames = {}
for methodNode in methodList:
methodDefinition = getText("definition", methodNode)
shortMethodDefinition = stripOpenMMPrefix(methodDefinition)
methodNames[methodNode] = shortMethodDefinition.split()[-1]
# Write other methods
for methodNode in methodList:
methodName = methodNames[methodNode]
if methodName in (shortClassName, destructorName):
continue
if self.shouldHideMethod(methodNode):
continue
isConstMethod = (methodNode.attrib['const'] == 'yes')
if isConstMethod and any(methodNames[m] == methodName and m.attrib['const'] == 'no' for m in methodList):
# There are two identical methods that differ only in whether they are const. Skip the const one.
continue
returnType = self.getType(getText("type", methodNode))
hasReturnValue = (returnType in ('integer*4', 'real*8'))
hasReturnArg = not (hasReturnValue or returnType == 'void')
functionName = "%s_%s" % (typeName, methodName)
if hasReturnValue:
self.out.write(" function ")
else:
self.out.write(" subroutine ")
self.out.write("%s(" % functionName)
isInstanceMethod = (methodNode.attrib['static'] != 'yes')
if isInstanceMethod:
self.out.write("target")
numArgs = self.writeArguments(methodNode, isInstanceMethod)
if hasReturnArg:
if isInstanceMethod or numArgs > 0:
self.out.write(", ")
self.out.write("result")
self.out.write(")\n")
self.out.write(" use OpenMM_Types; implicit none\n")
self.out.write(" type (%s) target\n" % typeName)
self.declareArguments(methodNode)
if hasReturnValue:
self.declareOneArgument(returnType, functionName)
if hasReturnArg:
self.declareOneArgument(returnType, 'result')
if hasReturnValue:
self.out.write(" end function\n")
else:
self.out.write(" end subroutine\n")
def writeArguments(self, methodNode, initialSeparator):
paramList = findNodes(methodNode, 'param')
if initialSeparator:
separator = ", "
else:
separator = ""
numArgs = 0
for node in paramList:
try:
type = getText('type', node)
except IndexError:
type = getText('type/ref', node)
if type == 'void':
continue
name = getText('declname', node)
self.out.write("%s%s" % (separator, name))
separator = ", &\n "
numArgs += 1
return numArgs
def declareOneArgument(self, type, name):
if type == 'void':
return
type = self.getType(type)
if type == 'Vec3':
self.out.write(" real*8 %s(3)\n" % name)
else:
self.out.write(" %s %s\n" % (type, name))
def declareArguments(self, methodNode):
paramList = findNodes(methodNode, 'param')
for node in paramList:
try:
type = getText('type', node)
except IndexError:
type = getText('type/ref', node)
name = getText('declname', node)
self.declareOneArgument(type, name)
def getType(self, type):
if type in self.typeTranslations:
return self.typeTranslations[type]
if type in self.typesByShortName:
return 'type (%s)' % self.typesByShortName[type]
if type.startswith('const '):
return self.getType(type[6:].strip())
if type.endswith('&') or type.endswith('*'):
return self.getType(type[:-1].strip())
return type
def writeOutput(self):
print >>self.out, """
MODULE OpenMM_Types
implicit none
"""
self.writeGlobalConstants()
self.writeTypeDeclarations()
print >>self.out, """
type OpenMM_Vec3Array
integer*8 :: handle = 0
end type
type OpenMM_StringArray
integer*8 :: handle = 0
end type
type OpenMM_BondArray
integer*8 :: handle = 0
end type
type OpenMM_ParameterArray
integer*8 :: handle = 0
end type
type OpenMM_PropertyArray
integer*8 :: handle = 0
end type
type OpenMM_DoubleArray
integer*8 :: handle = 0
end type
type OpenMM_IntArray
integer*8 :: handle = 0
end type
type OpenMM_IntSet
integer*8 :: handle = 0
end type
! Enumerations
integer*4, parameter :: OpenMM_False = 0
integer*4, parameter :: OpenMM_True = 1"""
for classNode in self._orderedClassNodes:
self.writeEnumerations(classNode)
print >>self.out, """
END MODULE OpenMM_Types
MODULE OpenMM
use OpenMM_Types; implicit none
interface
! OpenMM_Vec3
subroutine OpenMM_Vec3_scale(vec, scale, result)
use OpenMM_Types; implicit none
real*8 vec(3)
real*8 scale
real*8 result(3)
end subroutine
! OpenMM_Vec3Array
subroutine OpenMM_Vec3Array_create(result, size)
use OpenMM_Types; implicit none
integer*4 size
type (OpenMM_Vec3Array) result
end subroutine
subroutine OpenMM_Vec3Array_destroy(destroy)
use OpenMM_Types; implicit none
type (OpenMM_Vec3Array) destroy
end subroutine
function OpenMM_Vec3Array_getSize(target)
use OpenMM_Types; implicit none
type (OpenMM_Vec3Array) target
integer*4 OpenMM_Vec3Array_getSize
end function
subroutine OpenMM_Vec3Array_resize(target, size)
use OpenMM_Types; implicit none
type (OpenMM_Vec3Array) target
integer*4 size
end subroutine
subroutine OpenMM_Vec3Array_append(target, vec)
use OpenMM_Types; implicit none
type (OpenMM_Vec3Array) target
real*8 vec(3)
end subroutine
subroutine OpenMM_Vec3Array_set(target, index, vec)
use OpenMM_Types; implicit none
type (OpenMM_Vec3Array) target
integer*4 index
real*8 vec(3)
end subroutine
subroutine OpenMM_Vec3Array_get(target, index, result)
use OpenMM_Types; implicit none
type (OpenMM_Vec3Array) target
integer*4 index
real*8 result(3)
end subroutine
! OpenMM_StringArray
subroutine OpenMM_StringArray_create(result, size)
use OpenMM_Types; implicit none
integer*4 size
type (OpenMM_StringArray) result
end subroutine
subroutine OpenMM_StringArray_destroy(destroy)
use OpenMM_Types; implicit none
type (OpenMM_StringArray) destroy
end subroutine
function OpenMM_StringArray_getSize(target)
use OpenMM_Types; implicit none
type (OpenMM_StringArray) target
integer*4 OpenMM_StringArray_getSize
end function
subroutine OpenMM_StringArray_resize(target, size)
use OpenMM_Types; implicit none
type (OpenMM_StringArray) target
integer*4 size
end subroutine
subroutine OpenMM_StringArray_append(target, str)
use OpenMM_Types; implicit none
type (OpenMM_StringArray) target
character(*) str
end subroutine
subroutine OpenMM_StringArray_set(target, index, str)
use OpenMM_Types; implicit none
type (OpenMM_StringArray) target
integer*4 index
character(*) str
end subroutine
subroutine OpenMM_StringArray_get(target, index, result)
use OpenMM_Types; implicit none
type (OpenMM_StringArray) target
integer*4 index
character(*) result
end subroutine
! OpenMM_BondArray
subroutine OpenMM_BondArray_create(result, size)
use OpenMM_Types; implicit none
integer*4 size
type (OpenMM_BondArray) result
end subroutine
subroutine OpenMM_BondArray_destroy(destroy)
use OpenMM_Types; implicit none
type (OpenMM_BondArray) destroy
end subroutine
function OpenMM_BondArray_getSize(target)
use OpenMM_Types; implicit none
type (OpenMM_BondArray) target
integer*4 OpenMM_BondArray_getSize
end function
subroutine OpenMM_BondArray_resize(target, size)
use OpenMM_Types; implicit none
type (OpenMM_BondArray) target
integer*4 size
end subroutine
subroutine OpenMM_BondArray_append(target, particle1, particle2)
use OpenMM_Types; implicit none
type (OpenMM_BondArray) target
integer*4 particle1
integer*4 particle2
end subroutine
subroutine OpenMM_BondArray_set(target, index, particle1, particle2)
use OpenMM_Types; implicit none
type (OpenMM_BondArray) target
integer*4 index
integer*4 particle1
integer*4 particle2
end subroutine
subroutine OpenMM_BondArray_get(target, index, particle1, particle2)
use OpenMM_Types; implicit none
type (OpenMM_BondArray) target
integer*4 index
integer*4 particle1
integer*4 particle2
end subroutine
! OpenMM_ParameterArray
function OpenMM_ParameterArray_getSize(target)
use OpenMM_Types; implicit none
type (OpenMM_ParameterArray) target
integer*4 OpenMM_ParameterArray_getSize
end function
subroutine OpenMM_ParameterArray_get(target, name, result)
use OpenMM_Types; implicit none
type (OpenMM_ParameterArray) target
character(*) name
character(*) result
end subroutine
! OpenMM_PropertyArray
function OpenMM_PropertyArray_getSize(target)
use OpenMM_Types; implicit none
type (OpenMM_ParameterArray) target
integer*4 OpenMM_PropertyArray_getSize
end function
subroutine OpenMM_PropertyArray_get(target, name, result)
use OpenMM_Types; implicit none
type (OpenMM_PropertyArray) target
character(*) name
character(*) result
end subroutine"""
arrayTypes = {'OpenMM_DoubleArray':'real*8', 'OpenMM_IntArray':'integer*4'}
for name in arrayTypes:
values = {'type':arrayTypes[name], 'name':name}
print >>self.out, """
! %(name)s
subroutine %(name)s_create(result, size)
use OpenMM_Types; implicit none
integer*4 size
type (%(name)s) result
end subroutine
subroutine %(name)s_destroy(destroy)
use OpenMM_Types; implicit none
type (%(name)s) destroy
end subroutine
function %(name)s_getSize(target)
use OpenMM_Types; implicit none
type (%(name)s) target
integer*4 %(name)s_getSize
end function
subroutine %(name)s_resize(target, size)
use OpenMM_Types; implicit none
type (%(name)s) target
integer*4 size
end subroutine
subroutine %(name)s_append(target, value)
use OpenMM_Types; implicit none
type (%(name)s) target
%(type)s value
end subroutine
subroutine %(name)s_set(target, index, value)
use OpenMM_Types; implicit none
type (%(name)s) target
integer*4 index
%(type)s value
end subroutine
subroutine %(name)s_get(target, index, result)
use OpenMM_Types; implicit none
type (%(name)s) target
integer*4 index
%(type)s result
end subroutine""" % values
print >>self.out, """
! These methods need to be handled specially, since their C++ APIs cannot be directly translated to Fortran.
! Unlike the C++ versions, the return value is allocated on the heap, and you must delete it yourself.
subroutine OpenMM_Context_getState(target, types, enforcePeriodicBox, result)
use OpenMM_Types; implicit none
type (OpenMM_Context) target
integer*4 types
integer*4 enforcePeriodicBox
type (OpenMM_State) result
end subroutine
subroutine OpenMM_Platform_loadPluginsFromDirectory(directory, result)
use OpenMM_Types; implicit none
character(*) directory
type (OpenMM_StringArray) result
end subroutine"""
self.writeClasses()
print >>self.out, """
end interface
END MODULE OpenMM"""
#inputDirname = '/Users/peastman/workspace/openmm/bin-release/wrappers/doxygen/xml' #inputDirname = '/Users/peastman/workspace/openmm/bin-release/wrappers/doxygen/xml'
inputDirname = sys.argv[1] inputDirname = sys.argv[1]
builder = CHeaderGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.h'), 'w')) builder = CHeaderGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.h'), 'w'))
builder.writeOutput() builder.writeOutput()
builder = CSourceGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.cpp'), 'w')) builder = CSourceGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMCWrapper.cpp'), 'w'))
builder.writeOutput() builder.writeOutput()
builder = FortranHeaderGenerator(inputDirname, open(os.path.join(sys.argv[2], 'OpenMMFortranModule.f90'), 'w'))
builder.writeOutput()
#builder = FortranHeaderGenerator(inputDirname, sys.stdout)
#builder.writeOutput()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment