"ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "4e50d72141b60a7a37ce6f2ef94d49d85a1529b1"
Commit 496f469a authored by peastman's avatar peastman
Browse files

Fixed bug when multiple virtual sites depend on the same particles (see bug 2019)

parent ed554f52
...@@ -524,6 +524,47 @@ void testReordering() { ...@@ -524,6 +524,47 @@ void testReordering() {
} }
} }
/**
* Test a System where multiple virtual sites are all calculated from the same particles.
*/
void testOverlappingSites() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
nonbonded->addParticle(1.0, 0.0, 0.0);
nonbonded->addParticle(-0.5, 0.0, 0.0);
nonbonded->addParticle(-0.5, 0.0, 0.0);
vector<Vec3> positions;
positions.push_back(Vec3(0, 0, 0));
positions.push_back(Vec3(10, 0, 0));
positions.push_back(Vec3(0, 10, 0));
for (int i = 0; i < 20; i++) {
system.addParticle(0.0);
double u = 0.1*((i+1)%4);
double v = 0.05*i;
system.setVirtualSite(3+i, new ThreeParticleAverageSite(0, 1, 2, u, v, 1-u-v));
nonbonded->addParticle(i%2 == 0 ? -1.0 : 1.0, 0.0, 0.0);
positions.push_back(Vec3());
}
VerletIntegrator i1(0.002);
VerletIntegrator i2(0.002);
Context c1(system, i1, Platform::getPlatformByName("Reference"));
Context c2(system, i2, platform);
c1.setPositions(positions);
c2.setPositions(positions);
c1.applyConstraints(0.0001);
c2.applyConstraints(0.0001);
State s1 = c1.getState(State::Positions | State::Forces);
State s2 = c2.getState(State::Positions | State::Forces);
for (int i = 0; i < system.getNumParticles(); i++)
ASSERT_EQUAL_VEC(s1.getPositions()[i], s2.getPositions()[i], 1e-5);
for (int i = 0; i < 3; i++)
ASSERT_EQUAL_VEC(s1.getForces()[i], s2.getForces()[i], 1e-5);
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -535,6 +576,7 @@ int main(int argc, char* argv[]) { ...@@ -535,6 +576,7 @@ int main(int argc, char* argv[]) {
testLocalCoordinates(); testLocalCoordinates();
testConservationLaws(); testConservationLaws();
testReordering(); testReordering();
testOverlappingSites();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -121,7 +121,7 @@ private: ...@@ -121,7 +121,7 @@ private:
cl::Kernel ccmaPosForceKernel, ccmaVelForceKernel; cl::Kernel ccmaPosForceKernel, ccmaVelForceKernel;
cl::Kernel ccmaMultiplyKernel; cl::Kernel ccmaMultiplyKernel;
cl::Kernel ccmaPosUpdateKernel, ccmaVelUpdateKernel; cl::Kernel ccmaPosUpdateKernel, ccmaVelUpdateKernel;
cl::Kernel vsitePositionKernel, vsiteForceKernel; cl::Kernel vsitePositionKernel, vsiteForceKernel, vsiteAddForcesKernel;
cl::Kernel randomKernel, timeShiftKernel; cl::Kernel randomKernel, timeShiftKernel;
OpenCLArray* posDelta; OpenCLArray* posDelta;
OpenCLArray* settleAtoms; OpenCLArray* settleAtoms;
...@@ -152,7 +152,7 @@ private: ...@@ -152,7 +152,7 @@ private:
OpenCLArray* vsiteLocalCoordsParams; OpenCLArray* vsiteLocalCoordsParams;
int randomPos; int randomPos;
int lastSeed, numVsites; int lastSeed, numVsites;
bool hasInitializedPosConstraintKernels, hasInitializedVelConstraintKernels, ccmaUseDirectBuffer; bool hasInitializedPosConstraintKernels, hasInitializedVelConstraintKernels, ccmaUseDirectBuffer, hasOverlappingVsites;
struct ShakeCluster; struct ShakeCluster;
struct ConstraintOrderer; struct ConstraintOrderer;
}; };
......
...@@ -101,7 +101,7 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c ...@@ -101,7 +101,7 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
ccmaConstraintMatrixValue(NULL), ccmaDelta1(NULL), ccmaDelta2(NULL), ccmaConverged(NULL), ccmaConvergedHostBuffer(NULL), ccmaConstraintMatrixValue(NULL), ccmaDelta1(NULL), ccmaDelta2(NULL), ccmaConverged(NULL), ccmaConvergedHostBuffer(NULL),
vsite2AvgAtoms(NULL), vsite2AvgWeights(NULL), vsite3AvgAtoms(NULL), vsite3AvgWeights(NULL), vsite2AvgAtoms(NULL), vsite2AvgWeights(NULL), vsite3AvgAtoms(NULL), vsite3AvgWeights(NULL),
vsiteOutOfPlaneAtoms(NULL), vsiteOutOfPlaneWeights(NULL), vsiteLocalCoordsAtoms(NULL), vsiteLocalCoordsParams(NULL), vsiteOutOfPlaneAtoms(NULL), vsiteOutOfPlaneWeights(NULL), vsiteLocalCoordsAtoms(NULL), vsiteLocalCoordsParams(NULL),
hasInitializedPosConstraintKernels(false), hasInitializedVelConstraintKernels(false) { hasInitializedPosConstraintKernels(false), hasInitializedVelConstraintKernels(false), hasOverlappingVsites(false) {
// Create workspace arrays. // Create workspace arrays.
if (context.getUseDoublePrecision() || context.getUseMixedPrecision()) { if (context.getUseDoublePrecision() || context.getUseMixedPrecision()) {
...@@ -649,6 +649,7 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c ...@@ -649,6 +649,7 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
int num3Avg = vsite3AvgAtomVec.size(); int num3Avg = vsite3AvgAtomVec.size();
int numOutOfPlane = vsiteOutOfPlaneAtomVec.size(); int numOutOfPlane = vsiteOutOfPlaneAtomVec.size();
int numLocalCoords = vsiteLocalCoordsAtomVec.size(); int numLocalCoords = vsiteLocalCoordsAtomVec.size();
numVsites = num2Avg+num3Avg+numOutOfPlane+numLocalCoords;
vsite2AvgAtoms = OpenCLArray::create<mm_int4>(context, max(1, num2Avg), "vsite2AvgAtoms"); vsite2AvgAtoms = OpenCLArray::create<mm_int4>(context, max(1, num2Avg), "vsite2AvgAtoms");
vsite3AvgAtoms = OpenCLArray::create<mm_int4>(context, max(1, num3Avg), "vsite3AvgAtoms"); vsite3AvgAtoms = OpenCLArray::create<mm_int4>(context, max(1, num3Avg), "vsite3AvgAtoms");
vsiteOutOfPlaneAtoms = OpenCLArray::create<mm_int4>(context, max(1, numOutOfPlane), "vsiteOutOfPlaneAtoms"); vsiteOutOfPlaneAtoms = OpenCLArray::create<mm_int4>(context, max(1, numOutOfPlane), "vsiteOutOfPlaneAtoms");
...@@ -706,6 +707,20 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c ...@@ -706,6 +707,20 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
} }
} }
// If multiple virtual sites depend on the same particle, make sure the force distribution
// can be done safely.
vector<int> atomCounts(numAtoms, 0);
for (int i = 0; i < numAtoms; i++)
if (system.isVirtualSite(i))
for (int j = 0; j < system.getVirtualSite(i).getNumParticles(); j++)
atomCounts[system.getVirtualSite(i).getParticle(j)]++;
for (int i = 0; i < numAtoms; i++)
if (atomCounts[i] > 1)
hasOverlappingVsites = true;
if (hasOverlappingVsites && context.getUseDoublePrecision() && !context.getSupports64BitGlobalAtomics())
throw OpenMMException("This device does not support 64 bit atomics. Cannot use double precision when multiple virtual sites depend on the same atom.");
// Create the kernels for virtual sites. // Create the kernels for virtual sites.
map<string, string> defines; map<string, string> defines;
...@@ -713,6 +728,10 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c ...@@ -713,6 +728,10 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
defines["NUM_3_AVERAGE"] = context.intToString(num3Avg); defines["NUM_3_AVERAGE"] = context.intToString(num3Avg);
defines["NUM_OUT_OF_PLANE"] = context.intToString(numOutOfPlane); defines["NUM_OUT_OF_PLANE"] = context.intToString(numOutOfPlane);
defines["NUM_LOCAL_COORDS"] = context.intToString(numLocalCoords); defines["NUM_LOCAL_COORDS"] = context.intToString(numLocalCoords);
defines["NUM_ATOMS"] = context.intToString(numAtoms);
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
if (hasOverlappingVsites)
defines["HAS_OVERLAPPING_VSITES"] = "1";
cl::Program vsiteProgram = context.createProgram(OpenCLKernelSources::virtualSites, defines); cl::Program vsiteProgram = context.createProgram(OpenCLKernelSources::virtualSites, defines);
vsitePositionKernel = cl::Kernel(vsiteProgram, "computeVirtualSites"); vsitePositionKernel = cl::Kernel(vsiteProgram, "computeVirtualSites");
int index = 0; int index = 0;
...@@ -731,6 +750,8 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c ...@@ -731,6 +750,8 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
index = 0; index = 0;
vsiteForceKernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
index++; // Skip argument 1: the force array hasn't been created yet. index++; // Skip argument 1: the force array hasn't been created yet.
if (context.getSupports64BitGlobalAtomics())
index++; // Skip argument 2: the force array hasn't been created yet.
if (context.getUseMixedPrecision()) if (context.getUseMixedPrecision())
vsiteForceKernel.setArg<cl::Buffer>(index++, context.getPosqCorrection().getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(index++, context.getPosqCorrection().getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(index++, vsite2AvgAtoms->getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(index++, vsite2AvgAtoms->getDeviceBuffer());
...@@ -741,7 +762,8 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c ...@@ -741,7 +762,8 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
vsiteForceKernel.setArg<cl::Buffer>(index++, vsiteOutOfPlaneWeights->getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(index++, vsiteOutOfPlaneWeights->getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(index++, vsiteLocalCoordsAtoms->getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(index++, vsiteLocalCoordsAtoms->getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(index++, vsiteLocalCoordsParams->getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(index++, vsiteLocalCoordsParams->getDeviceBuffer());
numVsites = num2Avg+num3Avg+numOutOfPlane+numLocalCoords; if (hasOverlappingVsites && context.getSupports64BitGlobalAtomics())
vsiteAddForcesKernel = cl::Kernel(vsiteProgram, "addDistributedForces");
} }
OpenCLIntegrationUtilities::~OpenCLIntegrationUtilities() { OpenCLIntegrationUtilities::~OpenCLIntegrationUtilities() {
...@@ -941,8 +963,25 @@ void OpenCLIntegrationUtilities::computeVirtualSites() { ...@@ -941,8 +963,25 @@ void OpenCLIntegrationUtilities::computeVirtualSites() {
void OpenCLIntegrationUtilities::distributeForcesFromVirtualSites() { void OpenCLIntegrationUtilities::distributeForcesFromVirtualSites() {
if (numVsites > 0) { if (numVsites > 0) {
// Set arguments that didn't exist yet in the constructor.
vsiteForceKernel.setArg<cl::Buffer>(1, context.getForce().getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(1, context.getForce().getDeviceBuffer());
if (context.getSupports64BitGlobalAtomics()) {
vsiteForceKernel.setArg<cl::Buffer>(2, context.getLongForceBuffer().getDeviceBuffer());
if (hasOverlappingVsites) {
// We'll be using 64 bit atomics for the force redistribution, so clear the buffer.
context.clearBuffer(context.getLongForceBuffer());
}
}
context.executeKernel(vsiteForceKernel, numVsites); context.executeKernel(vsiteForceKernel, numVsites);
if (context.getSupports64BitGlobalAtomics() && hasOverlappingVsites) {
// Add the redistributed forces from the virtual sites to the main force array.
vsiteAddForcesKernel.setArg<cl::Buffer>(0, context.getLongForceBuffer().getDeviceBuffer());
vsiteAddForcesKernel.setArg<cl::Buffer>(1, context.getForce().getDeviceBuffer());
context.executeKernel(vsiteAddForcesKernel, context.getNumAtoms());
}
} }
} }
......
...@@ -108,10 +108,66 @@ __kernel void computeVirtualSites(__global real4* restrict posq, ...@@ -108,10 +108,66 @@ __kernel void computeVirtualSites(__global real4* restrict posq,
} }
} }
#ifdef HAS_OVERLAPPING_VSITES
#ifdef SUPPORTS_64_BIT_ATOMICS
// We will use 64 bit atomics for force redistribution.
#define ADD_FORCE(index, f) addForce(index, f, longForce);
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
void addForce(int index, float4 f, __global long* longForce) {
atom_add(&longForce[index], (long) (f.x*0x100000000));
atom_add(&longForce[index+PADDED_NUM_ATOMS], (long) (f.y*0x100000000));
atom_add(&longForce[index+2*PADDED_NUM_ATOMS], (long) (f.z*0x100000000));
}
__kernel void addDistributedForces(__global const long* restrict longForces, __global real4* restrict forces) {
real scale = 1/(real) 0x100000000;
for (int index = get_global_id(0); index < NUM_ATOMS; index += get_global_size(0)) {
real4 f = (real4) (scale*longForces[index], scale*longForces[index+PADDED_NUM_ATOMS], scale*longForces[index+2*PADDED_NUM_ATOMS], 0);
forces[index] += f;
}
}
#else
// 64 bit atomics aren't supported, so we have to use atomic_cmpxchg() for force redistribution.
#define ADD_FORCE(index, f) addForce(index, f, force);
void atomicAddFloat(__global float* p, float v) {
__global int* ip = (__global int*) p;
while (true) {
union {
float f;
int i;
} oldval, newval;
oldval.f = *p;
newval.f = oldval.f+v;
if (atomic_cmpxchg(ip, oldval.i, newval.i) == oldval.i)
return;
}
}
void addForce(int index, float4 f, __global float4* force) {
__global float* components = (__global float*) force;
atomicAddFloat(&components[4*index], f.x);
atomicAddFloat(&components[4*index+1], f.y);
atomicAddFloat(&components[4*index+2], f.z);
}
#endif
#else
// There are no overlapping virtual sites, so we can just store forces directly.
#define ADD_FORCE(index, f) force[index].xyz += (f).xyz;
#endif
/** /**
* Distribute forces from virtual sites to the atoms they are based on. * Distribute forces from virtual sites to the atoms they are based on.
*/ */
__kernel void distributeForces(__global const real4* restrict posq, __global real4* restrict force, __kernel void distributeForces(__global const real4* restrict posq, __global real4* restrict force,
#ifdef SUPPORTS_64_BIT_ATOMICS
__global long* restrict longForce,
#endif
#ifdef USE_MIXED_PRECISION #ifdef USE_MIXED_PRECISION
__global real4* restrict posqCorrection, __global real4* restrict posqCorrection,
#endif #endif
...@@ -129,12 +185,8 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea ...@@ -129,12 +185,8 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
int4 atoms = avg2Atoms[index]; int4 atoms = avg2Atoms[index];
real2 weights = avg2Weights[index]; real2 weights = avg2Weights[index];
real4 f = force[atoms.x]; real4 f = force[atoms.x];
real4 f1 = force[atoms.y]; ADD_FORCE(atoms.y, f*weights.x);
real4 f2 = force[atoms.z]; ADD_FORCE(atoms.z, f*weights.y);
f1.xyz += f.xyz*weights.x;
f2.xyz += f.xyz*weights.y;
force[atoms.y] = f1;
force[atoms.z] = f2;
} }
// Three particle average sites. // Three particle average sites.
...@@ -143,15 +195,9 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea ...@@ -143,15 +195,9 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
int4 atoms = avg3Atoms[index]; int4 atoms = avg3Atoms[index];
real4 weights = avg3Weights[index]; real4 weights = avg3Weights[index];
real4 f = force[atoms.x]; real4 f = force[atoms.x];
real4 f1 = force[atoms.y]; ADD_FORCE(atoms.y, f*weights.x);
real4 f2 = force[atoms.z]; ADD_FORCE(atoms.z, f*weights.y);
real4 f3 = force[atoms.w]; ADD_FORCE(atoms.w, f*weights.z);
f1.xyz += f.xyz*weights.x;
f2.xyz += f.xyz*weights.y;
f3.xyz += f.xyz*weights.z;
force[atoms.y] = f1;
force[atoms.z] = f2;
force[atoms.w] = f3;
} }
// Out of plane sites. // Out of plane sites.
...@@ -165,21 +211,15 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea ...@@ -165,21 +211,15 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
mixed4 v12 = pos2-pos1; mixed4 v12 = pos2-pos1;
mixed4 v13 = pos3-pos1; mixed4 v13 = pos3-pos1;
real4 f = force[atoms.x]; real4 f = force[atoms.x];
real4 f1 = force[atoms.y];
real4 f2 = force[atoms.z];
real4 f3 = force[atoms.w];
real4 fp2 = (real4) (weights.x*f.x - weights.z*v13.z*f.y + weights.z*v13.y*f.z, real4 fp2 = (real4) (weights.x*f.x - weights.z*v13.z*f.y + weights.z*v13.y*f.z,
weights.z*v13.z*f.x + weights.x*f.y - weights.z*v13.x*f.z, weights.z*v13.z*f.x + weights.x*f.y - weights.z*v13.x*f.z,
-weights.z*v13.y*f.x + weights.z*v13.x*f.y + weights.x*f.z, 0.0f); -weights.z*v13.y*f.x + weights.z*v13.x*f.y + weights.x*f.z, 0.0f);
real4 fp3 = (real4) (weights.y*f.x + weights.z*v12.z*f.y - weights.z*v12.y*f.z, real4 fp3 = (real4) (weights.y*f.x + weights.z*v12.z*f.y - weights.z*v12.y*f.z,
-weights.z*v12.z*f.x + weights.y*f.y + weights.z*v12.x*f.z, -weights.z*v12.z*f.x + weights.y*f.y + weights.z*v12.x*f.z,
weights.z*v12.y*f.x - weights.z*v12.x*f.y + weights.y*f.z, 0.0f); weights.z*v12.y*f.x - weights.z*v12.x*f.y + weights.y*f.z, 0.0f);
f1.xyz += f.xyz-fp2.xyz-fp3.xyz; ADD_FORCE(atoms.y, f-fp2-fp3);
f2.xyz += fp2.xyz; ADD_FORCE(atoms.z, fp2);
f3.xyz += fp3.xyz; ADD_FORCE(atoms.w, fp3);
force[atoms.y] = f1;
force[atoms.z] = f2;
force[atoms.w] = f3;
} }
// Local coordinates sites. // Local coordinates sites.
...@@ -230,9 +270,9 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea ...@@ -230,9 +270,9 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
mixed sz3 = t32*dz.x-t31*dz.y; mixed sz3 = t32*dz.x-t31*dz.y;
mixed4 wxScaled = wx*invNormXdir; mixed4 wxScaled = wx*invNormXdir;
real4 f = force[atoms.x]; real4 f = force[atoms.x];
real4 f1 = force[atoms.y]; real4 f1 = 0;
real4 f2 = force[atoms.z]; real4 f2 = 0;
real4 f3 = force[atoms.w]; real4 f3 = 0;
mixed4 fp1 = localPosition*f.x; mixed4 fp1 = localPosition*f.x;
mixed4 fp2 = localPosition*f.y; mixed4 fp2 = localPosition*f.y;
mixed4 fp3 = localPosition*f.z; mixed4 fp3 = localPosition*f.z;
...@@ -263,8 +303,8 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea ...@@ -263,8 +303,8 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
f3.x += fp3.x*wxScaled.z*( -dx.z*dx.x) + fp3.z*(dz.z*sx3+t32) + fp3.y*((-dx.x*dy.z-dz.y)*wxScaled.z + dy.z*sx3 + dx.x*t33); f3.x += fp3.x*wxScaled.z*( -dx.z*dx.x) + fp3.z*(dz.z*sx3+t32) + fp3.y*((-dx.x*dy.z-dz.y)*wxScaled.z + dy.z*sx3 + dx.x*t33);
f3.y += fp3.x*wxScaled.z*( -dx.z*dx.y) + fp3.z*(dz.z*sy3-t31) + fp3.y*((-dx.y*dy.z+dz.x)*wxScaled.z + dy.z*sy3 + dx.y*t33); f3.y += fp3.x*wxScaled.z*( -dx.z*dx.y) + fp3.z*(dz.z*sy3-t31) + fp3.y*((-dx.y*dy.z+dz.x)*wxScaled.z + dy.z*sy3 + dx.y*t33);
f3.z += fp3.x*wxScaled.z*(1-dx.z*dx.z) + fp3.z*(dz.z*sz3 ) + fp3.y*((-dx.z*dy.z )*wxScaled.z + dy.z*sz3 - dx.x*t31 - dx.y*t32) + f.z*originWeights.z; f3.z += fp3.x*wxScaled.z*(1-dx.z*dx.z) + fp3.z*(dz.z*sz3 ) + fp3.y*((-dx.z*dy.z )*wxScaled.z + dy.z*sz3 - dx.x*t31 - dx.y*t32) + f.z*originWeights.z;
force[atoms.y] = f1; ADD_FORCE(atoms.y, f1);
force[atoms.z] = f2; ADD_FORCE(atoms.z, f2);
force[atoms.w] = f3; ADD_FORCE(atoms.w, f3);
} }
} }
...@@ -524,6 +524,47 @@ void testReordering() { ...@@ -524,6 +524,47 @@ void testReordering() {
} }
} }
/**
* Test a System where multiple virtual sites are all calculated from the same particles.
*/
void testOverlappingSites() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
nonbonded->addParticle(1.0, 0.0, 0.0);
nonbonded->addParticle(-0.5, 0.0, 0.0);
nonbonded->addParticle(-0.5, 0.0, 0.0);
vector<Vec3> positions;
positions.push_back(Vec3(0, 0, 0));
positions.push_back(Vec3(10, 0, 0));
positions.push_back(Vec3(0, 10, 0));
for (int i = 0; i < 20; i++) {
system.addParticle(0.0);
double u = 0.1*((i+1)%4);
double v = 0.05*i;
system.setVirtualSite(3+i, new ThreeParticleAverageSite(0, 1, 2, u, v, 1-u-v));
nonbonded->addParticle(i%2 == 0 ? -1.0 : 1.0, 0.0, 0.0);
positions.push_back(Vec3());
}
VerletIntegrator i1(0.002);
VerletIntegrator i2(0.002);
Context c1(system, i1, Platform::getPlatformByName("Reference"));
Context c2(system, i2, platform);
c1.setPositions(positions);
c2.setPositions(positions);
c1.applyConstraints(0.0001);
c2.applyConstraints(0.0001);
State s1 = c1.getState(State::Positions | State::Forces);
State s2 = c2.getState(State::Positions | State::Forces);
for (int i = 0; i < system.getNumParticles(); i++)
ASSERT_EQUAL_VEC(s1.getPositions()[i], s2.getPositions()[i], 1e-5);
for (int i = 0; i < 3; i++)
ASSERT_EQUAL_VEC(s1.getForces()[i], s2.getForces()[i], 1e-5);
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -535,6 +576,7 @@ int main(int argc, char* argv[]) { ...@@ -535,6 +576,7 @@ int main(int argc, char* argv[]) {
testLocalCoordinates(); testLocalCoordinates();
testConservationLaws(); testConservationLaws();
testReordering(); testReordering();
testOverlappingSites();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment