Unverified Commit 559da024 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimized setPositions() and setVelocities() (#4945)

* Optimized setPositions() and setVelocities()

* Fix test failures
parent 7df74a1c
......@@ -154,6 +154,8 @@ public:
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private:
ComputeContext& cc;
ComputeArray floatBuffer, doubleBuffer;
ComputeKernel copyFloatKernel, copyDoubleKernel;
};
/**
......
......@@ -58,6 +58,21 @@ using namespace std;
using namespace Lepton;
void CommonUpdateStateDataKernel::initialize(const System& system) {
ContextSelector selector(cc);
floatBuffer.initialize<float>(cc, 3*system.getNumParticles(), "floatBuffer");
map<string, string> defines;
ComputeProgram program = cc.compileProgram(CommonKernelSources::copyCoordinateBuffers, defines);
copyFloatKernel = program->createKernel("copyFloatBuffer");
copyFloatKernel->addArg(floatBuffer);
copyFloatKernel->addArg();
copyFloatKernel->addArg(cc.getNumAtoms());
if (cc.getUseMixedPrecision() || cc.getUseDoublePrecision()) {
doubleBuffer.initialize<double>(cc, 3*system.getNumParticles(), "doubleBuffer");
copyDoubleKernel = program->createKernel("copyDoubleBuffer");
copyDoubleKernel->addArg(doubleBuffer);
copyDoubleKernel->addArg();
copyDoubleKernel->addArg(cc.getNumAtoms());
}
}
double CommonUpdateStateDataKernel::getTime(const ContextImpl& context) const {
......@@ -144,32 +159,28 @@ void CommonUpdateStateDataKernel::setPositions(ContextImpl& context, const vecto
const vector<int>& order = cc.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
if (cc.getUseDoublePrecision()) {
mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
cc.getPosq().download(posq);
double* pos = (double*) cc.getPinnedBuffer();
for (int i = 0; i < numParticles; ++i) {
mm_double4& pos = posq[i];
const Vec3& p = positions[order[i]];
pos.x = p[0];
pos.y = p[1];
pos.z = p[2];
pos[3*i] = p[0];
pos[3*i+1] = p[1];
pos[3*i+2] = p[2];
}
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++)
posq[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
cc.getPosq().upload(posq);
doubleBuffer.upload(pos);
copyDoubleKernel->setArg(1, cc.getPosq());
copyDoubleKernel->execute(numParticles);
}
else {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
cc.getPosq().download(posq);
float* pos = (float*) cc.getPinnedBuffer();
for (int i = 0; i < numParticles; ++i) {
mm_float4& pos = posq[i];
const Vec3& p = positions[order[i]];
pos.x = (float) p[0];
pos.y = (float) p[1];
pos.z = (float) p[2];
pos[3*i] = (float) p[0];
pos[3*i+1] = (float) p[1];
pos[3*i+2] = (float) p[2];
}
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++)
posq[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
cc.getPosq().upload(posq);
floatBuffer.upload(pos);
copyFloatKernel->setArg(1, cc.getPosq());
copyFloatKernel->execute(numParticles);
}
if (cc.getUseMixedPrecision()) {
mm_float4* posCorrection = (mm_float4*) cc.getPinnedBuffer();
......@@ -218,32 +229,28 @@ void CommonUpdateStateDataKernel::setVelocities(ContextImpl& context, const vect
const vector<int>& order = cc.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
mm_double4* velm = (mm_double4*) cc.getPinnedBuffer();
cc.getVelm().download(velm);
double* vel = (double*) cc.getPinnedBuffer();
for (int i = 0; i < numParticles; ++i) {
mm_double4& vel = velm[i];
const Vec3& p = velocities[order[i]];
vel.x = p[0];
vel.y = p[1];
vel.z = p[2];
vel[3*i] = p[0];
vel[3*i+1] = p[1];
vel[3*i+2] = p[2];
}
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++)
velm[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
cc.getVelm().upload(velm);
doubleBuffer.upload(vel);
copyDoubleKernel->setArg(1, cc.getVelm());
copyDoubleKernel->execute(numParticles);
}
else {
mm_float4* velm = (mm_float4*) cc.getPinnedBuffer();
cc.getVelm().download(velm);
float* vel = (float*) cc.getPinnedBuffer();
for (int i = 0; i < numParticles; ++i) {
mm_float4& vel = velm[i];
const Vec3& p = velocities[order[i]];
vel.x = p[0];
vel.y = p[1];
vel.z = p[2];
vel[3*i] = (float) p[0];
vel[3*i+1] = (float) p[1];
vel[3*i+2] = (float) p[2];
}
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++)
velm[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
cc.getVelm().upload(velm);
floatBuffer.upload(vel);
copyFloatKernel->setArg(1, cc.getVelm());
copyFloatKernel->execute(numParticles);
}
}
......
KERNEL void copyFloatBuffer(GLOBAL float* RESTRICT source, GLOBAL float4* RESTRICT dest, int numAtoms) {
for (int i = GLOBAL_ID; i < numAtoms; i += GLOBAL_SIZE) {
dest[i].x = source[3*i];
dest[i].y = source[3*i+1];
dest[i].z = source[3*i+2];
}
}
#ifdef SUPPORTS_DOUBLE_PRECISION
KERNEL void copyDoubleBuffer(GLOBAL double* RESTRICT source, GLOBAL double4* RESTRICT dest, int numAtoms) {
for (int i = GLOBAL_ID; i < numAtoms; i += GLOBAL_SIZE) {
dest[i].x = source[3*i];
dest[i].y = source[3*i+1];
dest[i].z = source[3*i+2];
}
}
#endif
\ No newline at end of file
......@@ -359,6 +359,7 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
nonbonded = new CudaNonbondedUtilities(*this);
integration = new CudaIntegrationUtilities(*this, system);
expression = new CudaExpressionUtilities(*this);
clearBuffer(posq);
}
CudaContext::~CudaContext() {
......
......@@ -351,6 +351,7 @@ HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSy
nonbonded = new HipNonbondedUtilities(*this);
integration = new HipIntegrationUtilities(*this, system);
expression = new HipExpressionUtilities(*this);
clearBuffer(posq);
}
HipContext::~HipContext() {
......
......@@ -492,6 +492,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
nonbonded = new OpenCLNonbondedUtilities(*this);
integration = new OpenCLIntegrationUtilities(*this, system);
expression = new OpenCLExpressionUtilities(*this);
clearBuffer(posq);
}
OpenCLContext::~OpenCLContext() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment