Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
3fda9879
Commit
3fda9879
authored
Nov 21, 2013
by
peastman
Browse files
Finished C wrapper generation
parent
19b4fc48
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
68 deletions
+118
-68
wrappers/generateWrappers.py
wrappers/generateWrappers.py
+118
-68
No files found.
wrappers/generateWrappers.py
View file @
3fda9879
...
@@ -63,6 +63,8 @@ def findNodes(parent, path, **args):
...
@@ -63,6 +63,8 @@ def findNodes(parent, path, **args):
return
nodes
return
nodes
class
WrapperGenerator
:
class
WrapperGenerator
:
"""This is the parent class of generators for various API wrapper files. It defines functions common to all of them."""
def
__init__
(
self
,
inputDirname
,
output
):
def
__init__
(
self
,
inputDirname
,
output
):
self
.
skipClasses
=
[
'OpenMM::Vec3'
,
'OpenMM::XmlSerializer'
,
'OpenMM::Kernel'
,
'OpenMM::KernelImpl'
,
'OpenMM::KernelFactory'
,
'OpenMM::ContextImpl'
,
'OpenMM::SerializationNode'
,
'OpenMM::SerializationProxy'
]
self
.
skipClasses
=
[
'OpenMM::Vec3'
,
'OpenMM::XmlSerializer'
,
'OpenMM::Kernel'
,
'OpenMM::KernelImpl'
,
'OpenMM::KernelFactory'
,
'OpenMM::ContextImpl'
,
'OpenMM::SerializationNode'
,
'OpenMM::SerializationProxy'
]
self
.
skipMethods
=
[
'OpenMM::Context::getState'
,
'OpenMM::Platform::loadPluginsFromDirectory'
,
'OpenMM::Context::createCheckpoint'
,
'OpenMM::Context::loadCheckpoint'
]
self
.
skipMethods
=
[
'OpenMM::Context::getState'
,
'OpenMM::Platform::loadPluginsFromDirectory'
,
'OpenMM::Context::createCheckpoint'
,
'OpenMM::Context::loadCheckpoint'
]
...
@@ -79,6 +81,7 @@ class WrapperGenerator:
...
@@ -79,6 +81,7 @@ class WrapperGenerator:
'std::vector< double >'
:
'OpenMM_DoubleArray'
,
'std::vector< double >'
:
'OpenMM_DoubleArray'
,
'std::vector< int >'
:
'OpenMM_IntArray'
,
'std::vector< int >'
:
'OpenMM_IntArray'
,
'std::set< int >'
:
'OpenMM_IntSet'
}
'std::set< int >'
:
'OpenMM_IntSet'
}
self
.
inverseTranslations
=
dict
((
self
.
typeTranslations
[
key
],
key
)
for
key
in
self
.
typeTranslations
)
self
.
nodeByID
=
{}
self
.
nodeByID
=
{}
# Read all the XML files and merge them into a single document.
# Read all the XML files and merge them into a single document.
...
@@ -88,7 +91,7 @@ class WrapperGenerator:
...
@@ -88,7 +91,7 @@ class WrapperGenerator:
for
node
in
root
:
for
node
in
root
:
self
.
doc
.
getroot
().
append
(
node
)
self
.
doc
.
getroot
().
append
(
node
)
self
.
fO
ut
=
output
self
.
o
ut
=
output
self
.
typesByShortName
=
{}
self
.
typesByShortName
=
{}
self
.
_orderedClassNodes
=
self
.
buildOrderedClassNodes
()
self
.
_orderedClassNodes
=
self
.
buildOrderedClassNodes
()
...
@@ -156,11 +159,13 @@ class WrapperGenerator:
...
@@ -156,11 +159,13 @@ class WrapperGenerator:
return
False
return
False
class
CHeaderGenerator
(
WrapperGenerator
):
class
CHeaderGenerator
(
WrapperGenerator
):
"""This class generates the header file for the C API wrappers."""
def
__init__
(
self
,
inputDirname
,
output
):
def
__init__
(
self
,
inputDirname
,
output
):
WrapperGenerator
.
__init__
(
self
,
inputDirname
,
output
)
WrapperGenerator
.
__init__
(
self
,
inputDirname
,
output
)
def
writeGlobalConstants
(
self
):
def
writeGlobalConstants
(
self
):
self
.
fO
ut
.
write
(
"/* Global Constants */
\n\n
"
)
self
.
o
ut
.
write
(
"/* Global Constants */
\n\n
"
)
node
=
next
((
x
for
x
in
findNodes
(
self
.
doc
.
getroot
(),
"compounddef"
,
kind
=
"namespace"
)
if
x
.
findtext
(
"compoundname"
)
==
"OpenMM"
))
node
=
next
((
x
for
x
in
findNodes
(
self
.
doc
.
getroot
(),
"compounddef"
,
kind
=
"namespace"
)
if
x
.
findtext
(
"compoundname"
)
==
"OpenMM"
))
for
section
in
findNodes
(
node
,
"sectiondef"
,
kind
=
"var"
):
for
section
in
findNodes
(
node
,
"sectiondef"
,
kind
=
"var"
):
for
memberNode
in
findNodes
(
section
,
"memberdef"
,
kind
=
"variable"
,
mutable
=
"no"
,
prot
=
"public"
,
static
=
"yes"
):
for
memberNode
in
findNodes
(
section
,
"memberdef"
,
kind
=
"variable"
,
mutable
=
"no"
,
prot
=
"public"
,
static
=
"yes"
):
...
@@ -168,24 +173,24 @@ class CHeaderGenerator(WrapperGenerator):
...
@@ -168,24 +173,24 @@ class CHeaderGenerator(WrapperGenerator):
iDef
=
getText
(
"initializer"
,
memberNode
)
iDef
=
getText
(
"initializer"
,
memberNode
)
if
iDef
.
startswith
(
"="
):
if
iDef
.
startswith
(
"="
):
iDef
=
iDef
[
1
:]
iDef
=
iDef
[
1
:]
self
.
fO
ut
.
write
(
"static %s = %s;
\n
"
%
(
vDef
,
iDef
))
self
.
o
ut
.
write
(
"static %s = %s;
\n
"
%
(
vDef
,
iDef
))
def
writeTypeDeclarations
(
self
):
def
writeTypeDeclarations
(
self
):
self
.
fO
ut
.
write
(
"
\n
/* Type Declarations */
\n\n
"
)
self
.
o
ut
.
write
(
"
\n
/* Type Declarations */
\n\n
"
)
for
classNode
in
self
.
_orderedClassNodes
:
for
classNode
in
self
.
_orderedClassNodes
:
className
=
getText
(
"compoundname"
,
classNode
)
className
=
getText
(
"compoundname"
,
classNode
)
shortName
=
stripOpenMMPrefix
(
className
)
shortName
=
stripOpenMMPrefix
(
className
)
typeName
=
convertOpenMMPrefix
(
className
)
typeName
=
convertOpenMMPrefix
(
className
)
self
.
fO
ut
.
write
(
"typedef struct %s_struct %s;
\n
"
%
(
typeName
,
typeName
))
self
.
o
ut
.
write
(
"typedef struct %s_struct %s;
\n
"
%
(
typeName
,
typeName
))
self
.
typesByShortName
[
shortName
]
=
typeName
self
.
typesByShortName
[
shortName
]
=
typeName
def
writeClasses
(
self
):
def
writeClasses
(
self
):
for
classNode
in
self
.
_orderedClassNodes
:
for
classNode
in
self
.
_orderedClassNodes
:
className
=
stripOpenMMPrefix
(
getText
(
"compoundname"
,
classNode
))
className
=
stripOpenMMPrefix
(
getText
(
"compoundname"
,
classNode
))
self
.
fO
ut
.
write
(
"
\n
/* %s */
\n
"
%
className
)
self
.
o
ut
.
write
(
"
\n
/* %s */
\n
"
%
className
)
self
.
writeEnumerations
(
classNode
)
self
.
writeEnumerations
(
classNode
)
self
.
writeMethods
(
classNode
)
self
.
writeMethods
(
classNode
)
self
.
fO
ut
.
write
(
"
\n
"
)
self
.
o
ut
.
write
(
"
\n
"
)
def
writeEnumerations
(
self
,
classNode
):
def
writeEnumerations
(
self
,
classNode
):
enumNodes
=
[]
enumNodes
=
[]
...
@@ -198,18 +203,18 @@ class CHeaderGenerator(WrapperGenerator):
...
@@ -198,18 +203,18 @@ class CHeaderGenerator(WrapperGenerator):
for
enumNode
in
enumNodes
:
for
enumNode
in
enumNodes
:
enumName
=
getText
(
"name"
,
enumNode
)
enumName
=
getText
(
"name"
,
enumNode
)
enumTypeName
=
"%s_%s"
%
(
typeName
,
enumName
)
enumTypeName
=
"%s_%s"
%
(
typeName
,
enumName
)
self
.
fO
ut
.
write
(
"typedef enum {
\n
"
)
self
.
o
ut
.
write
(
"typedef enum {
\n
"
)
argSep
=
""
argSep
=
""
for
valueNode
in
findNodes
(
enumNode
,
"enumvalue"
,
prot
=
"public"
):
for
valueNode
in
findNodes
(
enumNode
,
"enumvalue"
,
prot
=
"public"
):
vName
=
convertOpenMMPrefix
(
getText
(
"name"
,
valueNode
))
vName
=
convertOpenMMPrefix
(
getText
(
"name"
,
valueNode
))
vInit
=
getText
(
"initializer"
,
valueNode
)
vInit
=
getText
(
"initializer"
,
valueNode
)
if
vInit
.
startswith
(
"="
):
if
vInit
.
startswith
(
"="
):
vInit
=
vInit
[
1
:].
strip
()
vInit
=
vInit
[
1
:].
strip
()
self
.
fO
ut
.
write
(
"%s%s_%s = %s"
%
(
argSep
,
typeName
,
vName
,
vInit
))
self
.
o
ut
.
write
(
"%s%s_%s = %s"
%
(
argSep
,
typeName
,
vName
,
vInit
))
argSep
=
", "
argSep
=
", "
self
.
fO
ut
.
write
(
"
\n
} %s;
\n
"
%
enumTypeName
)
self
.
o
ut
.
write
(
"
\n
} %s;
\n
"
%
enumTypeName
)
self
.
typesByShortName
[
enumName
]
=
enumTypeName
self
.
typesByShortName
[
enumName
]
=
enumTypeName
if
len
(
enumNodes
)
>
0
:
self
.
fO
ut
.
write
(
"
\n
"
)
if
len
(
enumNodes
)
>
0
:
self
.
o
ut
.
write
(
"
\n
"
)
def
writeMethods
(
self
,
classNode
):
def
writeMethods
(
self
,
classNode
):
methodList
=
self
.
getClassMethods
(
classNode
)
methodList
=
self
.
getClassMethods
(
classNode
)
...
@@ -233,31 +238,40 @@ class CHeaderGenerator(WrapperGenerator):
...
@@ -233,31 +238,40 @@ class CHeaderGenerator(WrapperGenerator):
suffix
=
""
suffix
=
""
else
:
else
:
suffix
=
"_%d"
%
numConstructors
suffix
=
"_%d"
%
numConstructors
self
.
fO
ut
.
write
(
"extern OPENMM_EXPORT %s* %s_create%s("
%
(
typeName
,
typeName
,
suffix
))
self
.
o
ut
.
write
(
"extern OPENMM_EXPORT %s* %s_create%s("
%
(
typeName
,
typeName
,
suffix
))
self
.
writeArguments
(
methodNode
,
False
)
self
.
writeArguments
(
methodNode
,
False
)
self
.
fO
ut
.
write
(
");
\n
"
)
self
.
o
ut
.
write
(
");
\n
"
)
# Write destructor
# Write destructor
self
.
fO
ut
.
write
(
"extern OPENMM_EXPORT void %s_destroy(%s* target);
\n
"
%
(
typeName
,
typeName
))
self
.
o
ut
.
write
(
"extern OPENMM_EXPORT void %s_destroy(%s* target);
\n
"
%
(
typeName
,
typeName
))
# Write other methods
# Record method names for future reference.
methodNames
=
{}
for
methodNode
in
methodList
:
for
methodNode
in
methodList
:
methodDefinition
=
getText
(
"definition"
,
methodNode
)
methodDefinition
=
getText
(
"definition"
,
methodNode
)
shortMethodDefinition
=
stripOpenMMPrefix
(
methodDefinition
)
shortMethodDefinition
=
stripOpenMMPrefix
(
methodDefinition
)
methodName
=
shortMethodDefinition
.
split
()[
-
1
]
methodNames
[
methodNode
]
=
shortMethodDefinition
.
split
()[
-
1
]
# Write other methods
for
methodNode
in
methodList
:
methodName
=
methodNames
[
methodNode
]
if
methodName
in
(
shortClassName
,
destructorName
):
if
methodName
in
(
shortClassName
,
destructorName
):
continue
continue
if
self
.
shouldHideMethod
(
methodNode
):
if
self
.
shouldHideMethod
(
methodNode
):
continue
continue
isConstMethod
=
(
methodNode
.
attrib
[
'const'
]
==
'yes'
)
if
isConstMethod
and
any
(
methodNames
[
m
]
==
methodName
and
m
.
attrib
[
'const'
]
==
'no'
for
m
in
methodList
):
# There are two identical methods that differ only in whether they are const. Skip the const one.
continue
returnType
=
self
.
getType
(
getText
(
"type"
,
methodNode
))
returnType
=
self
.
getType
(
getText
(
"type"
,
methodNode
))
self
.
fO
ut
.
write
(
"extern OPENMM_EXPORT %s %s_%s("
%
(
returnType
,
typeName
,
methodName
))
self
.
o
ut
.
write
(
"extern OPENMM_EXPORT %s %s_%s("
%
(
returnType
,
typeName
,
methodName
))
isInstanceMethod
=
(
methodNode
.
attrib
[
'static'
]
!=
'yes'
)
isInstanceMethod
=
(
methodNode
.
attrib
[
'static'
]
!=
'yes'
)
if
isInstanceMethod
:
if
isInstanceMethod
:
if
methodNode
.
attrib
[
'const'
]
==
'yes'
:
if
isConstMethod
:
self
.
fO
ut
.
write
(
'const '
)
self
.
o
ut
.
write
(
'const '
)
self
.
fO
ut
.
write
(
"%s* target"
%
typeName
)
self
.
o
ut
.
write
(
"%s* target"
%
typeName
)
self
.
writeArguments
(
methodNode
,
isInstanceMethod
)
self
.
writeArguments
(
methodNode
,
isInstanceMethod
)
self
.
fO
ut
.
write
(
");
\n
"
)
self
.
o
ut
.
write
(
");
\n
"
)
def
writeArguments
(
self
,
methodNode
,
initialSeparator
):
def
writeArguments
(
self
,
methodNode
,
initialSeparator
):
paramList
=
findNodes
(
methodNode
,
'param'
)
paramList
=
findNodes
(
methodNode
,
'param'
)
...
@@ -274,7 +288,7 @@ class CHeaderGenerator(WrapperGenerator):
...
@@ -274,7 +288,7 @@ class CHeaderGenerator(WrapperGenerator):
continue
continue
type
=
self
.
getType
(
type
)
type
=
self
.
getType
(
type
)
name
=
getText
(
'declname'
,
node
)
name
=
getText
(
'declname'
,
node
)
self
.
fO
ut
.
write
(
"%s%s %s"
%
(
separator
,
type
,
name
))
self
.
o
ut
.
write
(
"%s%s %s"
%
(
separator
,
type
,
name
))
separator
=
", "
separator
=
", "
def
getType
(
self
,
type
):
def
getType
(
self
,
type
):
...
@@ -289,7 +303,7 @@ class CHeaderGenerator(WrapperGenerator):
...
@@ -289,7 +303,7 @@ class CHeaderGenerator(WrapperGenerator):
return
type
return
type
def
writeOutput
(
self
):
def
writeOutput
(
self
):
print
>>
out
,
"""
print
>>
self
.
out
,
"""
#ifndef OPENMM_CWRAPPER_H_
#ifndef OPENMM_CWRAPPER_H_
#define OPENMM_CWRAPPER_H_
#define OPENMM_CWRAPPER_H_
...
@@ -299,7 +313,7 @@ class CHeaderGenerator(WrapperGenerator):
...
@@ -299,7 +313,7 @@ class CHeaderGenerator(WrapperGenerator):
"""
"""
self
.
writeGlobalConstants
()
self
.
writeGlobalConstants
()
self
.
writeTypeDeclarations
()
self
.
writeTypeDeclarations
()
print
>>
out
,
"""
print
>>
self
.
out
,
"""
typedef struct OpenMM_Vec3Array_struct OpenMM_Vec3Array;
typedef struct OpenMM_Vec3Array_struct OpenMM_Vec3Array;
typedef struct OpenMM_StringArray_struct OpenMM_StringArray;
typedef struct OpenMM_StringArray_struct OpenMM_StringArray;
typedef struct OpenMM_BondArray_struct OpenMM_BondArray;
typedef struct OpenMM_BondArray_struct OpenMM_BondArray;
...
@@ -357,7 +371,7 @@ extern OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyA
...
@@ -357,7 +371,7 @@ extern OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyA
for
type
in
(
'double'
,
'int'
):
for
type
in
(
'double'
,
'int'
):
name
=
'OpenMM_%sArray'
%
type
.
capitalize
()
name
=
'OpenMM_%sArray'
%
type
.
capitalize
()
values
=
{
'type'
:
type
,
'name'
:
name
}
values
=
{
'type'
:
type
,
'name'
:
name
}
print
>>
out
,
"""
print
>>
self
.
out
,
"""
/* %(name)s */
/* %(name)s */
extern OPENMM_EXPORT %(name)s* %(name)s_create(int size);
extern OPENMM_EXPORT %(name)s* %(name)s_create(int size);
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* array);
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* array);
...
@@ -370,14 +384,14 @@ extern OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index);"""
...
@@ -370,14 +384,14 @@ extern OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index);"""
for
type
in
(
'int'
,):
for
type
in
(
'int'
,):
name
=
'OpenMM_%sSet'
%
type
.
capitalize
()
name
=
'OpenMM_%sSet'
%
type
.
capitalize
()
values
=
{
'type'
:
type
,
'name'
:
name
}
values
=
{
'type'
:
type
,
'name'
:
name
}
print
>>
out
,
"""
print
>>
self
.
out
,
"""
/* %(name)s */
/* %(name)s */
extern OPENMM_EXPORT %(name)s* %(name)s_create();
extern OPENMM_EXPORT %(name)s* %(name)s_create();
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* set);
extern OPENMM_EXPORT void %(name)s_destroy(%(name)s* set);
extern OPENMM_EXPORT int %(name)s_getSize(const %(name)s* set);
extern OPENMM_EXPORT int %(name)s_getSize(const %(name)s* set);
extern OPENMM_EXPORT void %(name)s_insert(%(name)s* set, %(type)s value);"""
%
values
extern OPENMM_EXPORT void %(name)s_insert(%(name)s* set, %(type)s value);"""
%
values
print
>>
out
,
"""
print
>>
self
.
out
,
"""
/* These methods need to be handled specially, since their C++ APIs cannot be directly translated to C.
/* These methods need to be handled specially, since their C++ APIs cannot be directly translated to C.
Unlike the C++ versions, the return value is allocated on the heap, and you must delete it yourself. */
Unlike the C++ versions, the return value is allocated on the heap, and you must delete it yourself. */
extern OPENMM_EXPORT OpenMM_State* OpenMM_Context_getState(const OpenMM_Context* target, int types, int enforcePeriodicBox);
extern OPENMM_EXPORT OpenMM_State* OpenMM_Context_getState(const OpenMM_Context* target, int types, int enforcePeriodicBox);
...
@@ -385,7 +399,7 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector
...
@@ -385,7 +399,7 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector
self
.
writeClasses
()
self
.
writeClasses
()
print
>>
out
,
"""
print
>>
self
.
out
,
"""
#if defined(__cplusplus)
#if defined(__cplusplus)
}
}
#endif
#endif
...
@@ -394,9 +408,12 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector
...
@@ -394,9 +408,12 @@ extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirector
class
CSourceGenerator
(
WrapperGenerator
):
class
CSourceGenerator
(
WrapperGenerator
):
"""This class generates the source file for the C API wrappers."""
def
__init__
(
self
,
inputDirname
,
output
):
def
__init__
(
self
,
inputDirname
,
output
):
WrapperGenerator
.
__init__
(
self
,
inputDirname
,
output
)
WrapperGenerator
.
__init__
(
self
,
inputDirname
,
output
)
self
.
classesByShortName
=
{}
self
.
classesByShortName
=
{}
self
.
enumerationTypes
=
{}
self
.
findTypes
()
self
.
findTypes
()
def
findTypes
(
self
):
def
findTypes
(
self
):
...
@@ -416,16 +433,19 @@ class CSourceGenerator(WrapperGenerator):
...
@@ -416,16 +433,19 @@ class CSourceGenerator(WrapperGenerator):
typeName
=
convertOpenMMPrefix
(
className
)
typeName
=
convertOpenMMPrefix
(
className
)
for
enumNode
in
enumNodes
:
for
enumNode
in
enumNodes
:
enumName
=
getText
(
"name"
,
enumNode
)
enumName
=
getText
(
"name"
,
enumNode
)
self
.
typesByShortName
[
enumName
]
=
"%s_%s"
%
(
typeName
,
enumName
)
enumTypeName
=
"%s_%s"
%
(
typeName
,
enumName
)
self
.
classesByShortName
[
enumName
]
=
'%s::%s'
%
(
className
,
enumName
)
enumClassName
=
"%s::%s"
%
(
className
,
enumName
)
self
.
typesByShortName
[
enumName
]
=
enumTypeName
self
.
classesByShortName
[
enumName
]
=
enumClassName
self
.
enumerationTypes
[
enumClassName
]
=
enumTypeName
def
writeClasses
(
self
):
def
writeClasses
(
self
):
for
classNode
in
self
.
_orderedClassNodes
:
for
classNode
in
self
.
_orderedClassNodes
:
className
=
stripOpenMMPrefix
(
getText
(
"compoundname"
,
classNode
))
className
=
stripOpenMMPrefix
(
getText
(
"compoundname"
,
classNode
))
self
.
fO
ut
.
write
(
"
\n
/* OpenMM::%s */
\n
"
%
className
)
self
.
o
ut
.
write
(
"
\n
/* OpenMM::%s */
\n
"
%
className
)
self
.
findEnumerations
(
classNode
)
self
.
findEnumerations
(
classNode
)
self
.
writeMethods
(
classNode
)
self
.
writeMethods
(
classNode
)
self
.
fO
ut
.
write
(
"
\n
"
)
self
.
o
ut
.
write
(
"
\n
"
)
def
writeMethods
(
self
,
classNode
):
def
writeMethods
(
self
,
classNode
):
methodList
=
self
.
getClassMethods
(
classNode
)
methodList
=
self
.
getClassMethods
(
classNode
)
...
@@ -449,57 +469,69 @@ class CSourceGenerator(WrapperGenerator):
...
@@ -449,57 +469,69 @@ class CSourceGenerator(WrapperGenerator):
suffix
=
""
suffix
=
""
else
:
else
:
suffix
=
"_%d"
%
numConstructors
suffix
=
"_%d"
%
numConstructors
self
.
fO
ut
.
write
(
"OPENMM_EXPORT %s* %s_create%s("
%
(
typeName
,
typeName
,
suffix
))
self
.
o
ut
.
write
(
"OPENMM_EXPORT %s* %s_create%s("
%
(
typeName
,
typeName
,
suffix
))
self
.
writeArguments
(
methodNode
,
False
)
self
.
writeArguments
(
methodNode
,
False
)
self
.
fO
ut
.
write
(
") {
\n
"
)
self
.
o
ut
.
write
(
") {
\n
"
)
self
.
fO
ut
.
write
(
" return reinterpret_cast<%s*>(new %s("
%
(
class
Name
,
className
))
self
.
o
ut
.
write
(
" return reinterpret_cast<%s*>(new %s("
%
(
type
Name
,
className
))
self
.
writeInvocationArguments
(
methodNode
,
False
)
self
.
writeInvocationArguments
(
methodNode
,
False
)
self
.
fO
ut
.
write
(
"));
\n
"
)
self
.
o
ut
.
write
(
"));
\n
"
)
self
.
fO
ut
.
write
(
"}
\n
"
)
self
.
o
ut
.
write
(
"}
\n
"
)
# Write destructor
# Write destructor
self
.
fO
ut
.
write
(
"OPENMM_EXPORT void %s_destroy(%s* target) {
\n
"
%
(
typeName
,
typeName
))
self
.
o
ut
.
write
(
"OPENMM_EXPORT void %s_destroy(%s* target) {
\n
"
%
(
typeName
,
typeName
))
self
.
fO
ut
.
write
(
" delete reinterpret_cast<%s*>(target);
\n
"
%
className
)
self
.
o
ut
.
write
(
" delete reinterpret_cast<%s*>(target);
\n
"
%
className
)
self
.
fO
ut
.
write
(
"}
\n
"
)
self
.
o
ut
.
write
(
"}
\n
"
)
# Write other methods
# Record method names for future reference.
methodNames
=
{}
for
methodNode
in
methodList
:
for
methodNode
in
methodList
:
methodDefinition
=
getText
(
"definition"
,
methodNode
)
methodDefinition
=
getText
(
"definition"
,
methodNode
)
shortMethodDefinition
=
stripOpenMMPrefix
(
methodDefinition
)
shortMethodDefinition
=
stripOpenMMPrefix
(
methodDefinition
)
methodName
=
shortMethodDefinition
.
split
()[
-
1
]
methodNames
[
methodNode
]
=
shortMethodDefinition
.
split
()[
-
1
]
# Write other methods
for
methodNode
in
methodList
:
methodName
=
methodNames
[
methodNode
]
if
methodName
in
(
shortClassName
,
destructorName
):
if
methodName
in
(
shortClassName
,
destructorName
):
continue
continue
if
self
.
shouldHideMethod
(
methodNode
):
if
self
.
shouldHideMethod
(
methodNode
):
continue
continue
isConstMethod
=
(
methodNode
.
attrib
[
'const'
]
==
'yes'
)
if
isConstMethod
and
any
(
methodNames
[
m
]
==
methodName
and
m
.
attrib
[
'const'
]
==
'no'
for
m
in
methodList
):
# There are two identical methods that differ only in whether they are const. Skip the const one.
continue
methodType
=
getText
(
"type"
,
methodNode
)
methodType
=
getText
(
"type"
,
methodNode
)
returnType
=
self
.
getType
(
methodType
)
returnType
=
self
.
getType
(
methodType
)
if
methodType
in
self
.
classesByShortName
:
if
methodType
in
self
.
classesByShortName
:
methodType
=
self
.
classesByShortName
[
methodType
]
methodType
=
self
.
classesByShortName
[
methodType
]
self
.
fO
ut
.
write
(
"OPENMM_EXPORT %s %s_%s("
%
(
returnType
,
typeName
,
methodName
))
self
.
o
ut
.
write
(
"OPENMM_EXPORT %s %s_%s("
%
(
returnType
,
typeName
,
methodName
))
isInstanceMethod
=
(
methodNode
.
attrib
[
'static'
]
!=
'yes'
)
isInstanceMethod
=
(
methodNode
.
attrib
[
'static'
]
!=
'yes'
)
if
isInstanceMethod
:
if
isInstanceMethod
:
isConstMethod
=
(
methodNode
.
attrib
[
'const'
]
==
'yes'
)
if
isConstMethod
:
if
isConstMethod
:
self
.
fO
ut
.
write
(
'const '
)
self
.
o
ut
.
write
(
'const '
)
self
.
fO
ut
.
write
(
"%s* target"
%
typeName
)
self
.
o
ut
.
write
(
"%s* target"
%
typeName
)
self
.
writeArguments
(
methodNode
,
isInstanceMethod
)
self
.
writeArguments
(
methodNode
,
isInstanceMethod
)
self
.
fO
ut
.
write
(
") {
\n
"
)
self
.
o
ut
.
write
(
") {
\n
"
)
self
.
fO
ut
.
write
(
" "
)
self
.
o
ut
.
write
(
" "
)
if
returnType
!=
'void'
:
if
returnType
!=
'void'
:
self
.
fOut
.
write
(
'%s result = '
%
methodType
)
if
methodType
.
endswith
(
'&'
):
# Convert references to pointers
self
.
out
.
write
(
'%s* result = &'
%
methodType
[:
-
1
].
strip
())
else
:
self
.
out
.
write
(
'%s result = '
%
methodType
)
if
isInstanceMethod
:
if
isInstanceMethod
:
self
.
fO
ut
.
write
(
'reinterpret_cast<'
)
self
.
o
ut
.
write
(
'reinterpret_cast<'
)
if
isConstMethod
:
if
isConstMethod
:
self
.
fO
ut
.
write
(
'const '
)
self
.
o
ut
.
write
(
'const '
)
self
.
fO
ut
.
write
(
'%s*>(target)->'
%
className
)
self
.
o
ut
.
write
(
'%s*>(target)->'
%
className
)
else
:
else
:
self
.
fO
ut
.
write
(
'%s::'
%
className
)
self
.
o
ut
.
write
(
'%s::'
%
className
)
self
.
fO
ut
.
write
(
'%s('
%
methodName
)
self
.
o
ut
.
write
(
'%s('
%
methodName
)
self
.
writeInvocationArguments
(
methodNode
,
False
)
self
.
writeInvocationArguments
(
methodNode
,
False
)
self
.
fO
ut
.
write
(
');
\n
'
)
self
.
o
ut
.
write
(
');
\n
'
)
if
returnType
!=
'void'
:
if
returnType
!=
'void'
:
self
.
fO
ut
.
write
(
' return %s;
\n
'
%
self
.
wrapValue
(
methodType
,
'result'
))
self
.
o
ut
.
write
(
' return %s;
\n
'
%
self
.
wrapValue
(
methodType
,
'result'
))
self
.
fO
ut
.
write
(
"}
\n
"
)
self
.
o
ut
.
write
(
"}
\n
"
)
def
writeArguments
(
self
,
methodNode
,
initialSeparator
):
def
writeArguments
(
self
,
methodNode
,
initialSeparator
):
paramList
=
findNodes
(
methodNode
,
'param'
)
paramList
=
findNodes
(
methodNode
,
'param'
)
...
@@ -516,7 +548,7 @@ class CSourceGenerator(WrapperGenerator):
...
@@ -516,7 +548,7 @@ class CSourceGenerator(WrapperGenerator):
continue
continue
type
=
self
.
getType
(
type
)
type
=
self
.
getType
(
type
)
name
=
getText
(
'declname'
,
node
)
name
=
getText
(
'declname'
,
node
)
self
.
fO
ut
.
write
(
"%s%s %s"
%
(
separator
,
type
,
name
))
self
.
o
ut
.
write
(
"%s%s %s"
%
(
separator
,
type
,
name
))
separator
=
", "
separator
=
", "
def
writeInvocationArguments
(
self
,
methodNode
,
initialSeparator
):
def
writeInvocationArguments
(
self
,
methodNode
,
initialSeparator
):
...
@@ -533,7 +565,9 @@ class CSourceGenerator(WrapperGenerator):
...
@@ -533,7 +565,9 @@ class CSourceGenerator(WrapperGenerator):
if
type
==
'void'
:
if
type
==
'void'
:
continue
continue
name
=
getText
(
'declname'
,
node
)
name
=
getText
(
'declname'
,
node
)
self
.
fOut
.
write
(
"%s%s"
%
(
separator
,
name
))
if
self
.
getType
(
type
)
!=
type
:
name
=
self
.
unwrapValue
(
type
,
name
)
self
.
out
.
write
(
"%s%s"
%
(
separator
,
name
))
separator
=
", "
separator
=
", "
def
getType
(
self
,
type
):
def
getType
(
self
,
type
):
...
@@ -553,7 +587,9 @@ class CSourceGenerator(WrapperGenerator):
...
@@ -553,7 +587,9 @@ class CSourceGenerator(WrapperGenerator):
if
type
==
'std::string'
:
if
type
==
'std::string'
:
return
'%s.c_str()'
%
value
return
'%s.c_str()'
%
value
if
type
==
'const std::string &'
:
if
type
==
'const std::string &'
:
return
'%s.c_str()'
%
value
return
'%s->c_str()'
%
value
if
type
in
self
.
enumerationTypes
:
return
'static_cast<%s>(%s)'
%
(
self
.
enumerationTypes
[
type
],
value
)
wrappedType
=
self
.
getType
(
type
)
wrappedType
=
self
.
getType
(
type
)
if
wrappedType
==
type
:
if
wrappedType
==
type
:
return
value
;
return
value
;
...
@@ -561,8 +597,20 @@ class CSourceGenerator(WrapperGenerator):
...
@@ -561,8 +597,20 @@ class CSourceGenerator(WrapperGenerator):
return
'reinterpret_cast<%s>(%s)'
%
(
wrappedType
,
value
)
return
'reinterpret_cast<%s>(%s)'
%
(
wrappedType
,
value
)
return
'static_cast<%s>(%s)'
%
(
wrappedType
,
value
)
return
'static_cast<%s>(%s)'
%
(
wrappedType
,
value
)
def
unwrapValue
(
self
,
type
,
value
):
if
type
.
endswith
(
'&'
):
unwrappedType
=
type
[:
-
1
].
strip
()
if
unwrappedType
in
self
.
classesByShortName
:
unwrappedType
=
self
.
classesByShortName
[
unwrappedType
]
return
'*'
+
self
.
unwrapValue
(
unwrappedType
+
'*'
,
value
)
if
type
in
self
.
classesByShortName
:
return
'static_cast<%s>(%s)'
%
(
self
.
classesByShortName
[
type
],
value
)
if
type
==
'bool'
:
return
value
return
'reinterpret_cast<%s>(%s)'
%
(
type
,
value
)
def
writeOutput
(
self
):
def
writeOutput
(
self
):
print
>>
out
,
"""
print
>>
self
.
out
,
"""
#include "OpenMM.h"
#include "OpenMM.h"
#include "OpenMMCWrapper.h"
#include "OpenMMCWrapper.h"
#include <cstring>
#include <cstring>
...
@@ -599,7 +647,7 @@ OPENMM_EXPORT void OpenMM_Vec3Array_set(OpenMM_Vec3Array* array, int index, cons
...
@@ -599,7 +647,7 @@ OPENMM_EXPORT void OpenMM_Vec3Array_set(OpenMM_Vec3Array* array, int index, cons
(*reinterpret_cast<vector<Vec3>*>(array))[index] = Vec3(vec.x, vec.y, vec.z);
(*reinterpret_cast<vector<Vec3>*>(array))[index] = Vec3(vec.x, vec.y, vec.z);
}
}
OPENMM_EXPORT const OpenMM_Vec3* OpenMM_Vec3Array_get(const OpenMM_Vec3Array* array, int index) {
OPENMM_EXPORT const OpenMM_Vec3* OpenMM_Vec3Array_get(const OpenMM_Vec3Array* array, int index) {
return reinterpret_cast<const OpenMM_Vec3*>((&
amp;
(*reinterpret_cast<const vector<Vec3>*>(array))[index]));
return reinterpret_cast<const OpenMM_Vec3*>((&(*reinterpret_cast<const vector<Vec3>*>(array))[index]));
}
}
/* OpenMM_StringArray */
/* OpenMM_StringArray */
...
@@ -677,7 +725,7 @@ OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyArray* a
...
@@ -677,7 +725,7 @@ OPENMM_EXPORT const char* OpenMM_PropertyArray_get(const OpenMM_PropertyArray* a
for
type
in
(
'double'
,
'int'
):
for
type
in
(
'double'
,
'int'
):
name
=
'OpenMM_%sArray'
%
type
.
capitalize
()
name
=
'OpenMM_%sArray'
%
type
.
capitalize
()
values
=
{
'type'
:
type
,
'name'
:
name
}
values
=
{
'type'
:
type
,
'name'
:
name
}
print
>>
out
,
"""
print
>>
self
.
out
,
"""
/* %(name)s */
/* %(name)s */
OPENMM_EXPORT %(name)s* %(name)s_create(int size) {
OPENMM_EXPORT %(name)s* %(name)s_create(int size) {
return reinterpret_cast<%(name)s*>(new vector<%(type)s>(size));
return reinterpret_cast<%(name)s*>(new vector<%(type)s>(size));
...
@@ -704,7 +752,7 @@ OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index) {
...
@@ -704,7 +752,7 @@ OPENMM_EXPORT %(type)s %(name)s_get(const %(name)s* array, int index) {
for
type
in
(
'int'
,):
for
type
in
(
'int'
,):
name
=
'OpenMM_%sSet'
%
type
.
capitalize
()
name
=
'OpenMM_%sSet'
%
type
.
capitalize
()
values
=
{
'type'
:
type
,
'name'
:
name
}
values
=
{
'type'
:
type
,
'name'
:
name
}
print
>>
out
,
"""
print
>>
self
.
out
,
"""
/* %(name)s */
/* %(name)s */
OPENMM_EXPORT %(name)s* %(name)s_create() {
OPENMM_EXPORT %(name)s* %(name)s_create() {
return reinterpret_cast<%(name)s*>(new set<%(type)s>());
return reinterpret_cast<%(name)s*>(new set<%(type)s>());
...
@@ -720,9 +768,11 @@ OPENMM_EXPORT void %(name)s_insert(%(name)s* s, %(type)s value) {
...
@@ -720,9 +768,11 @@ OPENMM_EXPORT void %(name)s_insert(%(name)s* s, %(type)s value) {
}"""
%
values
}"""
%
values
self
.
writeClasses
()
self
.
writeClasses
()
print
>>
self
.
out
,
"}
\n
"
inputDirname
=
'/Users/peastman/workspace/openmm/bin-release/wrappers/doxygen/xml'
#inputDirname = '/Users/peastman/workspace/openmm/bin-release/wrappers/doxygen/xml'
out
=
sys
.
stdout
inputDirname
=
sys
.
argv
[
1
]
#builder = CHeaderGenerator(inputDirname, out)
builder
=
CHeaderGenerator
(
inputDirname
,
open
(
os
.
path
.
join
(
sys
.
argv
[
2
],
'OpenMMCWrapper.h'
),
'w'
))
builder
=
CSourceGenerator
(
inputDirname
,
out
)
builder
.
writeOutput
()
builder
=
CSourceGenerator
(
inputDirname
,
open
(
os
.
path
.
join
(
sys
.
argv
[
2
],
'OpenMMCWrapper.cpp'
),
'w'
))
builder
.
writeOutput
()
builder
.
writeOutput
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment