Commit 4517491d authored by Jason Swails's avatar Jason Swails
Browse files

Switch to a deep copy on the tabulated functions in CustomNonbondedForce. This

necessitated creating a Copy method to do that on TabulatedFunction classes.
parent d3046049
......@@ -552,6 +552,10 @@ public:
}
FunctionInfo(const std::string& name, TabulatedFunction* function) : name(name), function(function) {
}
FunctionInfo Copy() const {
TabulatedFunction new_func = function->Copy();
return FunctionInfo(name, &new_func);
}
};
/**
......
......@@ -59,6 +59,9 @@ class OPENMM_EXPORT TabulatedFunction {
public:
virtual ~TabulatedFunction() {
}
TabulatedFunction Copy() const {
return TabulatedFunction();
}
};
/**
......@@ -96,6 +99,10 @@ public:
* @param max the value of x corresponding to the last element of values
*/
void setFunctionParameters(const std::vector<double>& values, double min, double max);
/**
* Create a deep copy of the tabulated function.
*/
Continuous1DFunction Copy() const;
private:
std::vector<double> values;
double min, max;
......@@ -151,6 +158,10 @@ public:
* @param ymax the value of y corresponding to the last element of values
*/
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax);
/**
* Create a deep copy of the tabulated function
*/
Continuous2DFunction Copy() const;
private:
std::vector<double> values;
int xsize, ysize;
......@@ -222,6 +233,10 @@ public:
* @param zmax the value of z corresponding to the last element of values
*/
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
/**
* Create a deep copy of the tabulated function
*/
Continuous3DFunction Copy() const;
private:
std::vector<double> values;
int xsize, ysize, zsize;
......@@ -253,6 +268,10 @@ public:
* @param values the tabulated values of the function f(x)
*/
void setFunctionParameters(const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete1DFunction Copy() const;
private:
std::vector<double> values;
};
......@@ -291,6 +310,10 @@ public:
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete2DFunction Copy() const;
private:
int xsize, ysize;
std::vector<double> values;
......@@ -333,6 +356,10 @@ public:
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values);
/**
* Create a deep copy of the tabulated function
*/
Discrete3DFunction Copy() const;
private:
int xsize, ysize, zsize;
std::vector<double> values;
......
......@@ -48,30 +48,29 @@ using std::stringstream;
using std::vector;
CustomNonbondedForce::CustomNonbondedForce(const string& energy) : energyExpression(energy), nonbondedMethod(NoCutoff), cutoffDistance(1.0),
switchingDistance(-1.0), useSwitchingFunction(false), useLongRangeCorrection(false), iOwnTabulatedFunctions(true) {
switchingDistance(-1.0), useSwitchingFunction(false), useLongRangeCorrection(false) {
}
CustomNonbondedForce::CustomNonbondedForce(const CustomNonbondedForce& rhs) {
// Copy everything, but the copy does *not* own the tabulated functions
// Copy everything and deep copy the tabulated functions
energyExpression = rhs.energyExpression;
nonbondedMethod = rhs.nonbondedMethod;
cutoffDistance = rhs.cutoffDistance;
switchingDistance = rhs.switchingDistance;
useSwitchingFunction = rhs.useSwitchingFunction;
useLongRangeCorrection = rhs.useLongRangeCorrection;
iOwnTabulatedFunctions = false;
parameters = rhs.parameters;
globalParameters = rhs.globalParameters;
particles = rhs.particles;
exclusions = rhs.exclusions;
functions = rhs.functions;
interactionGroups = rhs.interactionGroups;
for (vector<FunctionInfo>::const_iterator it = rhs.functions.begin(); it != rhs.functions.end(); it++)
functions.push_back(it->Copy());
}
CustomNonbondedForce::~CustomNonbondedForce() {
if (iOwnTabulatedFunctions)
for (int i = 0; i < (int) functions.size(); i++)
delete functions[i].function;
for (int i = 0; i < (int) functions.size(); i++)
delete functions[i].function;
}
const string& CustomNonbondedForce::getEnergyFunction() const {
......
......@@ -61,6 +61,13 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this->max = max;
}
Continuous1DFunction Continuous1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Continuous1DFunction(new_vec, min, max);
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
......@@ -107,6 +114,13 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
this->ymax = ymax;
}
Continuous2DFunction Continuous2DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
......@@ -166,6 +180,14 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize
this->zmax = zmax;
}
Continuous3DFunction Continuous3DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Continuous3DFunction(xsize, ysize, zsize, new_vec, xmin, xmax, ymin, ymax, zmin, zmax);
}
Discrete1DFunction::Discrete1DFunction(const vector<double>& values) {
this->values = values;
}
......@@ -178,6 +200,13 @@ void Discrete1DFunction::setFunctionParameters(const vector<double>& values) {
this->values = values;
}
Discrete1DFunction Discrete1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Discrete1DFunction(new_vec);
}
Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) {
if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values");
......@@ -200,6 +229,13 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto
this->values = values;
}
Discrete2DFunction Discrete2DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Discrete2DFunction(xsize, ysize, new_vec);
}
Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) {
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values");
......@@ -224,3 +260,10 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize,
this->zsize = zsize;
this->values = values;
}
Discrete3DFunction Discrete3DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return Discrete3DFunction(xsize, ysize, zsize, new_vec);
}
......@@ -145,35 +145,35 @@ class TestAmberPrmtopFile(unittest.TestCase):
totalMass2 = sum([system2.getParticleMass(i) for i in range(system2.getNumParticles())]).value_in_unit(amu)
self.assertAlmostEqual(totalMass1, totalMass2)
# def test_NBFIX_LongRange(self):
# """Test prmtop files with NBFIX LJ modifications w/ long-range correction"""
# system = prmtop3.createSystem(nonbondedMethod=PME,
# nonbondedCutoff=8*angstroms)
# # Check the forces
# has_nonbond_force = has_custom_nonbond_force = False
# nonbond_exceptions = custom_nonbond_exclusions = 0
# for force in system.getForces():
# if isinstance(force, NonbondedForce):
# has_nonbond_force = True
# nonbond_exceptions = force.getNumExceptions()
# elif isinstance(force, CustomNonbondedForce):
# has_custom_nonbond_force = True
# custom_nonbond_exceptions = force.getNumExclusions()
# self.assertTrue(has_nonbond_force)
# self.assertTrue(has_custom_nonbond_force)
# self.assertEqual(nonbond_exceptions, custom_nonbond_exceptions)
# integrator = VerletIntegrator(1.0*femtoseconds)
# # Use reference platform, since it should always be present and
# # 'working', and the system is plenty small so this won't be too slow
# sim = Simulation(prmtop3.topology, system, integrator, Platform.getPlatformByName('Reference'))
# # Check that the energy is about what we expect it to be
# sim.context.setPeriodicBoxVectors(*inpcrd3.boxVectors)
# sim.context.setPositions(inpcrd3.positions)
# ene = sim.context.getState(getEnergy=True, enforcePeriodicBox=True).getPotentialEnergy()
# ene = ene.value_in_unit(kilocalories_per_mole)
# # Make sure the energy is relatively close to the value we get with
# # Amber using this force field.
# self.assertAlmostEqual(-7099.44989739/ene, 1, places=3)
def test_NBFIX_LongRange(self):
"""Test prmtop files with NBFIX LJ modifications w/ long-range correction"""
system = prmtop3.createSystem(nonbondedMethod=PME,
nonbondedCutoff=8*angstroms)
# Check the forces
has_nonbond_force = has_custom_nonbond_force = False
nonbond_exceptions = custom_nonbond_exclusions = 0
for force in system.getForces():
if isinstance(force, NonbondedForce):
has_nonbond_force = True
nonbond_exceptions = force.getNumExceptions()
elif isinstance(force, CustomNonbondedForce):
has_custom_nonbond_force = True
custom_nonbond_exceptions = force.getNumExclusions()
self.assertTrue(has_nonbond_force)
self.assertTrue(has_custom_nonbond_force)
self.assertEqual(nonbond_exceptions, custom_nonbond_exceptions)
integrator = VerletIntegrator(1.0*femtoseconds)
# Use reference platform, since it should always be present and
# 'working', and the system is plenty small so this won't be too slow
sim = Simulation(prmtop3.topology, system, integrator, Platform.getPlatformByName('Reference'))
# Check that the energy is about what we expect it to be
sim.context.setPeriodicBoxVectors(*inpcrd3.boxVectors)
sim.context.setPositions(inpcrd3.positions)
ene = sim.context.getState(getEnergy=True, enforcePeriodicBox=True).getPotentialEnergy()
ene = ene.value_in_unit(kilocalories_per_mole)
# Make sure the energy is relatively close to the value we get with
# Amber using this force field.
self.assertAlmostEqual(-7099.44989739/ene, 1, places=3)
def test_NBFIX_noLongRange(self):
"""Test prmtop files with NBFIX LJ modifications w/out long-range correction"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment