Commit 8a8a2c66 authored by peastman's avatar peastman
Browse files

Merge pull request #539 from swails/fix_customnonbondedforce_segfault

Fix customnonbondedforce segfault
parents 14df6a03 80273800
......@@ -157,6 +157,7 @@ public:
* of r, the distance between them, as well as any global and per-particle parameters
*/
explicit CustomNonbondedForce(const std::string& energy);
CustomNonbondedForce(const CustomNonbondedForce& rhs); // copy constructor
~CustomNonbondedForce();
/**
* Get the number of particles for which force field parameters have been defined.
......@@ -466,6 +467,7 @@ public:
protected:
ForceImpl* createImpl() const;
private:
// REMEMBER TO UPDATE THE COPY CONSTRUCTOR IF YOU ADD ANY NEW FIELDS !!
class ParticleInfo;
class PerParticleParameterInfo;
class GlobalParameterInfo;
......
......@@ -59,6 +59,7 @@ class OPENMM_EXPORT TabulatedFunction {
public:
virtual ~TabulatedFunction() {
}
virtual TabulatedFunction* Copy() const = 0;
};
/**
......@@ -96,6 +97,10 @@ public:
* @param max the value of x corresponding to the last element of values
*/
void setFunctionParameters(const std::vector<double>& values, double min, double max);
/**
* Create a deep copy of the tabulated function.
*/
Continuous1DFunction* Copy() const;
private:
std::vector<double> values;
double min, max;
......@@ -151,6 +156,10 @@ public:
* @param ymax the value of y corresponding to the last element of values
*/
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax);
/**
* Create a deep copy of the tabulated function
*/
Continuous2DFunction* Copy() const;
private:
std::vector<double> values;
int xsize, ysize;
......@@ -222,6 +231,10 @@ public:
* @param zmax the value of z corresponding to the last element of values
*/
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
/**
* Create a deep copy of the tabulated function
*/
Continuous3DFunction* Copy() const;
private:
std::vector<double> values;
int xsize, ysize, zsize;
......@@ -253,6 +266,10 @@ public:
* @param values the tabulated values of the function f(x)
*/
void setFunctionParameters(const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete1DFunction* Copy() const;
private:
std::vector<double> values;
};
......@@ -291,6 +308,10 @@ public:
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete2DFunction* Copy() const;
private:
int xsize, ysize;
std::vector<double> values;
......@@ -333,6 +354,10 @@ public:
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete3DFunction* Copy() const;
private:
int xsize, ysize, zsize;
std::vector<double> values;
......
......@@ -51,6 +51,23 @@ CustomNonbondedForce::CustomNonbondedForce(const string& energy) : energyExpress
switchingDistance(-1.0), useSwitchingFunction(false), useLongRangeCorrection(false) {
}
CustomNonbondedForce::CustomNonbondedForce(const CustomNonbondedForce& rhs) {
// Copy everything and deep copy the tabulated functions
energyExpression = rhs.energyExpression;
nonbondedMethod = rhs.nonbondedMethod;
cutoffDistance = rhs.cutoffDistance;
switchingDistance = rhs.switchingDistance;
useSwitchingFunction = rhs.useSwitchingFunction;
useLongRangeCorrection = rhs.useLongRangeCorrection;
parameters = rhs.parameters;
globalParameters = rhs.globalParameters;
particles = rhs.particles;
exclusions = rhs.exclusions;
interactionGroups = rhs.interactionGroups;
for (vector<FunctionInfo>::const_iterator it = rhs.functions.begin(); it != rhs.functions.end(); it++)
functions.push_back(FunctionInfo(it->name, it->function->Copy()));
}
CustomNonbondedForce::~CustomNonbondedForce() {
for (int i = 0; i < (int) functions.size(); i++)
delete functions[i].function;
......
......@@ -61,6 +61,13 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this->max = max;
}
Continuous1DFunction* Continuous1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Continuous1DFunction(new_vec, min, max);
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
......@@ -107,6 +114,13 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
this->ymax = ymax;
}
Continuous2DFunction* Continuous2DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
......@@ -166,6 +180,14 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize
this->zmax = zmax;
}
Continuous3DFunction* Continuous3DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Continuous3DFunction(xsize, ysize, zsize, new_vec, xmin, xmax, ymin, ymax, zmin, zmax);
}
Discrete1DFunction::Discrete1DFunction(const vector<double>& values) {
this->values = values;
}
......@@ -178,6 +200,13 @@ void Discrete1DFunction::setFunctionParameters(const vector<double>& values) {
this->values = values;
}
Discrete1DFunction* Discrete1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Discrete1DFunction(new_vec);
}
Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) {
if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values");
......@@ -200,6 +229,13 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto
this->values = values;
}
Discrete2DFunction* Discrete2DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Discrete2DFunction(xsize, ysize, new_vec);
}
Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) {
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values");
......@@ -224,3 +260,10 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize,
this->zsize = zsize;
this->values = values;
}
Discrete3DFunction* Discrete3DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new Discrete3DFunction(xsize, ysize, zsize, new_vec);
}
......@@ -145,35 +145,35 @@ class TestAmberPrmtopFile(unittest.TestCase):
totalMass2 = sum([system2.getParticleMass(i) for i in range(system2.getNumParticles())]).value_in_unit(amu)
self.assertAlmostEqual(totalMass1, totalMass2)
# def test_NBFIX_LongRange(self):
# """Test prmtop files with NBFIX LJ modifications w/ long-range correction"""
# system = prmtop3.createSystem(nonbondedMethod=PME,
# nonbondedCutoff=8*angstroms)
# # Check the forces
# has_nonbond_force = has_custom_nonbond_force = False
# nonbond_exceptions = custom_nonbond_exclusions = 0
# for force in system.getForces():
# if isinstance(force, NonbondedForce):
# has_nonbond_force = True
# nonbond_exceptions = force.getNumExceptions()
# elif isinstance(force, CustomNonbondedForce):
# has_custom_nonbond_force = True
# custom_nonbond_exceptions = force.getNumExclusions()
# self.assertTrue(has_nonbond_force)
# self.assertTrue(has_custom_nonbond_force)
# self.assertEqual(nonbond_exceptions, custom_nonbond_exceptions)
# integrator = VerletIntegrator(1.0*femtoseconds)
# # Use reference platform, since it should always be present and
# # 'working', and the system is plenty small so this won't be too slow
# sim = Simulation(prmtop3.topology, system, integrator, Platform.getPlatformByName('Reference'))
# # Check that the energy is about what we expect it to be
# sim.context.setPeriodicBoxVectors(*inpcrd3.boxVectors)
# sim.context.setPositions(inpcrd3.positions)
# ene = sim.context.getState(getEnergy=True, enforcePeriodicBox=True).getPotentialEnergy()
# ene = ene.value_in_unit(kilocalories_per_mole)
# # Make sure the energy is relatively close to the value we get with
# # Amber using this force field.
# self.assertAlmostEqual(-7099.44989739/ene, 1, places=3)
def test_NBFIX_LongRange(self):
"""Test prmtop files with NBFIX LJ modifications w/ long-range correction"""
system = prmtop3.createSystem(nonbondedMethod=PME,
nonbondedCutoff=8*angstroms)
# Check the forces
has_nonbond_force = has_custom_nonbond_force = False
nonbond_exceptions = custom_nonbond_exclusions = 0
for force in system.getForces():
if isinstance(force, NonbondedForce):
has_nonbond_force = True
nonbond_exceptions = force.getNumExceptions()
elif isinstance(force, CustomNonbondedForce):
has_custom_nonbond_force = True
custom_nonbond_exceptions = force.getNumExclusions()
self.assertTrue(has_nonbond_force)
self.assertTrue(has_custom_nonbond_force)
self.assertEqual(nonbond_exceptions, custom_nonbond_exceptions)
integrator = VerletIntegrator(1.0*femtoseconds)
# Use reference platform, since it should always be present and
# 'working', and the system is plenty small so this won't be too slow
sim = Simulation(prmtop3.topology, system, integrator, Platform.getPlatformByName('Reference'))
# Check that the energy is about what we expect it to be
sim.context.setPeriodicBoxVectors(*inpcrd3.boxVectors)
sim.context.setPositions(inpcrd3.positions)
ene = sim.context.getState(getEnergy=True, enforcePeriodicBox=True).getPotentialEnergy()
ene = ene.value_in_unit(kilocalories_per_mole)
# Make sure the energy is relatively close to the value we get with
# Amber using this force field.
self.assertAlmostEqual(-7099.44989739/ene, 1, places=3)
def test_NBFIX_noLongRange(self):
"""Test prmtop files with NBFIX LJ modifications w/out long-range correction"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment