Commit 4f911c7a authored by Peter Eastman's avatar Peter Eastman
Browse files

Cleaned up Python API for XmlSerializer. Also ensure that deserialized...

Cleaned up Python API for XmlSerializer.  Also ensure that deserialized objects are returned as the correct class (see bug 1827)
parent 85f8a7e2
...@@ -188,16 +188,15 @@ class SwigInputBuilder: ...@@ -188,16 +188,15 @@ class SwigInputBuilder:
excludedClassNodes.append(node) excludedClassNodes.append(node)
def writeForceSubclassList(self): def writeFactories(self):
self.fOut.write("\n/* Force subclasses */\n\n") self.fOut.write("\n/* Declare factories */\n\n")
forceSubclassList=[] forceSubclassList = []
integratorSubclassList = []
for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"): for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
className = getText("compoundname", classNode) className = getText("compoundname", classNode)
shortClassName=stripOpenmmPrefix(className) shortClassName=stripOpenmmPrefix(className)
if (className.split("::")[-1],) in self.skipMethods: if (className.split("::")[-1],) in self.skipMethods:
continue continue
#print className
#print classNode.toxml()
for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"): for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
if "refid" in baseNodePnt.attrib: if "refid" in baseNodePnt.attrib:
baseNodeID=baseNodePnt.attrib["refid"] baseNodeID=baseNodePnt.attrib["refid"]
...@@ -205,14 +204,20 @@ class SwigInputBuilder: ...@@ -205,14 +204,20 @@ class SwigInputBuilder:
baseName = getText("compoundname", baseNode) baseName = getText("compoundname", baseNode)
if baseName == 'OpenMM::Force': if baseName == 'OpenMM::Force':
forceSubclassList.append(shortClassName) forceSubclassList.append(shortClassName)
elif baseName == 'OpenMM::Integrator':
integratorSubclassList.append(shortClassName)
self.fOut.write("%factory(OpenMM::Force& OpenMM::System::getForce") self.fOut.write("%factory(OpenMM::Force& OpenMM::System::getForce")
for name in sorted(forceSubclassList): for name in sorted(forceSubclassList):
self.fOut.write(",\n OpenMM::%s" % name) self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n") self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__deserializeForce")
# for classNode in self._orderedClassNodes: for name in sorted(forceSubclassList):
# className = stripOpenmmPrefix(getText("compoundname", classNode)) self.fOut.write(",\n OpenMM::%s" % name)
# self.fOut.write("class %s ;\n" % className) self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__deserializeIntegrator")
for name in sorted(integratorSubclassList):
self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n")
self.fOut.write("\n") self.fOut.write("\n")
def writeGlobalConstants(self): def writeGlobalConstants(self):
...@@ -379,43 +384,6 @@ class SwigInputBuilder: ...@@ -379,43 +384,6 @@ class SwigInputBuilder:
if clearOutput: if clearOutput:
self.fOut.write(clearOutput) self.fOut.write(clearOutput)
isXmlSerializer = 0
isSystem = 0
for items in methodList:
(shortClassName, memberNode,
shortMethDefinition, methName,
isConstructors, isDestructor, templateType, templateName) = items
if( shortClassName == 'XmlSerializer' ):isXmlSerializer = 1
if( shortClassName == 'System' ):isSystem = 1
if( isXmlSerializer == 1 ):
extendString = "%extend {\n"
extendString += " static std::string serializeSystem( const OpenMM::System *object ){\n"
extendString += " std::stringstream ss;\n"
extendString += " XmlSerializer::serialize<OpenMM::System>( object, \"System\", ss );\n"
extendString += " return ss.str();\n"
extendString += " }\n"
extendString += "\n"
extendString += " static OpenMM::System* deserializeSystem( const char* inputString ){\n"
extendString += " std::stringstream ss;\n"
extendString += " ss << inputString;\n"
extendString += " return XmlSerializer::deserialize<OpenMM::System>( ss );\n"
extendString += " }\n"
extendString += " static std::string serializeIntegrator( const OpenMM::Integrator *object ){\n"
extendString += " std::stringstream ss;\n"
extendString += " XmlSerializer::serialize<OpenMM::Integrator>( object, \"Integrator\", ss );\n"
extendString += " return ss.str();\n"
extendString += " }\n"
extendString += "\n"
extendString += " static OpenMM::Integrator* deserializeIntegrator( const char* inputString ){\n"
extendString += " std::stringstream ss;\n"
extendString += " ss << inputString;\n"
extendString += " return XmlSerializer::deserialize<OpenMM::Integrator>( ss );\n"
extendString += " }\n"
extendString += " };\n"
self.fOut.write("%s%s\n" % (INDENT, extendString))
#write python mod blocks #write python mod blocks
for items in methodList: for items in methodList:
(shortClassName, memberNode, (shortClassName, memberNode,
...@@ -571,7 +539,7 @@ class SwigInputBuilder: ...@@ -571,7 +539,7 @@ class SwigInputBuilder:
self.fOut.write("%sgenerated by %s on %s\n*/\n\n\n" self.fOut.write("%sgenerated by %s on %s\n*/\n\n\n"
% (INDENT, sys.argv[0], time.asctime())) % (INDENT, sys.argv[0], time.asctime()))
self.fOut.write("\nnamespace OpenMM {\n\n") self.fOut.write("\nnamespace OpenMM {\n\n")
self.writeForceSubclassList() self.writeFactories()
self.writeGlobalConstants() self.writeGlobalConstants()
self.writeForwardDeclarations() self.writeForwardDeclarations()
self.writeClassDeclarations() self.writeClassDeclarations()
......
...@@ -137,6 +137,8 @@ SKIP_METHODS = [('State',), ...@@ -137,6 +137,8 @@ SKIP_METHODS = [('State',),
('Platform', 'registerKernelFactory'), ('Platform', 'registerKernelFactory'),
('IntegrateRPMDStepKernel',), ('IntegrateRPMDStepKernel',),
('RPMDIntegrator', 'getState'), ('RPMDIntegrator', 'getState'),
('XmlSerializer', 'serialize'),
('XmlSerializer', 'deserialize'),
] ]
# The build script assumes method args that are non-const references are # The build script assumes method args that are non-const references are
......
...@@ -243,6 +243,44 @@ Parameters: ...@@ -243,6 +243,44 @@ Parameters:
} }
%extend OpenMM::XmlSerializer { %extend OpenMM::XmlSerializer {
%feature(docstring, "This method exists only for backward compatibility. @deprecated Use serialize() instead.") serializeSystem;
static std::string serializeSystem(const OpenMM::System* object) {
std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::System>(object, "System", ss);
return ss.str();
}
%feature(docstring, "This method exists only for backward compatibility. @deprecated Use deserialize() instead.") deserializeSystem;
static OpenMM::System* deserializeSystem(const char* inputString) {
std::stringstream ss;
ss << inputString;
return OpenMM::XmlSerializer::deserialize<OpenMM::System>(ss);
}
static std::string _serializeForce(const OpenMM::Force* object) {
std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::Force>(object, "Force", ss);
return ss.str();
}
static OpenMM::Force* _deserializeForce(const char* inputString) {
std::stringstream ss;
ss << inputString;
return OpenMM::XmlSerializer::deserialize<OpenMM::Force>(ss);
}
static std::string _serializeIntegrator(const OpenMM::Integrator* object) {
std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::Integrator>(object, "Integrator", ss);
return ss.str();
}
static OpenMM::Integrator* _deserializeIntegrator(const char* inputString) {
std::stringstream ss;
ss << inputString;
return OpenMM::XmlSerializer::deserialize<OpenMM::Integrator>(ss);
}
static std::string _serializeStateAsLists( static std::string _serializeStateAsLists(
const std::vector<Vec3>& pos, const std::vector<Vec3>& pos,
const std::vector<Vec3>& vel, const std::vector<Vec3>& vel,
...@@ -270,7 +308,7 @@ Parameters: ...@@ -270,7 +308,7 @@ Parameters:
%pythoncode { %pythoncode {
@staticmethod @staticmethod
def serializeState(pythonState): def _serializeState(pythonState):
positions = [] positions = []
velocities = [] velocities = []
forces = [] forces = []
...@@ -309,7 +347,7 @@ Parameters: ...@@ -309,7 +347,7 @@ Parameters:
return string return string
@staticmethod @staticmethod
def deserializeState(pythonString): def _deserializeState(pythonString):
(simTime, periodicBoxVectorsList, energy, coordList, velList, (simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = XmlSerializer._deserializeStringIntoLists(pythonString) forceList, paramMap) = XmlSerializer._deserializeStringIntoLists(pythonString)
...@@ -322,6 +360,38 @@ Parameters: ...@@ -322,6 +360,38 @@ Parameters:
periodicBoxVectorsList=periodicBoxVectorsList, periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap) paramMap=paramMap)
return state return state
@staticmethod
def serialize(object):
"""Serialize an object as XML."""
if isinstance(object, System):
return XmlSerializer.serializeSystem(object)
elif isinstance(object, Force):
return XmlSerializer._serializeForce(object)
elif isinstance(object, Integrator):
return XmlSerializer._serializeIntegrator(object)
elif isinstance(object, State):
return XmlSerializer._serializeState(object)
raise ValueError("Unsupported object type")
@staticmethod
def deserialize(inputString):
"""Reconstruct an object that has been serialized as XML."""
# Look for the first tag to figure out what type of object it is.
import re
match = re.search("<([^?]\S*)", inputString)
if match is None:
raise ValueError("Invalid input string")
type = match.groups()[0]
if type == "System":
return XmlSerializer.deserializeSystem(inputString)
if type == "Force":
return XmlSerializer._deserializeForce(inputString)
if type == "Integrator":
return XmlSerializer._deserializeIntegrator(inputString)
if type == "State":
return XmlSerializer._deserializeState(inputString)
raise ValueError("Unsupported object type")
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment