Commit 4517491d authored by Jason Swails's avatar Jason Swails
Browse files

Switch to a deep copy on the tabulated functions in CustomNonbondedForce. This

necessitated creating a Copy method to do that on TabulatedFunction classes.
parent d3046049
...@@ -552,6 +552,10 @@ public: ...@@ -552,6 +552,10 @@ public:
} }
FunctionInfo(const std::string& name, TabulatedFunction* function) : name(name), function(function) { FunctionInfo(const std::string& name, TabulatedFunction* function) : name(name), function(function) {
} }
FunctionInfo Copy() const {
TabulatedFunction new_func = function->Copy();
return FunctionInfo(name, &new_func);
}
}; };
/** /**
......
...@@ -59,6 +59,9 @@ class OPENMM_EXPORT TabulatedFunction { ...@@ -59,6 +59,9 @@ class OPENMM_EXPORT TabulatedFunction {
public: public:
virtual ~TabulatedFunction() { virtual ~TabulatedFunction() {
} }
TabulatedFunction Copy() const {
return TabulatedFunction();
}
}; };
/** /**
...@@ -96,6 +99,10 @@ public: ...@@ -96,6 +99,10 @@ public:
* @param max the value of x corresponding to the last element of values * @param max the value of x corresponding to the last element of values
*/ */
void setFunctionParameters(const std::vector<double>& values, double min, double max); void setFunctionParameters(const std::vector<double>& values, double min, double max);
/**
* Create a deep copy of the tabulated function.
*/
Continuous1DFunction Copy() const;
private: private:
std::vector<double> values; std::vector<double> values;
double min, max; double min, max;
...@@ -151,6 +158,10 @@ public: ...@@ -151,6 +158,10 @@ public:
* @param ymax the value of y corresponding to the last element of values * @param ymax the value of y corresponding to the last element of values
*/ */
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax); void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax);
/**
* Create a deep copy of the tabulated function
*/
Continuous2DFunction Copy() const;
private: private:
std::vector<double> values; std::vector<double> values;
int xsize, ysize; int xsize, ysize;
...@@ -222,6 +233,10 @@ public: ...@@ -222,6 +233,10 @@ public:
* @param zmax the value of z corresponding to the last element of values * @param zmax the value of z corresponding to the last element of values
*/ */
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax); void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
/**
* Create a deep copy of the tabulated function
*/
Continuous3DFunction Copy() const;
private: private:
std::vector<double> values; std::vector<double> values;
int xsize, ysize, zsize; int xsize, ysize, zsize;
...@@ -253,6 +268,10 @@ public: ...@@ -253,6 +268,10 @@ public:
* @param values the tabulated values of the function f(x) * @param values the tabulated values of the function f(x)
*/ */
void setFunctionParameters(const std::vector<double>& values); void setFunctionParameters(const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete1DFunction Copy() const;
private: private:
std::vector<double> values; std::vector<double> values;
}; };
...@@ -291,6 +310,10 @@ public: ...@@ -291,6 +310,10 @@ public:
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize. * values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/ */
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values); void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete2DFunction Copy() const;
private: private:
int xsize, ysize; int xsize, ysize;
std::vector<double> values; std::vector<double> values;
...@@ -333,6 +356,10 @@ public: ...@@ -333,6 +356,10 @@ public:
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize. * values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/ */
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values); void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete3DFunction Copy() const;
private: private:
int xsize, ysize, zsize; int xsize, ysize, zsize;
std::vector<double> values; std::vector<double> values;
......
...@@ -48,28 +48,27 @@ using std::stringstream; ...@@ -48,28 +48,27 @@ using std::stringstream;
using std::vector; using std::vector;
CustomNonbondedForce::CustomNonbondedForce(const string& energy) : energyExpression(energy), nonbondedMethod(NoCutoff), cutoffDistance(1.0), CustomNonbondedForce::CustomNonbondedForce(const string& energy) : energyExpression(energy), nonbondedMethod(NoCutoff), cutoffDistance(1.0),
switchingDistance(-1.0), useSwitchingFunction(false), useLongRangeCorrection(false), iOwnTabulatedFunctions(true) { switchingDistance(-1.0), useSwitchingFunction(false), useLongRangeCorrection(false) {
} }
CustomNonbondedForce::CustomNonbondedForce(const CustomNonbondedForce& rhs) { CustomNonbondedForce::CustomNonbondedForce(const CustomNonbondedForce& rhs) {
// Copy everything, but the copy does *not* own the tabulated functions // Copy everything and deep copy the tabulated functions
energyExpression = rhs.energyExpression; energyExpression = rhs.energyExpression;
nonbondedMethod = rhs.nonbondedMethod; nonbondedMethod = rhs.nonbondedMethod;
cutoffDistance = rhs.cutoffDistance; cutoffDistance = rhs.cutoffDistance;
switchingDistance = rhs.switchingDistance; switchingDistance = rhs.switchingDistance;
useSwitchingFunction = rhs.useSwitchingFunction; useSwitchingFunction = rhs.useSwitchingFunction;
useLongRangeCorrection = rhs.useLongRangeCorrection; useLongRangeCorrection = rhs.useLongRangeCorrection;
iOwnTabulatedFunctions = false;
parameters = rhs.parameters; parameters = rhs.parameters;
globalParameters = rhs.globalParameters; globalParameters = rhs.globalParameters;
particles = rhs.particles; particles = rhs.particles;
exclusions = rhs.exclusions; exclusions = rhs.exclusions;
functions = rhs.functions;
interactionGroups = rhs.interactionGroups; interactionGroups = rhs.interactionGroups;
for (vector<FunctionInfo>::const_iterator it = rhs.functions.begin(); it != rhs.functions.end(); it++)
functions.push_back(it->Copy());
} }
CustomNonbondedForce::~CustomNonbondedForce() { CustomNonbondedForce::~CustomNonbondedForce() {
if (iOwnTabulatedFunctions)
for (int i = 0; i < (int) functions.size(); i++) for (int i = 0; i < (int) functions.size(); i++)
delete functions[i].function; delete functions[i].function;
} }
......
...@@ -61,6 +61,13 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d ...@@ -61,6 +61,13 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this->max = max; this->max = max;
} }
Continuous1DFunction Continuous1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Continuous1DFunction(new_vec, min, max);
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) { Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2) if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis"); throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
...@@ -107,6 +114,13 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec ...@@ -107,6 +114,13 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
this->ymax = ymax; this->ymax = ymax;
} }
Continuous2DFunction Continuous2DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) { Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2) if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis"); throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
...@@ -166,6 +180,14 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize ...@@ -166,6 +180,14 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize
this->zmax = zmax; this->zmax = zmax;
} }
Continuous3DFunction Continuous3DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Continuous3DFunction(xsize, ysize, zsize, new_vec, xmin, xmax, ymin, ymax, zmin, zmax);
}
Discrete1DFunction::Discrete1DFunction(const vector<double>& values) { Discrete1DFunction::Discrete1DFunction(const vector<double>& values) {
this->values = values; this->values = values;
} }
...@@ -178,6 +200,13 @@ void Discrete1DFunction::setFunctionParameters(const vector<double>& values) { ...@@ -178,6 +200,13 @@ void Discrete1DFunction::setFunctionParameters(const vector<double>& values) {
this->values = values; this->values = values;
} }
Discrete1DFunction Discrete1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Discrete1DFunction(new_vec);
}
Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) { Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) {
if (values.size() != xsize*ysize) if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values"); throw OpenMMException("Discrete2DFunction: incorrect number of values");
...@@ -200,6 +229,13 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto ...@@ -200,6 +229,13 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto
this->values = values; this->values = values;
} }
Discrete2DFunction Discrete2DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Discrete2DFunction(xsize, ysize, new_vec);
}
Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) { Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) {
if (values.size() != xsize*ysize*zsize) if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values"); throw OpenMMException("Discrete3DFunction: incorrect number of values");
...@@ -224,3 +260,10 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, ...@@ -224,3 +260,10 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize,
this->zsize = zsize; this->zsize = zsize;
this->values = values; this->values = values;
} }
Discrete3DFunction Discrete3DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Discrete3DFunction(xsize, ysize, zsize, new_vec);
}
...@@ -145,35 +145,35 @@ class TestAmberPrmtopFile(unittest.TestCase): ...@@ -145,35 +145,35 @@ class TestAmberPrmtopFile(unittest.TestCase):
totalMass2 = sum([system2.getParticleMass(i) for i in range(system2.getNumParticles())]).value_in_unit(amu) totalMass2 = sum([system2.getParticleMass(i) for i in range(system2.getNumParticles())]).value_in_unit(amu)
self.assertAlmostEqual(totalMass1, totalMass2) self.assertAlmostEqual(totalMass1, totalMass2)
# def test_NBFIX_LongRange(self): def test_NBFIX_LongRange(self):
# """Test prmtop files with NBFIX LJ modifications w/ long-range correction""" """Test prmtop files with NBFIX LJ modifications w/ long-range correction"""
# system = prmtop3.createSystem(nonbondedMethod=PME, system = prmtop3.createSystem(nonbondedMethod=PME,
# nonbondedCutoff=8*angstroms) nonbondedCutoff=8*angstroms)
# # Check the forces # Check the forces
# has_nonbond_force = has_custom_nonbond_force = False has_nonbond_force = has_custom_nonbond_force = False
# nonbond_exceptions = custom_nonbond_exclusions = 0 nonbond_exceptions = custom_nonbond_exclusions = 0
# for force in system.getForces(): for force in system.getForces():
# if isinstance(force, NonbondedForce): if isinstance(force, NonbondedForce):
# has_nonbond_force = True has_nonbond_force = True
# nonbond_exceptions = force.getNumExceptions() nonbond_exceptions = force.getNumExceptions()
# elif isinstance(force, CustomNonbondedForce): elif isinstance(force, CustomNonbondedForce):
# has_custom_nonbond_force = True has_custom_nonbond_force = True
# custom_nonbond_exceptions = force.getNumExclusions() custom_nonbond_exceptions = force.getNumExclusions()
# self.assertTrue(has_nonbond_force) self.assertTrue(has_nonbond_force)
# self.assertTrue(has_custom_nonbond_force) self.assertTrue(has_custom_nonbond_force)
# self.assertEqual(nonbond_exceptions, custom_nonbond_exceptions) self.assertEqual(nonbond_exceptions, custom_nonbond_exceptions)
# integrator = VerletIntegrator(1.0*femtoseconds) integrator = VerletIntegrator(1.0*femtoseconds)
# # Use reference platform, since it should always be present and # Use reference platform, since it should always be present and
# # 'working', and the system is plenty small so this won't be too slow # 'working', and the system is plenty small so this won't be too slow
# sim = Simulation(prmtop3.topology, system, integrator, Platform.getPlatformByName('Reference')) sim = Simulation(prmtop3.topology, system, integrator, Platform.getPlatformByName('Reference'))
# # Check that the energy is about what we expect it to be # Check that the energy is about what we expect it to be
# sim.context.setPeriodicBoxVectors(*inpcrd3.boxVectors) sim.context.setPeriodicBoxVectors(*inpcrd3.boxVectors)
# sim.context.setPositions(inpcrd3.positions) sim.context.setPositions(inpcrd3.positions)
# ene = sim.context.getState(getEnergy=True, enforcePeriodicBox=True).getPotentialEnergy() ene = sim.context.getState(getEnergy=True, enforcePeriodicBox=True).getPotentialEnergy()
# ene = ene.value_in_unit(kilocalories_per_mole) ene = ene.value_in_unit(kilocalories_per_mole)
# # Make sure the energy is relatively close to the value we get with # Make sure the energy is relatively close to the value we get with
# # Amber using this force field. # Amber using this force field.
# self.assertAlmostEqual(-7099.44989739/ene, 1, places=3) self.assertAlmostEqual(-7099.44989739/ene, 1, places=3)
def test_NBFIX_noLongRange(self): def test_NBFIX_noLongRange(self):
"""Test prmtop files with NBFIX LJ modifications w/out long-range correction""" """Test prmtop files with NBFIX LJ modifications w/out long-range correction"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment