Unverified Commit bc05f1c0 authored by Emilio Gallicchio's avatar Emilio Gallicchio Committed by GitHub
Browse files

merge setDisplacements kernel into copyState (#5058)

parent c91894e8
...@@ -1397,14 +1397,12 @@ private: ...@@ -1397,14 +1397,12 @@ private:
bool hasInitializedKernel; bool hasInitializedKernel;
ComputeContext& cc; ComputeContext& cc;
ComputeArray displ1, displ0; // actual displacements used in calculation
ComputeArray displacement1, displacement0; // fixed lab-frame displacements ComputeArray displacement1, displacement0; // fixed lab-frame displacements
ComputeArray displParticles; // variable displacements based on atom positions ComputeArray displParticles; // variable displacements based on atom positions
// int4 arranged as (pDestination1, pOrigin1, pDestination0, pOrigin0 // int4 arranged as (pDestination1, pOrigin1, pDestination0, pOrigin0
ComputeArray invAtomOrder, inner0InvAtomOrder, inner1InvAtomOrder; ComputeArray invAtomOrder, inner0InvAtomOrder, inner1InvAtomOrder;
ComputeArray dforce0, dforce1; // forces due to variable displacements ComputeArray dforce0, dforce1; // forces due to variable displacements
ComputeKernel copyStateKernel; ComputeKernel copyStateKernel;
ComputeKernel setDisplacementsKernel;
ComputeKernel resetDisplForceKernel; ComputeKernel resetDisplForceKernel;
ComputeKernel displForceKernel; ComputeKernel displForceKernel;
ComputeKernel hybridForceKernel; ComputeKernel hybridForceKernel;
......
...@@ -4141,10 +4141,8 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce& ...@@ -4141,10 +4141,8 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce&
displVector0[p] = mm_double4(d0[p][0], d0[p][1], d0[p][2], 0); displVector0[p] = mm_double4(d0[p][0], d0[p][1], d0[p][2], 0);
displParticlesVector[p] = mm_int4(j1[p], i1[p], j0[p], i0[p]); displParticlesVector[p] = mm_int4(j1[p], i1[p], j0[p], i0[p]);
} }
displ1.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displ1");
displacement1.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displacement1"); displacement1.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displacement1");
displacement1.upload(displVector1); displacement1.upload(displVector1);
displ0.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displ0");
displacement0.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displacement0"); displacement0.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displacement0");
displacement0.upload(displVector0); displacement0.upload(displVector0);
} }
...@@ -4156,10 +4154,8 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce& ...@@ -4156,10 +4154,8 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce&
displVector0[p] = mm_float4(d0[p][0], d0[p][1], d0[p][2], 0); displVector0[p] = mm_float4(d0[p][0], d0[p][1], d0[p][2], 0);
displParticlesVector[p] = mm_int4(j1[p], i1[p], j0[p], i0[p]); displParticlesVector[p] = mm_int4(j1[p], i1[p], j0[p], i0[p]);
} }
displ1.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ1");
displacement1.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displacement1"); displacement1.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displacement1");
displacement1.upload(displVector1); displacement1.upload(displVector1);
displ0.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ0");
displacement0.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displacement0"); displacement0.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displacement0");
displacement0.upload(displVector0); displacement0.upload(displVector0);
} }
...@@ -4203,27 +4199,17 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in ...@@ -4203,27 +4199,17 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in
ComputeProgram program = cc.compileProgram(CommonKernelSources::atmforce); ComputeProgram program = cc.compileProgram(CommonKernelSources::atmforce);
//create the setDisplacements kernel
setDisplacementsKernel = program->createKernel("setDisplacements");
setDisplacementsKernel->addArg(numParticles);
setDisplacementsKernel->addArg(cc.getPosq());
setDisplacementsKernel->addArg(displacement0);
setDisplacementsKernel->addArg(displacement1);
setDisplacementsKernel->addArg(displParticles);
setDisplacementsKernel->addArg(cc.getAtomIndexArray());
setDisplacementsKernel->addArg(invAtomOrder);
setDisplacementsKernel->addArg(displ0);
setDisplacementsKernel->addArg(displ1);
//create CopyState kernel //create CopyState kernel
copyStateKernel = program->createKernel("copyState"); copyStateKernel = program->createKernel("copyState");
copyStateKernel->addArg(numParticles); copyStateKernel->addArg(numParticles);
copyStateKernel->addArg(cc.getPosq()); copyStateKernel->addArg(cc.getPosq());
copyStateKernel->addArg(cc0.getPosq()); copyStateKernel->addArg(cc0.getPosq());
copyStateKernel->addArg(cc1.getPosq()); copyStateKernel->addArg(cc1.getPosq());
copyStateKernel->addArg(displ0); copyStateKernel->addArg(displacement0);
copyStateKernel->addArg(displ1); copyStateKernel->addArg(displacement1);
copyStateKernel->addArg(displParticles);
copyStateKernel->addArg(cc.getAtomIndexArray()); copyStateKernel->addArg(cc.getAtomIndexArray());
copyStateKernel->addArg(invAtomOrder);
copyStateKernel->addArg(inner0InvAtomOrder); copyStateKernel->addArg(inner0InvAtomOrder);
copyStateKernel->addArg(inner1InvAtomOrder); copyStateKernel->addArg(inner1InvAtomOrder);
if (cc.getUseMixedPrecision()) { if (cc.getUseMixedPrecision()) {
...@@ -4313,7 +4299,6 @@ void CommonCalcATMForceKernel::copyState(ContextImpl& context, ...@@ -4313,7 +4299,6 @@ void CommonCalcATMForceKernel::copyState(ContextImpl& context,
cc0.reorderAtoms(); cc0.reorderAtoms();
cc1.reorderAtoms(); cc1.reorderAtoms();
setDisplacementsKernel->execute(numParticles);
copyStateKernel->execute(numParticles); copyStateKernel->execute(numParticles);
map<string, double> innerParameters0 = innerContext0.getParameters(); map<string, double> innerParameters0 = innerContext0.getParameters();
......
...@@ -26,47 +26,6 @@ KERNEL void hybridForce(int numParticles, ...@@ -26,47 +26,6 @@ KERNEL void hybridForce(int numParticles,
} }
} }
KERNEL void setDisplacements(int numParticles,
GLOBAL real4* RESTRICT posq,
GLOBAL real4* RESTRICT displacement0,
GLOBAL real4* RESTRICT displacement1,
GLOBAL int4* displParticles,
GLOBAL int* RESTRICT atomOrder,
GLOBAL int* RESTRICT invAtomOrder,
GLOBAL real4* RESTRICT displ0,
GLOBAL real4* RESTRICT displ1) {
for (int index = GLOBAL_ID; index < numParticles; index += GLOBAL_SIZE) {
int atom = atomOrder[index];
int pj1 = displParticles[atom].x;
int pi1 = displParticles[atom].y;
int pj0 = displParticles[atom].z;
int pi0 = displParticles[atom].w;
if (pj1 >= 0 && pi1 >= 0) {
// variable system coordinate displacements
int indexj1 = invAtomOrder[pj1];
int indexi1 = invAtomOrder[pi1];
displ1[atom] = make_real4((real) posq[indexj1].x- posq[indexi1].x,
(real) posq[indexj1].y- posq[indexi1].y,
(real) posq[indexj1].z- posq[indexi1].z, (real) 0);
if (pj0 >= 0 && pi0 >= 0) {
int indexj0 = invAtomOrder[pj0];
int indexi0 = invAtomOrder[pi0];
displ0[atom] = make_real4((real) posq[indexj0].x - posq[indexi0].x,
(real) posq[indexj0].y - posq[indexi0].y,
(real) posq[indexj0].z - posq[indexi0].z, (real) 0);
}
else {
displ0[atom] = make_real4((real) 0, (real) 0, (real) 0, (real) 0);
}
}
else {
//fixed lab frame displacement
displ1[atom] = displacement1[atom];
displ0[atom] = displacement0[atom];
}
}
}
//reset variable displacement forces //reset variable displacement forces
KERNEL void resetDisplForce(int numParticles, KERNEL void resetDisplForce(int numParticles,
int paddedNumParticles, int paddedNumParticles,
...@@ -134,9 +93,11 @@ KERNEL void copyState(int numParticles, ...@@ -134,9 +93,11 @@ KERNEL void copyState(int numParticles,
GLOBAL real4* RESTRICT posq, GLOBAL real4* RESTRICT posq,
GLOBAL real4* RESTRICT posq0, GLOBAL real4* RESTRICT posq0,
GLOBAL real4* RESTRICT posq1, GLOBAL real4* RESTRICT posq1,
GLOBAL real4* RESTRICT displ0, GLOBAL real4* RESTRICT displacement0,
GLOBAL real4* RESTRICT displ1, GLOBAL real4* RESTRICT displacement1,
GLOBAL int4* displParticles,
GLOBAL int* RESTRICT atomOrder, GLOBAL int* RESTRICT atomOrder,
GLOBAL int* RESTRICT invAtomOrder,
GLOBAL int* RESTRICT inner0InvAtomOrder, GLOBAL int* RESTRICT inner0InvAtomOrder,
GLOBAL int* RESTRICT inner1InvAtomOrder GLOBAL int* RESTRICT inner1InvAtomOrder
#ifdef USE_MIXED_PRECISION #ifdef USE_MIXED_PRECISION
...@@ -146,12 +107,41 @@ KERNEL void copyState(int numParticles, ...@@ -146,12 +107,41 @@ KERNEL void copyState(int numParticles,
GLOBAL real4* RESTRICT posq1Correction GLOBAL real4* RESTRICT posq1Correction
#endif #endif
) { ) {
for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) { for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) {
int atom = atomOrder[i]; int atom = atomOrder[i];
//default fixed lab frame displacement
real4 displ0 = displacement0[atom];
real4 displ1 = displacement1[atom];
//override with variable displacements if set
int pj1 = displParticles[atom].x;
int pi1 = displParticles[atom].y;
int pj0 = displParticles[atom].z;
int pi0 = displParticles[atom].w;
if (pj1 >= 0 && pi1 >= 0) {
// variable system coordinate displacements
int indexj1 = invAtomOrder[pj1];
int indexi1 = invAtomOrder[pi1];
displ1 = make_real4((real) posq[indexj1].x- posq[indexi1].x,
(real) posq[indexj1].y- posq[indexi1].y,
(real) posq[indexj1].z- posq[indexi1].z, (real) 0);
if (pj0 >= 0 && pi0 >= 0) {
int indexj0 = invAtomOrder[pj0];
int indexi0 = invAtomOrder[pi0];
displ0 = make_real4((real) posq[indexj0].x - posq[indexi0].x,
(real) posq[indexj0].y - posq[indexi0].y,
(real) posq[indexj0].z - posq[indexi0].z, (real) 0);
}
else {
displ0 = make_real4((real) 0, (real) 0, (real) 0, (real) 0);
}
}
int index0 = inner0InvAtomOrder[atom]; int index0 = inner0InvAtomOrder[atom];
int index1 = inner1InvAtomOrder[atom]; int index1 = inner1InvAtomOrder[atom];
real4 p0 = posq[i] + make_real4((real) displ0[atom].x, (real) displ0[atom].y, (real) displ0[atom].z, 0); real4 p0 = posq[i] + make_real4((real) displ0.x, (real) displ0.y, (real) displ0.z, 0);
real4 p1 = posq[i] + make_real4((real) displ1[atom].x, (real) displ1[atom].y, (real) displ1[atom].z, 0); real4 p1 = posq[i] + make_real4((real) displ1.x, (real) displ1.y, (real) displ1.z, 0);
p0.w = posq0[i].w; p0.w = posq0[i].w;
p1.w = posq1[i].w; p1.w = posq1[i].w;
posq0[index0] = p0; posq0[index0] = p0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment