Commit 4f911c7a authored by Peter Eastman's avatar Peter Eastman
Browse files

Cleaned up Python API for XmlSerializer. Also ensure that deserialized...

Cleaned up Python API for XmlSerializer.  Also ensure that deserialized objects are returned as the correct class (see bug 1827)
parent 85f8a7e2
......@@ -188,16 +188,15 @@ class SwigInputBuilder:
excludedClassNodes.append(node)
def writeForceSubclassList(self):
self.fOut.write("\n/* Force subclasses */\n\n")
forceSubclassList=[]
def writeFactories(self):
self.fOut.write("\n/* Declare factories */\n\n")
forceSubclassList = []
integratorSubclassList = []
for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
className = getText("compoundname", classNode)
shortClassName=stripOpenmmPrefix(className)
if (className.split("::")[-1],) in self.skipMethods:
continue
#print className
#print classNode.toxml()
for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
if "refid" in baseNodePnt.attrib:
baseNodeID=baseNodePnt.attrib["refid"]
......@@ -205,14 +204,20 @@ class SwigInputBuilder:
baseName = getText("compoundname", baseNode)
if baseName == 'OpenMM::Force':
forceSubclassList.append(shortClassName)
elif baseName == 'OpenMM::Integrator':
integratorSubclassList.append(shortClassName)
self.fOut.write("%factory(OpenMM::Force& OpenMM::System::getForce")
for name in sorted(forceSubclassList):
self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n")
# for classNode in self._orderedClassNodes:
# className = stripOpenmmPrefix(getText("compoundname", classNode))
# self.fOut.write("class %s ;\n" % className)
self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__deserializeForce")
for name in sorted(forceSubclassList):
self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__deserializeIntegrator")
for name in sorted(integratorSubclassList):
self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n")
self.fOut.write("\n")
def writeGlobalConstants(self):
......@@ -379,43 +384,6 @@ class SwigInputBuilder:
if clearOutput:
self.fOut.write(clearOutput)
isXmlSerializer = 0
isSystem = 0
for items in methodList:
(shortClassName, memberNode,
shortMethDefinition, methName,
isConstructors, isDestructor, templateType, templateName) = items
if( shortClassName == 'XmlSerializer' ):isXmlSerializer = 1
if( shortClassName == 'System' ):isSystem = 1
if( isXmlSerializer == 1 ):
extendString = "%extend {\n"
extendString += " static std::string serializeSystem( const OpenMM::System *object ){\n"
extendString += " std::stringstream ss;\n"
extendString += " XmlSerializer::serialize<OpenMM::System>( object, \"System\", ss );\n"
extendString += " return ss.str();\n"
extendString += " }\n"
extendString += "\n"
extendString += " static OpenMM::System* deserializeSystem( const char* inputString ){\n"
extendString += " std::stringstream ss;\n"
extendString += " ss << inputString;\n"
extendString += " return XmlSerializer::deserialize<OpenMM::System>( ss );\n"
extendString += " }\n"
extendString += " static std::string serializeIntegrator( const OpenMM::Integrator *object ){\n"
extendString += " std::stringstream ss;\n"
extendString += " XmlSerializer::serialize<OpenMM::Integrator>( object, \"Integrator\", ss );\n"
extendString += " return ss.str();\n"
extendString += " }\n"
extendString += "\n"
extendString += " static OpenMM::Integrator* deserializeIntegrator( const char* inputString ){\n"
extendString += " std::stringstream ss;\n"
extendString += " ss << inputString;\n"
extendString += " return XmlSerializer::deserialize<OpenMM::Integrator>( ss );\n"
extendString += " }\n"
extendString += " };\n"
self.fOut.write("%s%s\n" % (INDENT, extendString))
#write python mod blocks
for items in methodList:
(shortClassName, memberNode,
......@@ -571,7 +539,7 @@ class SwigInputBuilder:
self.fOut.write("%sgenerated by %s on %s\n*/\n\n\n"
% (INDENT, sys.argv[0], time.asctime()))
self.fOut.write("\nnamespace OpenMM {\n\n")
self.writeForceSubclassList()
self.writeFactories()
self.writeGlobalConstants()
self.writeForwardDeclarations()
self.writeClassDeclarations()
......
......@@ -137,6 +137,8 @@ SKIP_METHODS = [('State',),
('Platform', 'registerKernelFactory'),
('IntegrateRPMDStepKernel',),
('RPMDIntegrator', 'getState'),
('XmlSerializer', 'serialize'),
('XmlSerializer', 'deserialize'),
]
# The build script assumes method args that are non-const references are
......
......@@ -243,6 +243,44 @@ Parameters:
}
%extend OpenMM::XmlSerializer {
%feature(docstring, "This method exists only for backward compatibility. @deprecated Use serialize() instead.") serializeSystem;
static std::string serializeSystem(const OpenMM::System* object) {
std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::System>(object, "System", ss);
return ss.str();
}
%feature(docstring, "This method exists only for backward compatibility. @deprecated Use deserialize() instead.") deserializeSystem;
static OpenMM::System* deserializeSystem(const char* inputString) {
std::stringstream ss;
ss << inputString;
return OpenMM::XmlSerializer::deserialize<OpenMM::System>(ss);
}
static std::string _serializeForce(const OpenMM::Force* object) {
std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::Force>(object, "Force", ss);
return ss.str();
}
static OpenMM::Force* _deserializeForce(const char* inputString) {
std::stringstream ss;
ss << inputString;
return OpenMM::XmlSerializer::deserialize<OpenMM::Force>(ss);
}
static std::string _serializeIntegrator(const OpenMM::Integrator* object) {
std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::Integrator>(object, "Integrator", ss);
return ss.str();
}
static OpenMM::Integrator* _deserializeIntegrator(const char* inputString) {
std::stringstream ss;
ss << inputString;
return OpenMM::XmlSerializer::deserialize<OpenMM::Integrator>(ss);
}
static std::string _serializeStateAsLists(
const std::vector<Vec3>& pos,
const std::vector<Vec3>& vel,
......@@ -270,7 +308,7 @@ Parameters:
%pythoncode {
@staticmethod
def serializeState(pythonState):
def _serializeState(pythonState):
positions = []
velocities = []
forces = []
......@@ -309,7 +347,7 @@ Parameters:
return string
@staticmethod
def deserializeState(pythonString):
def _deserializeState(pythonString):
(simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = XmlSerializer._deserializeStringIntoLists(pythonString)
......@@ -322,6 +360,38 @@ Parameters:
periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap)
return state
@staticmethod
def serialize(object):
"""Serialize an object as XML."""
if isinstance(object, System):
return XmlSerializer.serializeSystem(object)
elif isinstance(object, Force):
return XmlSerializer._serializeForce(object)
elif isinstance(object, Integrator):
return XmlSerializer._serializeIntegrator(object)
elif isinstance(object, State):
return XmlSerializer._serializeState(object)
raise ValueError("Unsupported object type")
@staticmethod
def deserialize(inputString):
"""Reconstruct an object that has been serialized as XML."""
# Look for the first tag to figure out what type of object it is.
import re
match = re.search("<([^?]\S*)", inputString)
if match is None:
raise ValueError("Invalid input string")
type = match.groups()[0]
if type == "System":
return XmlSerializer.deserializeSystem(inputString)
if type == "Force":
return XmlSerializer._deserializeForce(inputString)
if type == "Integrator":
return XmlSerializer._deserializeIntegrator(inputString)
if type == "State":
return XmlSerializer._deserializeState(inputString)
raise ValueError("Unsupported object type")
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment