Commit 8a8a2c66 authored by peastman's avatar peastman
Browse files

Merge pull request #539 from swails/fix_customnonbondedforce_segfault

Fix customnonbondedforce segfault
parents 14df6a03 80273800
...@@ -157,6 +157,7 @@ public: ...@@ -157,6 +157,7 @@ public:
* of r, the distance between them, as well as any global and per-particle parameters * of r, the distance between them, as well as any global and per-particle parameters
*/ */
explicit CustomNonbondedForce(const std::string& energy); explicit CustomNonbondedForce(const std::string& energy);
CustomNonbondedForce(const CustomNonbondedForce& rhs); // copy constructor
~CustomNonbondedForce(); ~CustomNonbondedForce();
/** /**
* Get the number of particles for which force field parameters have been defined. * Get the number of particles for which force field parameters have been defined.
...@@ -466,6 +467,7 @@ public: ...@@ -466,6 +467,7 @@ public:
protected: protected:
ForceImpl* createImpl() const; ForceImpl* createImpl() const;
private: private:
// REMEMBER TO UPDATE THE COPY CONSTRUCTOR IF YOU ADD ANY NEW FIELDS !!
class ParticleInfo; class ParticleInfo;
class PerParticleParameterInfo; class PerParticleParameterInfo;
class GlobalParameterInfo; class GlobalParameterInfo;
......
...@@ -59,6 +59,7 @@ class OPENMM_EXPORT TabulatedFunction { ...@@ -59,6 +59,7 @@ class OPENMM_EXPORT TabulatedFunction {
public: public:
virtual ~TabulatedFunction() { virtual ~TabulatedFunction() {
} }
virtual TabulatedFunction* Copy() const = 0;
}; };
/** /**
...@@ -96,6 +97,10 @@ public: ...@@ -96,6 +97,10 @@ public:
* @param max the value of x corresponding to the last element of values * @param max the value of x corresponding to the last element of values
*/ */
void setFunctionParameters(const std::vector<double>& values, double min, double max); void setFunctionParameters(const std::vector<double>& values, double min, double max);
/**
* Create a deep copy of the tabulated function.
*/
Continuous1DFunction* Copy() const;
private: private:
std::vector<double> values; std::vector<double> values;
double min, max; double min, max;
...@@ -151,6 +156,10 @@ public: ...@@ -151,6 +156,10 @@ public:
* @param ymax the value of y corresponding to the last element of values * @param ymax the value of y corresponding to the last element of values
*/ */
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax); void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax);
/**
* Create a deep copy of the tabulated function
*/
Continuous2DFunction* Copy() const;
private: private:
std::vector<double> values; std::vector<double> values;
int xsize, ysize; int xsize, ysize;
...@@ -222,6 +231,10 @@ public: ...@@ -222,6 +231,10 @@ public:
* @param zmax the value of z corresponding to the last element of values * @param zmax the value of z corresponding to the last element of values
*/ */
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax); void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
/**
* Create a deep copy of the tabulated function
*/
Continuous3DFunction* Copy() const;
private: private:
std::vector<double> values; std::vector<double> values;
int xsize, ysize, zsize; int xsize, ysize, zsize;
...@@ -253,6 +266,10 @@ public: ...@@ -253,6 +266,10 @@ public:
* @param values the tabulated values of the function f(x) * @param values the tabulated values of the function f(x)
*/ */
void setFunctionParameters(const std::vector<double>& values); void setFunctionParameters(const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete1DFunction* Copy() const;
private: private:
std::vector<double> values; std::vector<double> values;
}; };
...@@ -291,6 +308,10 @@ public: ...@@ -291,6 +308,10 @@ public:
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize. * values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/ */
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values); void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete2DFunction* Copy() const;
private: private:
int xsize, ysize; int xsize, ysize;
std::vector<double> values; std::vector<double> values;
...@@ -333,6 +354,10 @@ public: ...@@ -333,6 +354,10 @@ public:
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize. * values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/ */
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values); void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete3DFunction* Copy() const;
private: private:
int xsize, ysize, zsize; int xsize, ysize, zsize;
std::vector<double> values; std::vector<double> values;
......
...@@ -51,6 +51,23 @@ CustomNonbondedForce::CustomNonbondedForce(const string& energy) : energyExpress ...@@ -51,6 +51,23 @@ CustomNonbondedForce::CustomNonbondedForce(const string& energy) : energyExpress
switchingDistance(-1.0), useSwitchingFunction(false), useLongRangeCorrection(false) { switchingDistance(-1.0), useSwitchingFunction(false), useLongRangeCorrection(false) {
} }
CustomNonbondedForce::CustomNonbondedForce(const CustomNonbondedForce& rhs) {
// Copy everything and deep copy the tabulated functions
energyExpression = rhs.energyExpression;
nonbondedMethod = rhs.nonbondedMethod;
cutoffDistance = rhs.cutoffDistance;
switchingDistance = rhs.switchingDistance;
useSwitchingFunction = rhs.useSwitchingFunction;
useLongRangeCorrection = rhs.useLongRangeCorrection;
parameters = rhs.parameters;
globalParameters = rhs.globalParameters;
particles = rhs.particles;
exclusions = rhs.exclusions;
interactionGroups = rhs.interactionGroups;
for (vector<FunctionInfo>::const_iterator it = rhs.functions.begin(); it != rhs.functions.end(); it++)
functions.push_back(FunctionInfo(it->name, it->function->Copy()));
}
CustomNonbondedForce::~CustomNonbondedForce() { CustomNonbondedForce::~CustomNonbondedForce() {
for (int i = 0; i < (int) functions.size(); i++) for (int i = 0; i < (int) functions.size(); i++)
delete functions[i].function; delete functions[i].function;
......
...@@ -61,6 +61,13 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d ...@@ -61,6 +61,13 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this->max = max; this->max = max;
} }
Continuous1DFunction* Continuous1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Continuous1DFunction(new_vec, min, max);
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) { Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2) if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis"); throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
...@@ -107,6 +114,13 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec ...@@ -107,6 +114,13 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
this->ymax = ymax; this->ymax = ymax;
} }
Continuous2DFunction* Continuous2DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) { Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2) if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis"); throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
...@@ -166,6 +180,14 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize ...@@ -166,6 +180,14 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize
this->zmax = zmax; this->zmax = zmax;
} }
Continuous3DFunction* Continuous3DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Continuous3DFunction(xsize, ysize, zsize, new_vec, xmin, xmax, ymin, ymax, zmin, zmax);
}
Discrete1DFunction::Discrete1DFunction(const vector<double>& values) { Discrete1DFunction::Discrete1DFunction(const vector<double>& values) {
this->values = values; this->values = values;
} }
...@@ -178,6 +200,13 @@ void Discrete1DFunction::setFunctionParameters(const vector<double>& values) { ...@@ -178,6 +200,13 @@ void Discrete1DFunction::setFunctionParameters(const vector<double>& values) {
this->values = values; this->values = values;
} }
Discrete1DFunction* Discrete1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Discrete1DFunction(new_vec);
}
Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) { Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) {
if (values.size() != xsize*ysize) if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values"); throw OpenMMException("Discrete2DFunction: incorrect number of values");
...@@ -200,6 +229,13 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto ...@@ -200,6 +229,13 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto
this->values = values; this->values = values;
} }
Discrete2DFunction* Discrete2DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Discrete2DFunction(xsize, ysize, new_vec);
}
Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) { Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) {
if (values.size() != xsize*ysize*zsize) if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values"); throw OpenMMException("Discrete3DFunction: incorrect number of values");
...@@ -224,3 +260,10 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, ...@@ -224,3 +260,10 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize,
this->zsize = zsize; this->zsize = zsize;
this->values = values; this->values = values;
} }
Discrete3DFunction* Discrete3DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Discrete3DFunction(xsize, ysize, zsize, new_vec);
}
...@@ -145,35 +145,35 @@ class TestAmberPrmtopFile(unittest.TestCase): ...@@ -145,35 +145,35 @@ class TestAmberPrmtopFile(unittest.TestCase):
totalMass2 = sum([system2.getParticleMass(i) for i in range(system2.getNumParticles())]).value_in_unit(amu) totalMass2 = sum([system2.getParticleMass(i) for i in range(system2.getNumParticles())]).value_in_unit(amu)
self.assertAlmostEqual(totalMass1, totalMass2) self.assertAlmostEqual(totalMass1, totalMass2)
# def test_NBFIX_LongRange(self): def test_NBFIX_LongRange(self):
# """Test prmtop files with NBFIX LJ modifications w/ long-range correction""" """Test prmtop files with NBFIX LJ modifications w/ long-range correction"""
# system = prmtop3.createSystem(nonbondedMethod=PME, system = prmtop3.createSystem(nonbondedMethod=PME,
# nonbondedCutoff=8*angstroms) nonbondedCutoff=8*angstroms)
# # Check the forces # Check the forces
# has_nonbond_force = has_custom_nonbond_force = False has_nonbond_force = has_custom_nonbond_force = False
# nonbond_exceptions = custom_nonbond_exclusions = 0 nonbond_exceptions = custom_nonbond_exclusions = 0
# for force in system.getForces(): for force in system.getForces():
# if isinstance(force, NonbondedForce): if isinstance(force, NonbondedForce):
# has_nonbond_force = True has_nonbond_force = True
# nonbond_exceptions = force.getNumExceptions() nonbond_exceptions = force.getNumExceptions()
# elif isinstance(force, CustomNonbondedForce): elif isinstance(force, CustomNonbondedForce):
# has_custom_nonbond_force = True has_custom_nonbond_force = True
# custom_nonbond_exceptions = force.getNumExclusions() custom_nonbond_exceptions = force.getNumExclusions()
# self.assertTrue(has_nonbond_force) self.assertTrue(has_nonbond_force)
# self.assertTrue(has_custom_nonbond_force) self.assertTrue(has_custom_nonbond_force)
# self.assertEqual(nonbond_exceptions, custom_nonbond_exceptions) self.assertEqual(nonbond_exceptions, custom_nonbond_exceptions)
# integrator = VerletIntegrator(1.0*femtoseconds) integrator = VerletIntegrator(1.0*femtoseconds)
# # Use reference platform, since it should always be present and # Use reference platform, since it should always be present and
# # 'working', and the system is plenty small so this won't be too slow # 'working', and the system is plenty small so this won't be too slow
# sim = Simulation(prmtop3.topology, system, integrator, Platform.getPlatformByName('Reference')) sim = Simulation(prmtop3.topology, system, integrator, Platform.getPlatformByName('Reference'))
# # Check that the energy is about what we expect it to be # Check that the energy is about what we expect it to be
# sim.context.setPeriodicBoxVectors(*inpcrd3.boxVectors) sim.context.setPeriodicBoxVectors(*inpcrd3.boxVectors)
# sim.context.setPositions(inpcrd3.positions) sim.context.setPositions(inpcrd3.positions)
# ene = sim.context.getState(getEnergy=True, enforcePeriodicBox=True).getPotentialEnergy() ene = sim.context.getState(getEnergy=True, enforcePeriodicBox=True).getPotentialEnergy()
# ene = ene.value_in_unit(kilocalories_per_mole) ene = ene.value_in_unit(kilocalories_per_mole)
# # Make sure the energy is relatively close to the value we get with # Make sure the energy is relatively close to the value we get with
# # Amber using this force field. # Amber using this force field.
# self.assertAlmostEqual(-7099.44989739/ene, 1, places=3) self.assertAlmostEqual(-7099.44989739/ene, 1, places=3)
def test_NBFIX_noLongRange(self): def test_NBFIX_noLongRange(self):
"""Test prmtop files with NBFIX LJ modifications w/out long-range correction""" """Test prmtop files with NBFIX LJ modifications w/out long-range correction"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment