Commit 59b0c5fa authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented updateParametersInContext() for NonbondedForce

parent 895f8dac
...@@ -53,7 +53,7 @@ void testTransform() { ...@@ -53,7 +53,7 @@ void testTransform() {
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(system, "", ""); OpenCLPlatform::PlatformData platformData(system, "", "");
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(system); context.initialize();
OpenMM_SFMT::SFMT sfmt; OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt); init_gen_rand(0, sfmt);
int xsize = 32, ysize = 25, zsize = 30; int xsize = 32, ysize = 25, zsize = 30;
......
...@@ -128,6 +128,8 @@ void testExclusionsAnd14() { ...@@ -128,6 +128,8 @@ void testExclusionsAnd14() {
second14 = i; second14 = i;
} }
system.addForce(nonbonded); system.addForce(nonbonded);
LangevinIntegrator integrator(0.0, 0.1, 0.01);
Context context(system, integrator, platform);
for (int i = 1; i < 5; ++i) { for (int i = 1; i < 5; ++i) {
// Test LJ forces // Test LJ forces
...@@ -143,8 +145,7 @@ void testExclusionsAnd14() { ...@@ -143,8 +145,7 @@ void testExclusionsAnd14() {
nonbonded->setExceptionParameters(first14, 0, 3, 0, 1.5, i == 3 ? 0.5 : 0.0); nonbonded->setExceptionParameters(first14, 0, 3, 0, 1.5, i == 3 ? 0.5 : 0.0);
nonbonded->setExceptionParameters(second14, 1, 4, 0, 1.5, 0.0); nonbonded->setExceptionParameters(second14, 1, 4, 0, 1.5, 0.0);
positions[i] = Vec3(r, 0, 0); positions[i] = Vec3(r, 0, 0);
LangevinIntegrator integrator(0.0, 0.1, 0.01); context.reinitialize();
Context context(system, integrator, platform);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
...@@ -170,10 +171,9 @@ void testExclusionsAnd14() { ...@@ -170,10 +171,9 @@ void testExclusionsAnd14() {
nonbonded->setParticleParameters(i, 2, 1.5, 0); nonbonded->setParticleParameters(i, 2, 1.5, 0);
nonbonded->setExceptionParameters(first14, 0, 3, i == 3 ? 4/1.2 : 0, 1.5, 0); nonbonded->setExceptionParameters(first14, 0, 3, i == 3 ? 4/1.2 : 0, 1.5, 0);
nonbonded->setExceptionParameters(second14, 1, 4, 0, 1.5, 0); nonbonded->setExceptionParameters(second14, 1, 4, 0, 1.5, 0);
LangevinIntegrator integrator2(0.0, 0.1, 0.01); context.reinitialize();
Context context2(system, integrator2, platform); context.setPositions(positions);
context2.setPositions(positions); state = context.getState(State::Forces | State::Energy);
state = context2.getState(State::Forces | State::Energy);
const vector<Vec3>& forces2 = state.getForces(); const vector<Vec3>& forces2 = state.getForces();
force = ONE_4PI_EPS0*4/(r*r); force = ONE_4PI_EPS0*4/(r*r);
energy = ONE_4PI_EPS0*4/r; energy = ONE_4PI_EPS0*4/r;
...@@ -654,14 +654,12 @@ void testDispersionCorrection() { ...@@ -654,14 +654,12 @@ void testDispersionCorrection() {
numType2++; numType2++;
} }
int numType1 = numParticles-numType2; int numType1 = numParticles-numType2;
nonbonded->updateParametersInContext(context);
energy2 = context.getState(State::Energy).getPotentialEnergy();
nonbonded->setUseDispersionCorrection(true); nonbonded->setUseDispersionCorrection(true);
context.reinitialize(); context.reinitialize();
context.setPositions(positions); context.setPositions(positions);
energy1 = context.getState(State::Energy).getPotentialEnergy(); energy1 = context.getState(State::Energy).getPotentialEnergy();
nonbonded->setUseDispersionCorrection(false);
context.reinitialize();
context.setPositions(positions);
energy2 = context.getState(State::Energy).getPotentialEnergy();
term1 = ((numType1*(numType1+1))/2)*(0.5*pow(1.1, 12)/pow(cutoff, 9))/9; term1 = ((numType1*(numType1+1))/2)*(0.5*pow(1.1, 12)/pow(cutoff, 9))/9;
term2 = ((numType1*(numType1+1))/2)*(0.5*pow(1.1, 6)/pow(cutoff, 3))/3; term2 = ((numType1*(numType1+1))/2)*(0.5*pow(1.1, 6)/pow(cutoff, 3))/3;
term1 += ((numType2*(numType2+1))/2)*(1*pow(1.0, 12)/pow(cutoff, 9))/9; term1 += ((numType2*(numType2+1))/2)*(1*pow(1.0, 12)/pow(cutoff, 9))/9;
...@@ -676,6 +674,77 @@ void testDispersionCorrection() { ...@@ -676,6 +674,77 @@ void testDispersionCorrection() {
ASSERT_EQUAL_TOL(expected, energy1-energy2, 1e-4); ASSERT_EQUAL_TOL(expected, energy1-energy2, 1e-4);
} }
void testChangingParameters() {
const int numMolecules = 600;
const int numParticles = numMolecules*2;
const double cutoff = 2.0;
const double boxSize = 20.0;
const double tol = 2e-3;
OpenCLPlatform cl;
ReferencePlatform reference;
System system;
for (int i = 0; i < numParticles; i++)
system.addParticle(1.0);
NonbondedForce* nonbonded = new NonbondedForce();
vector<Vec3> positions(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numMolecules; i++) {
if (i < numMolecules/2) {
nonbonded->addParticle(-1.0, 0.2, 0.1);
nonbonded->addParticle(1.0, 0.1, 0.1);
}
else {
nonbonded->addParticle(-1.0, 0.2, 0.2);
nonbonded->addParticle(1.0, 0.1, 0.2);
}
positions[2*i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
positions[2*i+1] = Vec3(positions[2*i][0]+1.0, positions[2*i][1], positions[2*i][2]);
system.addConstraint(2*i, 2*i+1, 1.0);
nonbonded->addException(2*i, 2*i+1, 0.0, 0.15, 0.0);
}
nonbonded->setNonbondedMethod(NonbondedForce::PME);
nonbonded->setCutoffDistance(cutoff);
system.addForce(nonbonded);
system.setDefaultPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
// See if Reference and OpenCL give the same forces and energies.
VerletIntegrator integrator1(0.01);
VerletIntegrator integrator2(0.01);
Context clContext(system, integrator1, cl);
Context referenceContext(system, integrator2, reference);
clContext.setPositions(positions);
referenceContext.setPositions(positions);
State clState = clContext.getState(State::Forces | State::Energy);
State referenceState = referenceContext.getState(State::Forces | State::Energy);
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(clState.getForces()[i], referenceState.getForces()[i], tol);
ASSERT_EQUAL_TOL(clState.getPotentialEnergy(), referenceState.getPotentialEnergy(), tol);
// Now modify parameters and see if they still agree.
for (int i = 0; i < numParticles; i += 5) {
double charge, sigma, epsilon;
nonbonded->getParticleParameters(i, charge, sigma, epsilon);
nonbonded->setParticleParameters(i, 1.5*charge, 1.1*sigma, 1.7*epsilon);
}
double total = 0;
for (int i = 0; i < numParticles; i++) {
double charge, sigma, epsilon;
nonbonded->getParticleParameters(i, charge, sigma, epsilon);
total += charge;
}
nonbonded->updateParametersInContext(clContext);
nonbonded->updateParametersInContext(referenceContext);
clState = clContext.getState(State::Forces | State::Energy);
referenceState = referenceContext.getState(State::Forces | State::Energy);
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(clState.getForces()[i], referenceState.getForces()[i], tol);
ASSERT_EQUAL_TOL(clState.getPotentialEnergy(), referenceState.getPotentialEnergy(), tol);
}
void testParallelComputation(bool useCutoff) { void testParallelComputation(bool useCutoff) {
OpenCLPlatform platform; OpenCLPlatform platform;
System system; System system;
...@@ -699,6 +768,9 @@ void testParallelComputation(bool useCutoff) { ...@@ -699,6 +768,9 @@ void testParallelComputation(bool useCutoff) {
if (delta.dot(delta) < 0.1) if (delta.dot(delta) < 0.1)
force->addException(i, j, 0, 1, 0); force->addException(i, j, 0, 1, 0);
} }
// Create two contexts, one with a single device and one with two devices.
VerletIntegrator integrator1(0.01); VerletIntegrator integrator1(0.01);
Context context1(system, integrator1, platform); Context context1(system, integrator1, platform);
context1.setPositions(positions); context1.setPositions(positions);
...@@ -710,6 +782,24 @@ void testParallelComputation(bool useCutoff) { ...@@ -710,6 +782,24 @@ void testParallelComputation(bool useCutoff) {
Context context2(system, integrator2, platform, props); Context context2(system, integrator2, platform, props);
context2.setPositions(positions); context2.setPositions(positions);
State state2 = context2.getState(State::Forces | State::Energy); State state2 = context2.getState(State::Forces | State::Energy);
// See if they agree.
ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), 1e-5);
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5);
// Modify some particle parameters and see if they still agree.
for (int i = 0; i < numParticles; i += 5) {
double charge, sigma, epsilon;
force->getParticleParameters(i, charge, sigma, epsilon);
force->setParticleParameters(i, 0.9*charge, sigma, epsilon);
}
force->updateParametersInContext(context1);
force->updateParametersInContext(context2);
state1 = context1.getState(State::Forces | State::Energy);
state2 = context2.getState(State::Forces | State::Energy);
ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), 1e-5); ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), 1e-5);
for (int i = 0; i < numParticles; i++) for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5); ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5);
...@@ -727,6 +817,7 @@ int main() { ...@@ -727,6 +817,7 @@ int main() {
testBlockInteractions(false); testBlockInteractions(false);
testBlockInteractions(true); testBlockInteractions(true);
testDispersionCorrection(); testDispersionCorrection();
testChangingParameters();
testParallelComputation(false); testParallelComputation(false);
testParallelComputation(true); testParallelComputation(true);
} }
......
...@@ -50,7 +50,7 @@ void testGaussian() { ...@@ -50,7 +50,7 @@ void testGaussian() {
system.addParticle(1.0); system.addParticle(1.0);
OpenCLPlatform::PlatformData platformData(system, "", ""); OpenCLPlatform::PlatformData platformData(system, "", "");
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(system); context.initialize();
context.getIntegrationUtilities().initRandomNumberGenerator(0); context.getIntegrationUtilities().initRandomNumberGenerator(0);
OpenCLArray<mm_float4>& random = context.getIntegrationUtilities().getRandom(); OpenCLArray<mm_float4>& random = context.getIntegrationUtilities().getRandom();
context.getIntegrationUtilities().prepareRandomNumbers(random.getSize()); context.getIntegrationUtilities().prepareRandomNumbers(random.getSize());
......
...@@ -64,7 +64,7 @@ void verifySorting(vector<float> array) { ...@@ -64,7 +64,7 @@ void verifySorting(vector<float> array) {
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(system, "", ""); OpenCLPlatform::PlatformData platformData(system, "", "");
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(system); context.initialize();
OpenCLArray<float> data(context, array.size(), "sortData"); OpenCLArray<float> data(context, array.size(), "sortData");
data.upload(array); data.upload(array);
OpenCLSort<SortTrait> sort(context, array.size()); OpenCLSort<SortTrait> sort(context, array.size());
......
...@@ -743,6 +743,47 @@ double ReferenceCalcNonbondedForceKernel::execute(ContextImpl& context, bool inc ...@@ -743,6 +743,47 @@ double ReferenceCalcNonbondedForceKernel::execute(ContextImpl& context, bool inc
return energy; return energy;
} }
void ReferenceCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const NonbondedForce& force) {
if (force.getNumParticles() != numParticles)
throw OpenMMException("updateParametersInContext: The number of particles has changed");
vector<int> nb14s;
for (int i = 0; i < force.getNumExceptions(); i++) {
int particle1, particle2;
double chargeProd, sigma, epsilon;
force.getExceptionParameters(i, particle1, particle2, chargeProd, sigma, epsilon);
if (chargeProd != 0.0 || epsilon != 0.0)
nb14s.push_back(i);
}
if (nb14s.size() != num14)
throw OpenMMException("updateParametersInContext: The number of non-excluded exceptions has changed");
// Record the values.
for (int i = 0; i < numParticles; ++i) {
double charge, radius, depth;
force.getParticleParameters(i, charge, radius, depth);
particleParamArray[i][0] = static_cast<RealOpenMM>(0.5*radius);
particleParamArray[i][1] = static_cast<RealOpenMM>(2.0*sqrt(depth));
particleParamArray[i][2] = static_cast<RealOpenMM>(charge);
}
for (int i = 0; i < num14; ++i) {
int particle1, particle2;
double charge, radius, depth;
force.getExceptionParameters(nb14s[i], particle1, particle2, charge, radius, depth);
bonded14IndexArray[i][0] = particle1;
bonded14IndexArray[i][1] = particle2;
bonded14ParamArray[i][0] = static_cast<RealOpenMM>(radius);
bonded14ParamArray[i][1] = static_cast<RealOpenMM>(4.0*depth);
bonded14ParamArray[i][2] = static_cast<RealOpenMM>(charge);
}
// Recompute the coefficient for the dispersion correction.
NonbondedForce::NonbondedMethod method = force.getNonbondedMethod();
if (force.getUseDispersionCorrection() && (method == NonbondedForce::CutoffPeriodic || method == NonbondedForce::Ewald || method == NonbondedForce::PME))
dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(context.getSystem(), force);
}
class ReferenceTabulatedFunction : public Lepton::CustomFunction { class ReferenceTabulatedFunction : public Lepton::CustomFunction {
public: public:
ReferenceTabulatedFunction(double min, double max, const vector<double>& values) : ReferenceTabulatedFunction(double min, double max, const vector<double>& values) :
......
...@@ -507,6 +507,13 @@ public: ...@@ -507,6 +507,13 @@ public:
* @return the potential energy due to the force * @return the potential energy due to the force
*/ */
double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal); double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the NonbondedForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const NonbondedForce& force);
private: private:
int numParticles, num14; int numParticles, num14;
int **exclusionArray, **bonded14IndexArray; int **exclusionArray, **bonded14IndexArray;
......
...@@ -165,8 +165,7 @@ void testExclusionsAnd14() { ...@@ -165,8 +165,7 @@ void testExclusionsAnd14() {
nonbonded->setParticleParameters(i, 2, 1.5, 0); nonbonded->setParticleParameters(i, 2, 1.5, 0);
nonbonded->setExceptionParameters(first14, 0, 3, i == 3 ? 4/1.2 : 0, 1.5, 0); nonbonded->setExceptionParameters(first14, 0, 3, i == 3 ? 4/1.2 : 0, 1.5, 0);
nonbonded->setExceptionParameters(second14, 1, 4, 0, 1.5, 0); nonbonded->setExceptionParameters(second14, 1, 4, 0, 1.5, 0);
context.reinitialize(); nonbonded->updateParametersInContext(context);
context.setPositions(positions);
state = context.getState(State::Forces | State::Energy); state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces2 = state.getForces(); const vector<Vec3>& forces2 = state.getForces();
force = ONE_4PI_EPS0*4/(r*r); force = ONE_4PI_EPS0*4/(r*r);
...@@ -298,8 +297,7 @@ void testCutoff14() { ...@@ -298,8 +297,7 @@ void testCutoff14() {
nonbonded->setParticleParameters(i, q, 1.5, 0); nonbonded->setParticleParameters(i, q, 1.5, 0);
nonbonded->setExceptionParameters(first14, 0, 3, i == 3 ? q*q/1.2 : 0, 1.5, 0); nonbonded->setExceptionParameters(first14, 0, 3, i == 3 ? q*q/1.2 : 0, 1.5, 0);
nonbonded->setExceptionParameters(second14, 1, 4, 0, 1.5, 0); nonbonded->setExceptionParameters(second14, 1, 4, 0, 1.5, 0);
context.reinitialize(); nonbonded->updateParametersInContext(context);
context.setPositions(positions);
state = context.getState(State::Forces | State::Energy); state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces2 = state.getForces(); const vector<Vec3>& forces2 = state.getForces();
force = ONE_4PI_EPS0*q*q/(r*r); force = ONE_4PI_EPS0*q*q/(r*r);
...@@ -401,14 +399,12 @@ void testDispersionCorrection() { ...@@ -401,14 +399,12 @@ void testDispersionCorrection() {
numType2++; numType2++;
} }
int numType1 = numParticles-numType2; int numType1 = numParticles-numType2;
nonbonded->updateParametersInContext(context);
energy2 = context.getState(State::Energy).getPotentialEnergy();
nonbonded->setUseDispersionCorrection(true); nonbonded->setUseDispersionCorrection(true);
context.reinitialize(); context.reinitialize();
context.setPositions(positions); context.setPositions(positions);
energy1 = context.getState(State::Energy).getPotentialEnergy(); energy1 = context.getState(State::Energy).getPotentialEnergy();
nonbonded->setUseDispersionCorrection(false);
context.reinitialize();
context.setPositions(positions);
energy2 = context.getState(State::Energy).getPotentialEnergy();
term1 = ((numType1*(numType1+1))/2)*(0.5*pow(1.1, 12)/pow(cutoff, 9))/9; term1 = ((numType1*(numType1+1))/2)*(0.5*pow(1.1, 12)/pow(cutoff, 9))/9;
term2 = ((numType1*(numType1+1))/2)*(0.5*pow(1.1, 6)/pow(cutoff, 3))/3; term2 = ((numType1*(numType1+1))/2)*(0.5*pow(1.1, 6)/pow(cutoff, 3))/3;
term1 += ((numType2*(numType2+1))/2)*(1*pow(1.0, 12)/pow(cutoff, 9))/9; term1 += ((numType2*(numType2+1))/2)*(1*pow(1.0, 12)/pow(cutoff, 9))/9;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment