Unverified Commit 88f32f2d authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimize computing kinetic energy (#4946)

parent f19c9f59
......@@ -145,7 +145,7 @@ protected:
ComputeKernel ccmaDirectionsKernel, ccmaPosForceKernel, ccmaVelForceKernel;
ComputeKernel ccmaMultiplyKernel, ccmaUpdateKernel, ccmaFullKernel;
ComputeKernel vsitePositionKernel, vsiteForceKernel, vsiteSaveForcesKernel;
ComputeKernel randomKernel, timeShiftKernel;
ComputeKernel randomKernel, timeShiftKernel, kineticEnergyKernel;
ComputeArray posDelta;
ComputeArray settleAtoms;
ComputeArray settleParams;
......@@ -177,7 +177,8 @@ protected:
ComputeArray vsiteLocalCoordsPos;
ComputeArray vsiteLocalCoordsStartIndex;
ComputeArray vsiteStage;
int randomPos, lastSeed, numVsites, numVsiteStages;
ComputeArray kineticEnergy;
int randomPos, lastSeed, numVsites, numVsiteStages, keWorkGroupSize;
bool hasOverlappingVsites;
mm_double2 lastStepSize;
struct ShakeCluster;
......
......@@ -101,6 +101,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
posDelta.upload(deltas);
stepSize.initialize<mm_double2>(context, 1, "stepSize");
stepSize.upload(&lastStepSize);
kineticEnergy.initialize<double>(context, 1, "kineticEnergy");
}
else {
posDelta.initialize<mm_float4>(context, context.getPaddedNumAtoms(), "posDelta");
......@@ -109,7 +110,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
stepSize.initialize<mm_float2>(context, 1, "stepSize");
mm_float2 lastStepSizeFloat = mm_float2(0.0f, 0.0f);
stepSize.upload(&lastStepSizeFloat);
kineticEnergy.initialize<float>(context, 1, "kineticEnergy");
}
keWorkGroupSize = context.getMaxThreadBlockSize();
if (keWorkGroupSize > 512)
keWorkGroupSize = 512;
// Record the set of constraints and how many constraints each atom is involved in.
......@@ -573,6 +578,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
defines["NUM_OUT_OF_PLANE"] = context.intToString(numOutOfPlane);
defines["NUM_LOCAL_COORDS"] = context.intToString(numLocalCoords);
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
defines["KE_WORK_GROUP_SIZE"] = context.intToString(keWorkGroupSize);
if (hasOverlappingVsites)
defines["HAS_OVERLAPPING_VSITES"] = "1";
if (numVsiteStages > 1)
......@@ -593,6 +599,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
vsiteSaveForcesKernel = program->createKernel("saveDistributedForces");
randomKernel = program->createKernel("generateRandomNumbers");
timeShiftKernel = program->createKernel("timeShiftVelocities");
kineticEnergyKernel = program->createKernel("computeKineticEnergy");
// Set arguments for virtual site kernels.
......@@ -740,6 +747,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
for (int i = 0; i < 3; i++)
timeShiftKernel->addArg();
// Set arguments of kinetic energy kernel.
kineticEnergyKernel->addArg(context.getVelm());
kineticEnergyKernel->addArg(kineticEnergy);
}
void IntegrationUtilities::setNextStepSize(double size) {
......@@ -874,31 +886,22 @@ double IntegrationUtilities::computeKineticEnergy(double timeShift) {
// Compute the kinetic energy.
double energy = 0.0;
if (context.getUseDoublePrecision() || context.getUseMixedPrecision()) {
auto velm = (mm_double4*)context.getPinnedBuffer();
context.getVelm().download(velm);
for (int i = 0; i < numParticles; i++) {
mm_double4 v = velm[i];
if (v.w != 0)
energy += (v.x*v.x+v.y*v.y+v.z*v.z)/v.w;
}
}
else {
auto velm = (mm_float4*)context.getPinnedBuffer();
context.getVelm().download(velm);
for (int i = 0; i < numParticles; i++) {
mm_float4 v = velm[i];
if (v.w != 0)
energy += (v.x*v.x+v.y*v.y+v.z*v.z)/v.w;
}
}
kineticEnergyKernel->execute(keWorkGroupSize, keWorkGroupSize);
// Restore the velocities.
if (timeShift != 0)
posDelta.copyTo(context.getVelm());
return 0.5*energy;
if (context.getUseDoublePrecision() || context.getUseMixedPrecision()) {
double energy;
kineticEnergy.download(&energy);
return energy;
}
else {
float energy;
kineticEnergy.download(&energy);
return energy;
}
}
void IntegrationUtilities::computeShiftedVelocities(double timeShift, vector<Vec3>& velocities) {
......
......@@ -1075,3 +1075,24 @@ KERNEL void timeShiftVelocities(GLOBAL mixed4* RESTRICT velm, GLOBAL const mm_lo
}
}
}
/**
* Compute the total kinetic energy.
*/
KERNEL void computeKineticEnergy(GLOBAL mixed4* RESTRICT velm, GLOBAL mixed* result) {
LOCAL mixed tempBuffer[KE_WORK_GROUP_SIZE];
mixed sum = 0;
for (unsigned int index = LOCAL_ID; index < NUM_ATOMS; index += LOCAL_SIZE) {
mixed4 v = velm[index];
if (v.w != 0)
sum += (v.x*v.x+v.y*v.y+v.z*v.z)/v.w;
}
tempBuffer[LOCAL_ID] = sum;
for (int i = 1; i < KE_WORK_GROUP_SIZE; i *= 2) {
SYNC_THREADS;
if (LOCAL_ID%(i*2) == 0 && LOCAL_ID+i < KE_WORK_GROUP_SIZE)
tempBuffer[LOCAL_ID] += tempBuffer[LOCAL_ID+i];
}
if (LOCAL_ID == 0)
*result = 0.5f*tempBuffer[0];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment