Commit 72ab8864 authored by peastman's avatar peastman
Browse files

Continuing CUDA implementation of spherical harmonics for multipoles

parent fa21870f
......@@ -1150,6 +1150,7 @@ void CudaCalcAmoebaMultipoleForceKernel::initialize(const System& system, const
if (maxInducedIterations > 0) {
defines["THREAD_BLOCK_SIZE"] = cu.intToString(inducedFieldThreads);
defines["MAX_PREV_DIIS_DIPOLES"] = cu.intToString(MaxPrevDIISDipoles);
defines["USE_MUTUAL_POLARIZATION"] = "1";
module = cu.createModule(CudaKernelSources::vectorOps+CudaAmoebaKernelSources::multipoleInducedField, defines);
computeInducedFieldKernel = cu.getKernel(module, "computeInducedField");
updateInducedFieldKernel = cu.getKernel(module, "updateInducedFieldByDIIS");
......@@ -1159,10 +1160,8 @@ void CudaCalcAmoebaMultipoleForceKernel::initialize(const System& system, const
stringstream electrostaticsSource;
if (usePME) {
electrostaticsSource << CudaKernelSources::vectorOps;
electrostaticsSource << CudaAmoebaKernelSources::sphericalMultipoles;
electrostaticsSource << CudaAmoebaKernelSources::pmeMultipoleElectrostatics;
electrostaticsSource << (hasQuadrupoles ? CudaAmoebaKernelSources::pmeElectrostaticPairForce : CudaAmoebaKernelSources::pmeElectrostaticPairForceNoQuadrupoles);
electrostaticsSource << "#define APPLY_SCALE\n";
electrostaticsSource << (hasQuadrupoles ? CudaAmoebaKernelSources::pmeElectrostaticPairForce : CudaAmoebaKernelSources::pmeElectrostaticPairForceNoQuadrupoles);
electrostaticsThreadMemory = 24*elementSize+3*sizeof(float)+3*sizeof(int)/(double) cu.TileSize;
if (!useShuffle)
electrostaticsThreadMemory += 3*elementSize;
......@@ -1659,8 +1658,8 @@ double CudaCalcAmoebaMultipoleForceKernel::execute(ContextImpl& context, bool in
&nb.getInteractingTiles().getDevicePointer(), &nb.getInteractionCount().getDevicePointer(),
cu.getPeriodicBoxSizePointer(), cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(),
&maxTiles, &nb.getBlockCenters().getDevicePointer(), &nb.getInteractingAtoms().getDevicePointer(),
&labFrameDipoles->getDevicePointer(), &labFrameQuadrupoles->getDevicePointer(), &inducedDipole->getDevicePointer(),
&inducedDipolePolar->getDevicePointer(), &dampingAndThole->getDevicePointer()};
&labFrameDipoles->getDevicePointer(), &labFrameQuadrupoles->getDevicePointer(), &sphericalDipoles->getDevicePointer(), &sphericalQuadrupoles->getDevicePointer(),
&inducedDipole->getDevicePointer(), &inducedDipolePolar->getDevicePointer(), &dampingAndThole->getDevicePointer()};
cu.executeKernel(electrostaticsKernel, electrostaticsArgs, numForceThreadBlocks*electrostaticsThreads, electrostaticsThreads);
void* pmeTransformInducedPotentialArgs[] = {&pmePhidp->getDevicePointer(), &pmeCphi->getDevicePointer(), recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]};
cu.executeKernel(pmeTransformPotentialKernel, pmeTransformInducedPotentialArgs, cu.getNumAtoms());
......
......@@ -10,10 +10,10 @@ extern "C" __global__ void computeLabFrameMoments(real4* __restrict__ posq, int4
sphericalDipoles[offset+2] = molecularDipoles[offset+1]; // y -> Q_11s
offset = 5*atom;
sphericalQuadrupoles[offset+0] = -3.0f*(molecularQuadrupoles[offset+0]+molecularQuadrupoles[offset+3]); // zz -> Q_20
sphericalQuadrupoles[offset+1] = (2*SQRT(3))*molecularQuadrupoles[offset+2]; // xz -> Q_21c
sphericalQuadrupoles[offset+2] = (2*SQRT(3))*molecularQuadrupoles[offset+4]; // yz -> Q_21s
sphericalQuadrupoles[offset+3] = SQRT(3)*(molecularQuadrupoles[offset+0]-molecularQuadrupoles[offset+3]); // xx-yy -> Q_22c
sphericalQuadrupoles[offset+4] = (2*SQRT(3))*molecularQuadrupoles[offset+1]; // xy -> Q_22s
sphericalQuadrupoles[offset+1] = (2*SQRT((real) 3))*molecularQuadrupoles[offset+2]; // xz -> Q_21c
sphericalQuadrupoles[offset+2] = (2*SQRT((real) 3))*molecularQuadrupoles[offset+4]; // yz -> Q_21s
sphericalQuadrupoles[offset+3] = SQRT((real) 3)*(molecularQuadrupoles[offset+0]-molecularQuadrupoles[offset+3]); // xx-yy -> Q_22c
sphericalQuadrupoles[offset+4] = (2*SQRT((real) 3))*molecularQuadrupoles[offset+1]; // xy -> Q_22s
// get coordinates of this atom and the z & x axis atoms
// compute the vector between the atoms and 1/sqrt(d2), d2 is distance between
......@@ -236,32 +236,32 @@ extern "C" __global__ void computeLabFrameMoments(real4* __restrict__ posq, int4
sphericalQuadrupole[4] *= -1;
}
real rotatedQuadrupole[5] = {0, 0, 0, 0, 0};
real sqrtThree = SQRT(3);
rotatedQuadrupole[0] += sphericalQuadrupole[0]*0.5f*(3.0f*vectorZ.z*vectorZ.z - 1.0f);
rotatedQuadrupole[1] += sphericalQuadrupole[0]*sqrtThree*vectorZ.z*vectorZ.x;
rotatedQuadrupole[2] += sphericalQuadrupole[0]*sqrtThree*vectorZ.z*vectorZ.y;
rotatedQuadrupole[3] += sphericalQuadrupole[0]*0.5f*sqrtThree*(vectorZ.x*vectorZ.x - vectorZ.y*vectorZ.y);
rotatedQuadrupole[4] += sphericalQuadrupole[0]*sqrtThree*vectorZ.x*vectorZ.y;
rotatedQuadrupole[0] += sphericalQuadrupole[1]*sqrtThree*vectorZ.z*vectorX.z;
rotatedQuadrupole[1] += sphericalQuadrupole[1]*(vectorZ.x*vectorX.z + vectorZ.z*vectorX.x);
rotatedQuadrupole[2] += sphericalQuadrupole[1]*(vectorZ.y*vectorX.z + vectorZ.z*vectorX.y);
rotatedQuadrupole[3] += sphericalQuadrupole[1]*(vectorZ.x*vectorX.x - vectorZ.y*vectorX.y);
rotatedQuadrupole[4] += sphericalQuadrupole[1]*(vectorZ.y*vectorX.x + vectorZ.x*vectorX.y);
rotatedQuadrupole[0] += sphericalQuadrupole[2]*sqrtThree*vectorZ.z*vectorY.z;
rotatedQuadrupole[1] += sphericalQuadrupole[2]*(vectorZ.x*vectorY.z + vectorZ.z*vectorY.x);
rotatedQuadrupole[2] += sphericalQuadrupole[2]*(vectorZ.y*vectorY.z + vectorZ.z*vectorY.y);
rotatedQuadrupole[3] += sphericalQuadrupole[2]*(vectorZ.x*vectorY.x - vectorZ.y*vectorY.y);
rotatedQuadrupole[4] += sphericalQuadrupole[2]*(vectorZ.y*vectorY.x + vectorZ.x*vectorY.y);
rotatedQuadrupole[0] += sphericalQuadrupole[3]*0.5f*sqrtThree*(vectorX.z*vectorX.z - vectorY.z*vectorY.z);
rotatedQuadrupole[1] += sphericalQuadrupole[3]*(vectorX.z*vectorX.x - vectorY.z*vectorY.x);
rotatedQuadrupole[2] += sphericalQuadrupole[3]*(vectorX.z*vectorX.y - vectorY.z*vectorY.y);
rotatedQuadrupole[3] += sphericalQuadrupole[3]*0.5f*(vectorX.x*vectorX.x - vectorX.y*vectorX.y - vectorY.x*vectorY.x + vectorY.y*vectorY.y);
rotatedQuadrupole[4] += sphericalQuadrupole[3]*(vectorX.x*vectorX.y - vectorY.x*vectorY.y);
rotatedQuadrupole[0] += sphericalQuadrupole[4]*sqrtThree*vectorX.z*vectorY.z;
rotatedQuadrupole[1] += sphericalQuadrupole[4]*(vectorX.x*vectorY.z + vectorX.z*vectorY.x);
rotatedQuadrupole[2] += sphericalQuadrupole[4]*(vectorX.y*vectorY.z + vectorX.z*vectorY.y);
rotatedQuadrupole[3] += sphericalQuadrupole[4]*(vectorX.x*vectorY.x - vectorX.y*vectorY.y);
rotatedQuadrupole[4] += sphericalQuadrupole[4]*(vectorX.y*vectorY.x + vectorX.x*vectorY.y);
real sqrtThree = SQRT((real) 3);
rotatedQuadrupole[0] += sphericalQuadrupole[0]*0.5f*(3.0f*vectorZ.z*vectorZ.z - 1.0f) +
sphericalQuadrupole[1]*sqrtThree*vectorZ.z*vectorX.z +
sphericalQuadrupole[2]*sqrtThree*vectorZ.z*vectorY.z +
sphericalQuadrupole[3]*0.5f*sqrtThree*(vectorX.z*vectorX.z - vectorY.z*vectorY.z) +
sphericalQuadrupole[4]*sqrtThree*vectorX.z*vectorY.z;
rotatedQuadrupole[1] += sphericalQuadrupole[0]*sqrtThree*vectorZ.z*vectorZ.x +
sphericalQuadrupole[1]*(vectorZ.x*vectorX.z + vectorZ.z*vectorX.x) +
sphericalQuadrupole[2]*(vectorZ.x*vectorY.z + vectorZ.z*vectorY.x) +
sphericalQuadrupole[3]*(vectorX.z*vectorX.x - vectorY.z*vectorY.x) +
sphericalQuadrupole[4]*(vectorX.x*vectorY.z + vectorX.z*vectorY.x);
rotatedQuadrupole[2] += sphericalQuadrupole[0]*sqrtThree*vectorZ.z*vectorZ.y +
sphericalQuadrupole[1]*(vectorZ.y*vectorX.z + vectorZ.z*vectorX.y) +
sphericalQuadrupole[2]*(vectorZ.y*vectorY.z + vectorZ.z*vectorY.y) +
sphericalQuadrupole[3]*(vectorX.z*vectorX.y - vectorY.z*vectorY.y) +
sphericalQuadrupole[4]*(vectorX.y*vectorY.z + vectorX.z*vectorY.y);
rotatedQuadrupole[3] += sphericalQuadrupole[0]*0.5f*sqrtThree*(vectorZ.x*vectorZ.x - vectorZ.y*vectorZ.y) +
sphericalQuadrupole[1]*(vectorZ.x*vectorX.x - vectorZ.y*vectorX.y) +
sphericalQuadrupole[2]*(vectorZ.x*vectorY.x - vectorZ.y*vectorY.y) +
sphericalQuadrupole[3]*0.5f*(vectorX.x*vectorX.x - vectorX.y*vectorX.y - vectorY.x*vectorY.x + vectorY.y*vectorY.y) +
sphericalQuadrupole[4]*(vectorX.x*vectorY.x - vectorX.y*vectorY.y);
rotatedQuadrupole[4] += sphericalQuadrupole[0]*sqrtThree*vectorZ.x*vectorZ.y +
sphericalQuadrupole[1]*(vectorZ.y*vectorX.x + vectorZ.x*vectorX.y) +
sphericalQuadrupole[2]*(vectorZ.y*vectorY.x + vectorZ.x*vectorY.y) +
sphericalQuadrupole[3]*(vectorX.x*vectorX.y - vectorY.x*vectorY.y) +
sphericalQuadrupole[4]*(vectorX.y*vectorY.x + vectorX.x*vectorY.y);
sphericalQuadrupoles[offset] = rotatedQuadrupole[0];
sphericalQuadrupoles[offset+1] = rotatedQuadrupole[1];
sphericalQuadrupoles[offset+2] = rotatedQuadrupole[2];
......
__device__ void buildQIRotationMatrix(real3 deltaR, real rInv, real (&rotationMatrix)[3][3]) {
real3 vectorZ = deltaR*rInv;
real3 vectorX = vectorZ;
if (deltaR.y != 0 || deltaR.z != 0)
vectorX.x += 1;
else
vectorX.y += 1;
vectorX -= vectorZ*dot(vectorX, vectorZ);
vectorX = normalize(vectorX);
real3 vectorY = cross(vectorZ, vectorX);
// Reorder the Cartesian {x,y,z} dipole rotation matrix, to account
// for spherical harmonic ordering {z,x,y}.
rotationMatrix[0][0] = vectorZ.z;
rotationMatrix[0][1] = vectorZ.x;
rotationMatrix[0][2] = vectorZ.y;
rotationMatrix[1][0] = vectorX.z;
rotationMatrix[1][1] = vectorX.x;
rotationMatrix[1][2] = vectorX.y;
rotationMatrix[2][0] = vectorY.z;
rotationMatrix[2][1] = vectorY.x;
rotationMatrix[2][2] = vectorY.y;
}
__device__ real3 rotateDipole(real3& dipole, const real (&rotationMatrix)[3][3]) {
return make_real3(rotationMatrix[0][0]*dipole.x + rotationMatrix[0][1]*dipole.y + rotationMatrix[0][2]*dipole.z,
rotationMatrix[1][0]*dipole.x + rotationMatrix[1][1]*dipole.y + rotationMatrix[1][2]*dipole.z,
rotationMatrix[2][0]*dipole.x + rotationMatrix[2][1]*dipole.y + rotationMatrix[2][2]*dipole.z);
}
__device__ void rotateQuadupoles(const real (&rotationMatrix)[3][3], const real* quad1, const real* quad2, real* rotated1, real* rotated2) {
real sqrtThree = SQRT((real) 3);
real element;
element = 0.5f*(3.0f*rotationMatrix[0][0]*rotationMatrix[0][0] - 1.0f);
rotated1[0] += quad1[0]*element;
rotated2[0] += quad2[0]*element;
element = sqrtThree*rotationMatrix[0][0]*rotationMatrix[0][1];
rotated1[0] += quad1[1]*element;
rotated2[0] += quad2[1]*element;
element = sqrtThree*rotationMatrix[0][0]*rotationMatrix[0][2];
rotated1[0] += quad1[2]*element;
rotated2[0] += quad2[2]*element;
element = 0.5f*sqrtThree*(rotationMatrix[0][1]*rotationMatrix[0][1] - rotationMatrix[0][2]*rotationMatrix[0][2]);
rotated1[0] += quad1[3]*element;
rotated2[0] += quad2[3]*element;
element = sqrtThree*rotationMatrix[0][1]*rotationMatrix[0][2];
rotated1[0] += quad1[4]*element;
rotated2[0] += quad2[4]*element;
element = sqrtThree*rotationMatrix[0][0]*rotationMatrix[1][0];
rotated1[1] += quad1[0]*element;
rotated2[1] += quad2[0]*element;
element = rotationMatrix[1][0]*rotationMatrix[0][1] + rotationMatrix[0][0]*rotationMatrix[1][1];
rotated1[1] += quad1[1]*element;
rotated2[1] += quad2[1]*element;
element = rotationMatrix[1][0]*rotationMatrix[0][2] + rotationMatrix[0][0]*rotationMatrix[1][2];
rotated1[1] += quad1[2]*element;
rotated2[1] += quad2[2]*element;
element = rotationMatrix[0][1]*rotationMatrix[1][1] - rotationMatrix[0][2]*rotationMatrix[1][2];
rotated1[1] += quad1[3]*element;
rotated2[1] += quad2[3]*element;
element = rotationMatrix[1][1]*rotationMatrix[0][2] + rotationMatrix[0][1]*rotationMatrix[1][2];
rotated1[1] += quad1[4]*element;
rotated2[1] += quad2[4]*element;
element = sqrtThree*rotationMatrix[0][0]*rotationMatrix[2][0];
rotated1[2] += quad1[0]*element;
rotated2[2] += quad2[0]*element;
element = rotationMatrix[2][0]*rotationMatrix[0][1] + rotationMatrix[0][0]*rotationMatrix[2][1];
rotated1[2] += quad1[1]*element;
rotated2[2] += quad2[1]*element;
element = rotationMatrix[2][0]*rotationMatrix[0][2] + rotationMatrix[0][0]*rotationMatrix[2][2];
rotated1[2] += quad1[2]*element;
rotated2[2] += quad2[2]*element;
element = rotationMatrix[0][1]*rotationMatrix[2][1] - rotationMatrix[0][2]*rotationMatrix[2][2];
rotated1[2] += quad1[3]*element;
rotated2[2] += quad2[3]*element;
element = rotationMatrix[2][1]*rotationMatrix[0][2] + rotationMatrix[0][1]*rotationMatrix[2][2];
rotated1[2] += quad1[4]*element;
rotated2[2] += quad2[4]*element;
element = 0.5f*sqrtThree*(rotationMatrix[1][0]*rotationMatrix[1][0] - rotationMatrix[2][0]*rotationMatrix[2][0]);
rotated1[3] += quad1[0]*element;
rotated2[3] += quad2[0]*element;
element = rotationMatrix[1][0]*rotationMatrix[1][1] - rotationMatrix[2][0]*rotationMatrix[2][1];
rotated1[3] += quad1[1]*element;
rotated2[3] += quad2[1]*element;
element = rotationMatrix[1][0]*rotationMatrix[1][2] - rotationMatrix[2][0]*rotationMatrix[2][2];
rotated1[3] += quad1[2]*element;
rotated2[3] += quad2[2]*element;
element = 0.5f*(rotationMatrix[1][1]*rotationMatrix[1][1] - rotationMatrix[2][1]*rotationMatrix[2][1] - rotationMatrix[1][2]*rotationMatrix[1][2] + rotationMatrix[2][2]*rotationMatrix[2][2]);
rotated1[3] += quad1[3]*element;
rotated2[3] += quad2[3]*element;
element = rotationMatrix[1][1]*rotationMatrix[1][2] - rotationMatrix[2][1]*rotationMatrix[2][2];
rotated1[3] += quad1[4]*element;
rotated2[3] += quad2[4]*element;
element = sqrtThree*rotationMatrix[1][0]*rotationMatrix[2][0];
rotated1[4] += quad1[0]*element;
rotated2[4] += quad2[0]*element;
element = rotationMatrix[2][0]*rotationMatrix[1][1] + rotationMatrix[1][0]*rotationMatrix[2][1];
rotated1[4] += quad1[1]*element;
rotated2[4] += quad2[1]*element;
element = rotationMatrix[2][0]*rotationMatrix[1][2] + rotationMatrix[1][0]*rotationMatrix[2][2];
rotated1[4] += quad1[2]*element;
rotated2[4] += quad2[2]*element;
element = rotationMatrix[1][1]*rotationMatrix[2][1] - rotationMatrix[1][2]*rotationMatrix[2][2];
rotated1[4] += quad1[3]*element;
rotated2[4] += quad2[3]*element;
element = rotationMatrix[2][1]*rotationMatrix[1][2] + rotationMatrix[1][1]*rotationMatrix[2][2];
rotated1[4] += quad1[4]*element;
rotated2[4] += quad2[4]*element;
}
......@@ -6676,6 +6676,6 @@ RealOpenMM AmoebaReferencePmeMultipoleForce::calculateElectrostatic(const vector
energy += computeReciprocalSpaceInducedDipoleForceAndEnergy(getPolarizationType(), particleData, forces, torques);
energy += computeReciprocalSpaceFixedMultipoleForceAndEnergy(particleData, forces, torques);
energy += calculatePmeSelfEnergy(particleData);
return energy;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment