Commit 7c86e874 authored by peastman's avatar peastman
Browse files

Support copy() and deepcopy() for TabulatedFunctions

parent 54b39e91
...@@ -59,6 +59,9 @@ class OPENMM_EXPORT TabulatedFunction { ...@@ -59,6 +59,9 @@ class OPENMM_EXPORT TabulatedFunction {
public: public:
virtual ~TabulatedFunction() { virtual ~TabulatedFunction() {
} }
/**
* @deprecated This will be removed in a future release.
*/
virtual TabulatedFunction* Copy() const = 0; virtual TabulatedFunction* Copy() const = 0;
}; };
...@@ -99,6 +102,8 @@ public: ...@@ -99,6 +102,8 @@ public:
void setFunctionParameters(const std::vector<double>& values, double min, double max); void setFunctionParameters(const std::vector<double>& values, double min, double max);
/** /**
* Create a deep copy of the tabulated function. * Create a deep copy of the tabulated function.
*
* @deprecated This will be removed in a future release.
*/ */
Continuous1DFunction* Copy() const; Continuous1DFunction* Copy() const;
private: private:
...@@ -158,6 +163,8 @@ public: ...@@ -158,6 +163,8 @@ public:
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax); void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax);
/** /**
* Create a deep copy of the tabulated function * Create a deep copy of the tabulated function
*
* @deprecated This will be removed in a future release.
*/ */
Continuous2DFunction* Copy() const; Continuous2DFunction* Copy() const;
private: private:
...@@ -233,6 +240,8 @@ public: ...@@ -233,6 +240,8 @@ public:
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax); void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
/** /**
* Create a deep copy of the tabulated function * Create a deep copy of the tabulated function
*
* @deprecated This will be removed in a future release.
*/ */
Continuous3DFunction* Copy() const; Continuous3DFunction* Copy() const;
private: private:
...@@ -268,6 +277,8 @@ public: ...@@ -268,6 +277,8 @@ public:
void setFunctionParameters(const std::vector<double>& values); void setFunctionParameters(const std::vector<double>& values);
/** /**
* Create a deep copy of the tabulated function * Create a deep copy of the tabulated function
*
* @deprecated This will be removed in a future release.
*/ */
Discrete1DFunction* Copy() const; Discrete1DFunction* Copy() const;
private: private:
...@@ -310,6 +321,8 @@ public: ...@@ -310,6 +321,8 @@ public:
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values); void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values);
/** /**
* Create a deep copy of the tabulated function * Create a deep copy of the tabulated function
*
* @deprecated This will be removed in a future release.
*/ */
Discrete2DFunction* Copy() const; Discrete2DFunction* Copy() const;
private: private:
...@@ -356,6 +369,8 @@ public: ...@@ -356,6 +369,8 @@ public:
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values); void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values);
/** /**
* Create a deep copy of the tabulated function * Create a deep copy of the tabulated function
*
* @deprecated This will be removed in a future release.
*/ */
Discrete3DFunction* Copy() const; Discrete3DFunction* Copy() const;
private: private:
......
...@@ -242,6 +242,7 @@ class SwigInputBuilder: ...@@ -242,6 +242,7 @@ class SwigInputBuilder:
self.fOut.write("\n/* Declare factories */\n\n") self.fOut.write("\n/* Declare factories */\n\n")
forceSubclassList = [] forceSubclassList = []
integratorSubclassList = [] integratorSubclassList = []
tabulatedFunctionSubclassList = []
for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"): for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
className = getText("compoundname", classNode) className = getText("compoundname", classNode)
shortClassName=stripOpenmmPrefix(className) shortClassName=stripOpenmmPrefix(className)
...@@ -256,6 +257,8 @@ class SwigInputBuilder: ...@@ -256,6 +257,8 @@ class SwigInputBuilder:
forceSubclassList.append(shortClassName) forceSubclassList.append(shortClassName)
elif baseName == 'OpenMM::Integrator': elif baseName == 'OpenMM::Integrator':
integratorSubclassList.append(shortClassName) integratorSubclassList.append(shortClassName)
elif baseName == 'OpenMM::TabulatedFunction':
tabulatedFunctionSubclassList.append(shortClassName)
self.fOut.write("%factory(OpenMM::Force& OpenMM::System::getForce") self.fOut.write("%factory(OpenMM::Force& OpenMM::System::getForce")
for name in sorted(forceSubclassList): for name in sorted(forceSubclassList):
...@@ -292,6 +295,16 @@ class SwigInputBuilder: ...@@ -292,6 +295,16 @@ class SwigInputBuilder:
self.fOut.write(",\n OpenMM::%s" % name) self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n") self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM::TabulatedFunction::__copy__")
for name in sorted(tabulatedFunctionSubclassList):
self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM_XmlSerializer__deserializeTabulatedFunction")
for name in sorted(tabulatedFunctionSubclassList):
self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::VirtualSite& OpenMM::System::getVirtualSite, OpenMM::TwoParticleAverageSite, OpenMM::ThreeParticleAverageSite, OpenMM::OutOfPlaneSite, OpenMM::LocalCoordinatesSite);\n\n") self.fOut.write("%factory(OpenMM::VirtualSite& OpenMM::System::getVirtualSite, OpenMM::TwoParticleAverageSite, OpenMM::ThreeParticleAverageSite, OpenMM::OutOfPlaneSite, OpenMM::LocalCoordinatesSite);\n\n")
self.fOut.write("\n") self.fOut.write("\n")
......
...@@ -363,6 +363,19 @@ Parameters: ...@@ -363,6 +363,19 @@ Parameters:
return OpenMM::XmlSerializer::deserialize<OpenMM::Integrator>(ss); return OpenMM::XmlSerializer::deserialize<OpenMM::Integrator>(ss);
} }
static std::string _serializeTabulatedFunction(const OpenMM::TabulatedFunction* object) {
std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::TabulatedFunction>(object, "TabulatedFunction", ss);
return ss.str();
}
%newobject _deserializeTabulatedFunction;
static OpenMM::TabulatedFunction* _deserializeTabulatedFunction(const char* inputString) {
std::stringstream ss;
ss << inputString;
return OpenMM::XmlSerializer::deserialize<OpenMM::TabulatedFunction>(ss);
}
static std::string _serializeStateAsLists( static std::string _serializeStateAsLists(
const std::vector<Vec3>& pos, const std::vector<Vec3>& pos,
const std::vector<Vec3>& vel, const std::vector<Vec3>& vel,
...@@ -463,6 +476,8 @@ Parameters: ...@@ -463,6 +476,8 @@ Parameters:
return XmlSerializer._serializeIntegrator(object) return XmlSerializer._serializeIntegrator(object)
elif isinstance(object, State): elif isinstance(object, State):
return XmlSerializer._serializeState(object) return XmlSerializer._serializeState(object)
elif isinstance(object, TabulatedFunction):
return XmlSerializer._serializeTabulatedFunction(object)
raise ValueError("Unsupported object type") raise ValueError("Unsupported object type")
@staticmethod @staticmethod
...@@ -481,6 +496,8 @@ Parameters: ...@@ -481,6 +496,8 @@ Parameters:
return XmlSerializer._deserializeIntegrator(inputString) return XmlSerializer._deserializeIntegrator(inputString)
if type == "State": if type == "State":
return XmlSerializer._deserializeState(inputString) return XmlSerializer._deserializeState(inputString)
if type == "TabulatedFunction":
return XmlSerializer._deserializeTabulatedFunction(inputString)
raise ValueError("Unsupported object type") raise ValueError("Unsupported object type")
%} %}
} }
...@@ -530,3 +547,22 @@ Parameters: ...@@ -530,3 +547,22 @@ Parameters:
return OpenMM::XmlSerializer::clone<OpenMM::Integrator>(*self); return OpenMM::XmlSerializer::clone<OpenMM::Integrator>(*self);
} }
} }
%extend OpenMM::TabulatedFunction {
%pythoncode %{
def __getstate__(self):
serializationString = XmlSerializer.serialize(self)
return serializationString
def __setstate__(self, serializationString):
system = XmlSerializer.deserialize(serializationString)
self.this = system.this
def __deepcopy__(self, memo):
return self.__copy__()
%}
%newobject __copy__;
OpenMM::TabulatedFunction* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::TabulatedFunction>(*self);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment