Unverified Commit 88f32f2d authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimize computing kinetic energy (#4946)

parent f19c9f59
...@@ -145,7 +145,7 @@ protected: ...@@ -145,7 +145,7 @@ protected:
ComputeKernel ccmaDirectionsKernel, ccmaPosForceKernel, ccmaVelForceKernel; ComputeKernel ccmaDirectionsKernel, ccmaPosForceKernel, ccmaVelForceKernel;
ComputeKernel ccmaMultiplyKernel, ccmaUpdateKernel, ccmaFullKernel; ComputeKernel ccmaMultiplyKernel, ccmaUpdateKernel, ccmaFullKernel;
ComputeKernel vsitePositionKernel, vsiteForceKernel, vsiteSaveForcesKernel; ComputeKernel vsitePositionKernel, vsiteForceKernel, vsiteSaveForcesKernel;
ComputeKernel randomKernel, timeShiftKernel; ComputeKernel randomKernel, timeShiftKernel, kineticEnergyKernel;
ComputeArray posDelta; ComputeArray posDelta;
ComputeArray settleAtoms; ComputeArray settleAtoms;
ComputeArray settleParams; ComputeArray settleParams;
...@@ -177,7 +177,8 @@ protected: ...@@ -177,7 +177,8 @@ protected:
ComputeArray vsiteLocalCoordsPos; ComputeArray vsiteLocalCoordsPos;
ComputeArray vsiteLocalCoordsStartIndex; ComputeArray vsiteLocalCoordsStartIndex;
ComputeArray vsiteStage; ComputeArray vsiteStage;
int randomPos, lastSeed, numVsites, numVsiteStages; ComputeArray kineticEnergy;
int randomPos, lastSeed, numVsites, numVsiteStages, keWorkGroupSize;
bool hasOverlappingVsites; bool hasOverlappingVsites;
mm_double2 lastStepSize; mm_double2 lastStepSize;
struct ShakeCluster; struct ShakeCluster;
......
...@@ -101,6 +101,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System ...@@ -101,6 +101,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
posDelta.upload(deltas); posDelta.upload(deltas);
stepSize.initialize<mm_double2>(context, 1, "stepSize"); stepSize.initialize<mm_double2>(context, 1, "stepSize");
stepSize.upload(&lastStepSize); stepSize.upload(&lastStepSize);
kineticEnergy.initialize<double>(context, 1, "kineticEnergy");
} }
else { else {
posDelta.initialize<mm_float4>(context, context.getPaddedNumAtoms(), "posDelta"); posDelta.initialize<mm_float4>(context, context.getPaddedNumAtoms(), "posDelta");
...@@ -109,7 +110,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System ...@@ -109,7 +110,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
stepSize.initialize<mm_float2>(context, 1, "stepSize"); stepSize.initialize<mm_float2>(context, 1, "stepSize");
mm_float2 lastStepSizeFloat = mm_float2(0.0f, 0.0f); mm_float2 lastStepSizeFloat = mm_float2(0.0f, 0.0f);
stepSize.upload(&lastStepSizeFloat); stepSize.upload(&lastStepSizeFloat);
kineticEnergy.initialize<float>(context, 1, "kineticEnergy");
} }
keWorkGroupSize = context.getMaxThreadBlockSize();
if (keWorkGroupSize > 512)
keWorkGroupSize = 512;
// Record the set of constraints and how many constraints each atom is involved in. // Record the set of constraints and how many constraints each atom is involved in.
...@@ -573,6 +578,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System ...@@ -573,6 +578,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
defines["NUM_OUT_OF_PLANE"] = context.intToString(numOutOfPlane); defines["NUM_OUT_OF_PLANE"] = context.intToString(numOutOfPlane);
defines["NUM_LOCAL_COORDS"] = context.intToString(numLocalCoords); defines["NUM_LOCAL_COORDS"] = context.intToString(numLocalCoords);
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms()); defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
defines["KE_WORK_GROUP_SIZE"] = context.intToString(keWorkGroupSize);
if (hasOverlappingVsites) if (hasOverlappingVsites)
defines["HAS_OVERLAPPING_VSITES"] = "1"; defines["HAS_OVERLAPPING_VSITES"] = "1";
if (numVsiteStages > 1) if (numVsiteStages > 1)
...@@ -593,6 +599,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System ...@@ -593,6 +599,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
vsiteSaveForcesKernel = program->createKernel("saveDistributedForces"); vsiteSaveForcesKernel = program->createKernel("saveDistributedForces");
randomKernel = program->createKernel("generateRandomNumbers"); randomKernel = program->createKernel("generateRandomNumbers");
timeShiftKernel = program->createKernel("timeShiftVelocities"); timeShiftKernel = program->createKernel("timeShiftVelocities");
kineticEnergyKernel = program->createKernel("computeKineticEnergy");
// Set arguments for virtual site kernels. // Set arguments for virtual site kernels.
...@@ -740,6 +747,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System ...@@ -740,6 +747,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; i++)
timeShiftKernel->addArg(); timeShiftKernel->addArg();
// Set arguments of kinetic energy kernel.
kineticEnergyKernel->addArg(context.getVelm());
kineticEnergyKernel->addArg(kineticEnergy);
} }
void IntegrationUtilities::setNextStepSize(double size) { void IntegrationUtilities::setNextStepSize(double size) {
...@@ -874,31 +886,22 @@ double IntegrationUtilities::computeKineticEnergy(double timeShift) { ...@@ -874,31 +886,22 @@ double IntegrationUtilities::computeKineticEnergy(double timeShift) {
// Compute the kinetic energy. // Compute the kinetic energy.
double energy = 0.0; kineticEnergyKernel->execute(keWorkGroupSize, keWorkGroupSize);
if (context.getUseDoublePrecision() || context.getUseMixedPrecision()) {
auto velm = (mm_double4*)context.getPinnedBuffer();
context.getVelm().download(velm);
for (int i = 0; i < numParticles; i++) {
mm_double4 v = velm[i];
if (v.w != 0)
energy += (v.x*v.x+v.y*v.y+v.z*v.z)/v.w;
}
}
else {
auto velm = (mm_float4*)context.getPinnedBuffer();
context.getVelm().download(velm);
for (int i = 0; i < numParticles; i++) {
mm_float4 v = velm[i];
if (v.w != 0)
energy += (v.x*v.x+v.y*v.y+v.z*v.z)/v.w;
}
}
// Restore the velocities. // Restore the velocities.
if (timeShift != 0) if (timeShift != 0)
posDelta.copyTo(context.getVelm()); posDelta.copyTo(context.getVelm());
return 0.5*energy; if (context.getUseDoublePrecision() || context.getUseMixedPrecision()) {
double energy;
kineticEnergy.download(&energy);
return energy;
}
else {
float energy;
kineticEnergy.download(&energy);
return energy;
}
} }
void IntegrationUtilities::computeShiftedVelocities(double timeShift, vector<Vec3>& velocities) { void IntegrationUtilities::computeShiftedVelocities(double timeShift, vector<Vec3>& velocities) {
......
...@@ -1075,3 +1075,24 @@ KERNEL void timeShiftVelocities(GLOBAL mixed4* RESTRICT velm, GLOBAL const mm_lo ...@@ -1075,3 +1075,24 @@ KERNEL void timeShiftVelocities(GLOBAL mixed4* RESTRICT velm, GLOBAL const mm_lo
} }
} }
} }
/**
* Compute the total kinetic energy.
*/
KERNEL void computeKineticEnergy(GLOBAL mixed4* RESTRICT velm, GLOBAL mixed* result) {
LOCAL mixed tempBuffer[KE_WORK_GROUP_SIZE];
mixed sum = 0;
for (unsigned int index = LOCAL_ID; index < NUM_ATOMS; index += LOCAL_SIZE) {
mixed4 v = velm[index];
if (v.w != 0)
sum += (v.x*v.x+v.y*v.y+v.z*v.z)/v.w;
}
tempBuffer[LOCAL_ID] = sum;
for (int i = 1; i < KE_WORK_GROUP_SIZE; i *= 2) {
SYNC_THREADS;
if (LOCAL_ID%(i*2) == 0 && LOCAL_ID+i < KE_WORK_GROUP_SIZE)
tempBuffer[LOCAL_ID] += tempBuffer[LOCAL_ID+i];
}
if (LOCAL_ID == 0)
*result = 0.5f*tempBuffer[0];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment