Unverified Commit 559da024 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimized setPositions() and setVelocities() (#4945)

* Optimized setPositions() and setVelocities()

* Fix test failures
parent 7df74a1c
...@@ -154,6 +154,8 @@ public: ...@@ -154,6 +154,8 @@ public:
void loadCheckpoint(ContextImpl& context, std::istream& stream); void loadCheckpoint(ContextImpl& context, std::istream& stream);
private: private:
ComputeContext& cc; ComputeContext& cc;
ComputeArray floatBuffer, doubleBuffer;
ComputeKernel copyFloatKernel, copyDoubleKernel;
}; };
/** /**
......
...@@ -58,6 +58,21 @@ using namespace std; ...@@ -58,6 +58,21 @@ using namespace std;
using namespace Lepton; using namespace Lepton;
void CommonUpdateStateDataKernel::initialize(const System& system) { void CommonUpdateStateDataKernel::initialize(const System& system) {
ContextSelector selector(cc);
floatBuffer.initialize<float>(cc, 3*system.getNumParticles(), "floatBuffer");
map<string, string> defines;
ComputeProgram program = cc.compileProgram(CommonKernelSources::copyCoordinateBuffers, defines);
copyFloatKernel = program->createKernel("copyFloatBuffer");
copyFloatKernel->addArg(floatBuffer);
copyFloatKernel->addArg();
copyFloatKernel->addArg(cc.getNumAtoms());
if (cc.getUseMixedPrecision() || cc.getUseDoublePrecision()) {
doubleBuffer.initialize<double>(cc, 3*system.getNumParticles(), "doubleBuffer");
copyDoubleKernel = program->createKernel("copyDoubleBuffer");
copyDoubleKernel->addArg(doubleBuffer);
copyDoubleKernel->addArg();
copyDoubleKernel->addArg(cc.getNumAtoms());
}
} }
double CommonUpdateStateDataKernel::getTime(const ContextImpl& context) const { double CommonUpdateStateDataKernel::getTime(const ContextImpl& context) const {
...@@ -144,32 +159,28 @@ void CommonUpdateStateDataKernel::setPositions(ContextImpl& context, const vecto ...@@ -144,32 +159,28 @@ void CommonUpdateStateDataKernel::setPositions(ContextImpl& context, const vecto
const vector<int>& order = cc.getAtomIndex(); const vector<int>& order = cc.getAtomIndex();
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
if (cc.getUseDoublePrecision()) { if (cc.getUseDoublePrecision()) {
mm_double4* posq = (mm_double4*) cc.getPinnedBuffer(); double* pos = (double*) cc.getPinnedBuffer();
cc.getPosq().download(posq);
for (int i = 0; i < numParticles; ++i) { for (int i = 0; i < numParticles; ++i) {
mm_double4& pos = posq[i];
const Vec3& p = positions[order[i]]; const Vec3& p = positions[order[i]];
pos.x = p[0]; pos[3*i] = p[0];
pos.y = p[1]; pos[3*i+1] = p[1];
pos.z = p[2]; pos[3*i+2] = p[2];
} }
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++) doubleBuffer.upload(pos);
posq[i] = mm_double4(0.0, 0.0, 0.0, 0.0); copyDoubleKernel->setArg(1, cc.getPosq());
cc.getPosq().upload(posq); copyDoubleKernel->execute(numParticles);
} }
else { else {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer(); float* pos = (float*) cc.getPinnedBuffer();
cc.getPosq().download(posq);
for (int i = 0; i < numParticles; ++i) { for (int i = 0; i < numParticles; ++i) {
mm_float4& pos = posq[i];
const Vec3& p = positions[order[i]]; const Vec3& p = positions[order[i]];
pos.x = (float) p[0]; pos[3*i] = (float) p[0];
pos.y = (float) p[1]; pos[3*i+1] = (float) p[1];
pos.z = (float) p[2]; pos[3*i+2] = (float) p[2];
} }
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++) floatBuffer.upload(pos);
posq[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f); copyFloatKernel->setArg(1, cc.getPosq());
cc.getPosq().upload(posq); copyFloatKernel->execute(numParticles);
} }
if (cc.getUseMixedPrecision()) { if (cc.getUseMixedPrecision()) {
mm_float4* posCorrection = (mm_float4*) cc.getPinnedBuffer(); mm_float4* posCorrection = (mm_float4*) cc.getPinnedBuffer();
...@@ -218,32 +229,28 @@ void CommonUpdateStateDataKernel::setVelocities(ContextImpl& context, const vect ...@@ -218,32 +229,28 @@ void CommonUpdateStateDataKernel::setVelocities(ContextImpl& context, const vect
const vector<int>& order = cc.getAtomIndex(); const vector<int>& order = cc.getAtomIndex();
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) { if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
mm_double4* velm = (mm_double4*) cc.getPinnedBuffer(); double* vel = (double*) cc.getPinnedBuffer();
cc.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) { for (int i = 0; i < numParticles; ++i) {
mm_double4& vel = velm[i];
const Vec3& p = velocities[order[i]]; const Vec3& p = velocities[order[i]];
vel.x = p[0]; vel[3*i] = p[0];
vel.y = p[1]; vel[3*i+1] = p[1];
vel.z = p[2]; vel[3*i+2] = p[2];
} }
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++) doubleBuffer.upload(vel);
velm[i] = mm_double4(0.0, 0.0, 0.0, 0.0); copyDoubleKernel->setArg(1, cc.getVelm());
cc.getVelm().upload(velm); copyDoubleKernel->execute(numParticles);
} }
else { else {
mm_float4* velm = (mm_float4*) cc.getPinnedBuffer(); float* vel = (float*) cc.getPinnedBuffer();
cc.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) { for (int i = 0; i < numParticles; ++i) {
mm_float4& vel = velm[i];
const Vec3& p = velocities[order[i]]; const Vec3& p = velocities[order[i]];
vel.x = p[0]; vel[3*i] = (float) p[0];
vel.y = p[1]; vel[3*i+1] = (float) p[1];
vel.z = p[2]; vel[3*i+2] = (float) p[2];
} }
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++) floatBuffer.upload(vel);
velm[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f); copyFloatKernel->setArg(1, cc.getVelm());
cc.getVelm().upload(velm); copyFloatKernel->execute(numParticles);
} }
} }
......
KERNEL void copyFloatBuffer(GLOBAL float* RESTRICT source, GLOBAL float4* RESTRICT dest, int numAtoms) {
for (int i = GLOBAL_ID; i < numAtoms; i += GLOBAL_SIZE) {
dest[i].x = source[3*i];
dest[i].y = source[3*i+1];
dest[i].z = source[3*i+2];
}
}
#ifdef SUPPORTS_DOUBLE_PRECISION
KERNEL void copyDoubleBuffer(GLOBAL double* RESTRICT source, GLOBAL double4* RESTRICT dest, int numAtoms) {
for (int i = GLOBAL_ID; i < numAtoms; i += GLOBAL_SIZE) {
dest[i].x = source[3*i];
dest[i].y = source[3*i+1];
dest[i].z = source[3*i+2];
}
}
#endif
\ No newline at end of file
...@@ -359,6 +359,7 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking ...@@ -359,6 +359,7 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
nonbonded = new CudaNonbondedUtilities(*this); nonbonded = new CudaNonbondedUtilities(*this);
integration = new CudaIntegrationUtilities(*this, system); integration = new CudaIntegrationUtilities(*this, system);
expression = new CudaExpressionUtilities(*this); expression = new CudaExpressionUtilities(*this);
clearBuffer(posq);
} }
CudaContext::~CudaContext() { CudaContext::~CudaContext() {
......
...@@ -351,6 +351,7 @@ HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSy ...@@ -351,6 +351,7 @@ HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSy
nonbonded = new HipNonbondedUtilities(*this); nonbonded = new HipNonbondedUtilities(*this);
integration = new HipIntegrationUtilities(*this, system); integration = new HipIntegrationUtilities(*this, system);
expression = new HipExpressionUtilities(*this); expression = new HipExpressionUtilities(*this);
clearBuffer(posq);
} }
HipContext::~HipContext() { HipContext::~HipContext() {
......
...@@ -492,6 +492,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device ...@@ -492,6 +492,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
nonbonded = new OpenCLNonbondedUtilities(*this); nonbonded = new OpenCLNonbondedUtilities(*this);
integration = new OpenCLIntegrationUtilities(*this, system); integration = new OpenCLIntegrationUtilities(*this, system);
expression = new OpenCLExpressionUtilities(*this); expression = new OpenCLExpressionUtilities(*this);
clearBuffer(posq);
} }
OpenCLContext::~OpenCLContext() { OpenCLContext::~OpenCLContext() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment