Commit 496f469a authored by peastman's avatar peastman
Browse files

Fixed bug when multiple virtual sites depend on the same particles (see bug 2019)

parent ed554f52
......@@ -524,6 +524,47 @@ void testReordering() {
}
}
/**
* Test a System where multiple virtual sites are all calculated from the same particles.
*/
void testOverlappingSites() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
nonbonded->addParticle(1.0, 0.0, 0.0);
nonbonded->addParticle(-0.5, 0.0, 0.0);
nonbonded->addParticle(-0.5, 0.0, 0.0);
vector<Vec3> positions;
positions.push_back(Vec3(0, 0, 0));
positions.push_back(Vec3(10, 0, 0));
positions.push_back(Vec3(0, 10, 0));
for (int i = 0; i < 20; i++) {
system.addParticle(0.0);
double u = 0.1*((i+1)%4);
double v = 0.05*i;
system.setVirtualSite(3+i, new ThreeParticleAverageSite(0, 1, 2, u, v, 1-u-v));
nonbonded->addParticle(i%2 == 0 ? -1.0 : 1.0, 0.0, 0.0);
positions.push_back(Vec3());
}
VerletIntegrator i1(0.002);
VerletIntegrator i2(0.002);
Context c1(system, i1, Platform::getPlatformByName("Reference"));
Context c2(system, i2, platform);
c1.setPositions(positions);
c2.setPositions(positions);
c1.applyConstraints(0.0001);
c2.applyConstraints(0.0001);
State s1 = c1.getState(State::Positions | State::Forces);
State s2 = c2.getState(State::Positions | State::Forces);
for (int i = 0; i < system.getNumParticles(); i++)
ASSERT_EQUAL_VEC(s1.getPositions()[i], s2.getPositions()[i], 1e-5);
for (int i = 0; i < 3; i++)
ASSERT_EQUAL_VEC(s1.getForces()[i], s2.getForces()[i], 1e-5);
}
int main(int argc, char* argv[]) {
try {
if (argc > 1)
......@@ -535,6 +576,7 @@ int main(int argc, char* argv[]) {
testLocalCoordinates();
testConservationLaws();
testReordering();
testOverlappingSites();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
......@@ -121,7 +121,7 @@ private:
cl::Kernel ccmaPosForceKernel, ccmaVelForceKernel;
cl::Kernel ccmaMultiplyKernel;
cl::Kernel ccmaPosUpdateKernel, ccmaVelUpdateKernel;
cl::Kernel vsitePositionKernel, vsiteForceKernel;
cl::Kernel vsitePositionKernel, vsiteForceKernel, vsiteAddForcesKernel;
cl::Kernel randomKernel, timeShiftKernel;
OpenCLArray* posDelta;
OpenCLArray* settleAtoms;
......@@ -152,7 +152,7 @@ private:
OpenCLArray* vsiteLocalCoordsParams;
int randomPos;
int lastSeed, numVsites;
bool hasInitializedPosConstraintKernels, hasInitializedVelConstraintKernels, ccmaUseDirectBuffer;
bool hasInitializedPosConstraintKernels, hasInitializedVelConstraintKernels, ccmaUseDirectBuffer, hasOverlappingVsites;
struct ShakeCluster;
struct ConstraintOrderer;
};
......
......@@ -101,7 +101,7 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
ccmaConstraintMatrixValue(NULL), ccmaDelta1(NULL), ccmaDelta2(NULL), ccmaConverged(NULL), ccmaConvergedHostBuffer(NULL),
vsite2AvgAtoms(NULL), vsite2AvgWeights(NULL), vsite3AvgAtoms(NULL), vsite3AvgWeights(NULL),
vsiteOutOfPlaneAtoms(NULL), vsiteOutOfPlaneWeights(NULL), vsiteLocalCoordsAtoms(NULL), vsiteLocalCoordsParams(NULL),
hasInitializedPosConstraintKernels(false), hasInitializedVelConstraintKernels(false) {
hasInitializedPosConstraintKernels(false), hasInitializedVelConstraintKernels(false), hasOverlappingVsites(false) {
// Create workspace arrays.
if (context.getUseDoublePrecision() || context.getUseMixedPrecision()) {
......@@ -649,6 +649,7 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
int num3Avg = vsite3AvgAtomVec.size();
int numOutOfPlane = vsiteOutOfPlaneAtomVec.size();
int numLocalCoords = vsiteLocalCoordsAtomVec.size();
numVsites = num2Avg+num3Avg+numOutOfPlane+numLocalCoords;
vsite2AvgAtoms = OpenCLArray::create<mm_int4>(context, max(1, num2Avg), "vsite2AvgAtoms");
vsite3AvgAtoms = OpenCLArray::create<mm_int4>(context, max(1, num3Avg), "vsite3AvgAtoms");
vsiteOutOfPlaneAtoms = OpenCLArray::create<mm_int4>(context, max(1, numOutOfPlane), "vsiteOutOfPlaneAtoms");
......@@ -706,6 +707,20 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
}
}
// If multiple virtual sites depend on the same particle, make sure the force distribution
// can be done safely.
vector<int> atomCounts(numAtoms, 0);
for (int i = 0; i < numAtoms; i++)
if (system.isVirtualSite(i))
for (int j = 0; j < system.getVirtualSite(i).getNumParticles(); j++)
atomCounts[system.getVirtualSite(i).getParticle(j)]++;
for (int i = 0; i < numAtoms; i++)
if (atomCounts[i] > 1)
hasOverlappingVsites = true;
if (hasOverlappingVsites && context.getUseDoublePrecision() && !context.getSupports64BitGlobalAtomics())
throw OpenMMException("This device does not support 64 bit atomics. Cannot use double precision when multiple virtual sites depend on the same atom.");
// Create the kernels for virtual sites.
map<string, string> defines;
......@@ -713,6 +728,10 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
defines["NUM_3_AVERAGE"] = context.intToString(num3Avg);
defines["NUM_OUT_OF_PLANE"] = context.intToString(numOutOfPlane);
defines["NUM_LOCAL_COORDS"] = context.intToString(numLocalCoords);
defines["NUM_ATOMS"] = context.intToString(numAtoms);
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
if (hasOverlappingVsites)
defines["HAS_OVERLAPPING_VSITES"] = "1";
cl::Program vsiteProgram = context.createProgram(OpenCLKernelSources::virtualSites, defines);
vsitePositionKernel = cl::Kernel(vsiteProgram, "computeVirtualSites");
int index = 0;
......@@ -731,6 +750,8 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
index = 0;
vsiteForceKernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
index++; // Skip argument 1: the force array hasn't been created yet.
if (context.getSupports64BitGlobalAtomics())
index++; // Skip argument 2: the force array hasn't been created yet.
if (context.getUseMixedPrecision())
vsiteForceKernel.setArg<cl::Buffer>(index++, context.getPosqCorrection().getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(index++, vsite2AvgAtoms->getDeviceBuffer());
......@@ -741,7 +762,8 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
vsiteForceKernel.setArg<cl::Buffer>(index++, vsiteOutOfPlaneWeights->getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(index++, vsiteLocalCoordsAtoms->getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(index++, vsiteLocalCoordsParams->getDeviceBuffer());
numVsites = num2Avg+num3Avg+numOutOfPlane+numLocalCoords;
if (hasOverlappingVsites && context.getSupports64BitGlobalAtomics())
vsiteAddForcesKernel = cl::Kernel(vsiteProgram, "addDistributedForces");
}
OpenCLIntegrationUtilities::~OpenCLIntegrationUtilities() {
......@@ -941,8 +963,25 @@ void OpenCLIntegrationUtilities::computeVirtualSites() {
void OpenCLIntegrationUtilities::distributeForcesFromVirtualSites() {
if (numVsites > 0) {
// Set arguments that didn't exist yet in the constructor.
vsiteForceKernel.setArg<cl::Buffer>(1, context.getForce().getDeviceBuffer());
if (context.getSupports64BitGlobalAtomics()) {
vsiteForceKernel.setArg<cl::Buffer>(2, context.getLongForceBuffer().getDeviceBuffer());
if (hasOverlappingVsites) {
// We'll be using 64 bit atomics for the force redistribution, so clear the buffer.
context.clearBuffer(context.getLongForceBuffer());
}
}
context.executeKernel(vsiteForceKernel, numVsites);
if (context.getSupports64BitGlobalAtomics() && hasOverlappingVsites) {
// Add the redistributed forces from the virtual sites to the main force array.
vsiteAddForcesKernel.setArg<cl::Buffer>(0, context.getLongForceBuffer().getDeviceBuffer());
vsiteAddForcesKernel.setArg<cl::Buffer>(1, context.getForce().getDeviceBuffer());
context.executeKernel(vsiteAddForcesKernel, context.getNumAtoms());
}
}
}
......
......@@ -108,10 +108,66 @@ __kernel void computeVirtualSites(__global real4* restrict posq,
}
}
#ifdef HAS_OVERLAPPING_VSITES
#ifdef SUPPORTS_64_BIT_ATOMICS
// We will use 64 bit atomics for force redistribution.
#define ADD_FORCE(index, f) addForce(index, f, longForce);
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
void addForce(int index, float4 f, __global long* longForce) {
atom_add(&longForce[index], (long) (f.x*0x100000000));
atom_add(&longForce[index+PADDED_NUM_ATOMS], (long) (f.y*0x100000000));
atom_add(&longForce[index+2*PADDED_NUM_ATOMS], (long) (f.z*0x100000000));
}
__kernel void addDistributedForces(__global const long* restrict longForces, __global real4* restrict forces) {
real scale = 1/(real) 0x100000000;
for (int index = get_global_id(0); index < NUM_ATOMS; index += get_global_size(0)) {
real4 f = (real4) (scale*longForces[index], scale*longForces[index+PADDED_NUM_ATOMS], scale*longForces[index+2*PADDED_NUM_ATOMS], 0);
forces[index] += f;
}
}
#else
// 64 bit atomics aren't supported, so we have to use atomic_cmpxchg() for force redistribution.
#define ADD_FORCE(index, f) addForce(index, f, force);
void atomicAddFloat(__global float* p, float v) {
__global int* ip = (__global int*) p;
while (true) {
union {
float f;
int i;
} oldval, newval;
oldval.f = *p;
newval.f = oldval.f+v;
if (atomic_cmpxchg(ip, oldval.i, newval.i) == oldval.i)
return;
}
}
void addForce(int index, float4 f, __global float4* force) {
__global float* components = (__global float*) force;
atomicAddFloat(&components[4*index], f.x);
atomicAddFloat(&components[4*index+1], f.y);
atomicAddFloat(&components[4*index+2], f.z);
}
#endif
#else
// There are no overlapping virtual sites, so we can just store forces directly.
#define ADD_FORCE(index, f) force[index].xyz += (f).xyz;
#endif
/**
* Distribute forces from virtual sites to the atoms they are based on.
*/
__kernel void distributeForces(__global const real4* restrict posq, __global real4* restrict force,
#ifdef SUPPORTS_64_BIT_ATOMICS
__global long* restrict longForce,
#endif
#ifdef USE_MIXED_PRECISION
__global real4* restrict posqCorrection,
#endif
......@@ -129,12 +185,8 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
int4 atoms = avg2Atoms[index];
real2 weights = avg2Weights[index];
real4 f = force[atoms.x];
real4 f1 = force[atoms.y];
real4 f2 = force[atoms.z];
f1.xyz += f.xyz*weights.x;
f2.xyz += f.xyz*weights.y;
force[atoms.y] = f1;
force[atoms.z] = f2;
ADD_FORCE(atoms.y, f*weights.x);
ADD_FORCE(atoms.z, f*weights.y);
}
// Three particle average sites.
......@@ -143,15 +195,9 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
int4 atoms = avg3Atoms[index];
real4 weights = avg3Weights[index];
real4 f = force[atoms.x];
real4 f1 = force[atoms.y];
real4 f2 = force[atoms.z];
real4 f3 = force[atoms.w];
f1.xyz += f.xyz*weights.x;
f2.xyz += f.xyz*weights.y;
f3.xyz += f.xyz*weights.z;
force[atoms.y] = f1;
force[atoms.z] = f2;
force[atoms.w] = f3;
ADD_FORCE(atoms.y, f*weights.x);
ADD_FORCE(atoms.z, f*weights.y);
ADD_FORCE(atoms.w, f*weights.z);
}
// Out of plane sites.
......@@ -165,21 +211,15 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
mixed4 v12 = pos2-pos1;
mixed4 v13 = pos3-pos1;
real4 f = force[atoms.x];
real4 f1 = force[atoms.y];
real4 f2 = force[atoms.z];
real4 f3 = force[atoms.w];
real4 fp2 = (real4) (weights.x*f.x - weights.z*v13.z*f.y + weights.z*v13.y*f.z,
weights.z*v13.z*f.x + weights.x*f.y - weights.z*v13.x*f.z,
-weights.z*v13.y*f.x + weights.z*v13.x*f.y + weights.x*f.z, 0.0f);
real4 fp3 = (real4) (weights.y*f.x + weights.z*v12.z*f.y - weights.z*v12.y*f.z,
-weights.z*v12.z*f.x + weights.y*f.y + weights.z*v12.x*f.z,
weights.z*v12.y*f.x - weights.z*v12.x*f.y + weights.y*f.z, 0.0f);
f1.xyz += f.xyz-fp2.xyz-fp3.xyz;
f2.xyz += fp2.xyz;
f3.xyz += fp3.xyz;
force[atoms.y] = f1;
force[atoms.z] = f2;
force[atoms.w] = f3;
ADD_FORCE(atoms.y, f-fp2-fp3);
ADD_FORCE(atoms.z, fp2);
ADD_FORCE(atoms.w, fp3);
}
// Local coordinates sites.
......@@ -230,9 +270,9 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
mixed sz3 = t32*dz.x-t31*dz.y;
mixed4 wxScaled = wx*invNormXdir;
real4 f = force[atoms.x];
real4 f1 = force[atoms.y];
real4 f2 = force[atoms.z];
real4 f3 = force[atoms.w];
real4 f1 = 0;
real4 f2 = 0;
real4 f3 = 0;
mixed4 fp1 = localPosition*f.x;
mixed4 fp2 = localPosition*f.y;
mixed4 fp3 = localPosition*f.z;
......@@ -263,8 +303,8 @@ __kernel void distributeForces(__global const real4* restrict posq, __global rea
f3.x += fp3.x*wxScaled.z*( -dx.z*dx.x) + fp3.z*(dz.z*sx3+t32) + fp3.y*((-dx.x*dy.z-dz.y)*wxScaled.z + dy.z*sx3 + dx.x*t33);
f3.y += fp3.x*wxScaled.z*( -dx.z*dx.y) + fp3.z*(dz.z*sy3-t31) + fp3.y*((-dx.y*dy.z+dz.x)*wxScaled.z + dy.z*sy3 + dx.y*t33);
f3.z += fp3.x*wxScaled.z*(1-dx.z*dx.z) + fp3.z*(dz.z*sz3 ) + fp3.y*((-dx.z*dy.z )*wxScaled.z + dy.z*sz3 - dx.x*t31 - dx.y*t32) + f.z*originWeights.z;
force[atoms.y] = f1;
force[atoms.z] = f2;
force[atoms.w] = f3;
ADD_FORCE(atoms.y, f1);
ADD_FORCE(atoms.z, f2);
ADD_FORCE(atoms.w, f3);
}
}
......@@ -524,6 +524,47 @@ void testReordering() {
}
}
/**
* Test a System where multiple virtual sites are all calculated from the same particles.
*/
void testOverlappingSites() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
nonbonded->addParticle(1.0, 0.0, 0.0);
nonbonded->addParticle(-0.5, 0.0, 0.0);
nonbonded->addParticle(-0.5, 0.0, 0.0);
vector<Vec3> positions;
positions.push_back(Vec3(0, 0, 0));
positions.push_back(Vec3(10, 0, 0));
positions.push_back(Vec3(0, 10, 0));
for (int i = 0; i < 20; i++) {
system.addParticle(0.0);
double u = 0.1*((i+1)%4);
double v = 0.05*i;
system.setVirtualSite(3+i, new ThreeParticleAverageSite(0, 1, 2, u, v, 1-u-v));
nonbonded->addParticle(i%2 == 0 ? -1.0 : 1.0, 0.0, 0.0);
positions.push_back(Vec3());
}
VerletIntegrator i1(0.002);
VerletIntegrator i2(0.002);
Context c1(system, i1, Platform::getPlatformByName("Reference"));
Context c2(system, i2, platform);
c1.setPositions(positions);
c2.setPositions(positions);
c1.applyConstraints(0.0001);
c2.applyConstraints(0.0001);
State s1 = c1.getState(State::Positions | State::Forces);
State s2 = c2.getState(State::Positions | State::Forces);
for (int i = 0; i < system.getNumParticles(); i++)
ASSERT_EQUAL_VEC(s1.getPositions()[i], s2.getPositions()[i], 1e-5);
for (int i = 0; i < 3; i++)
ASSERT_EQUAL_VEC(s1.getForces()[i], s2.getForces()[i], 1e-5);
}
int main(int argc, char* argv[]) {
try {
if (argc > 1)
......@@ -535,6 +576,7 @@ int main(int argc, char* argv[]) {
testLocalCoordinates();
testConservationLaws();
testReordering();
testOverlappingSites();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment