Unverified Commit ea4b6872 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

ATMForce reorders inner contexts for better performance (#4495)

* ATMForce reorders inner contexts for better performance

* Fixed obsolete comments
parent 6ba168cd
...@@ -115,9 +115,6 @@ void ATMForceImpl::initialize(ContextImpl& context) { ...@@ -115,9 +115,6 @@ void ATMForceImpl::initialize(ContextImpl& context) {
innerContext0 = context.createLinkedContext(innerSystem0, innerIntegrator0); innerContext0 = context.createLinkedContext(innerSystem0, innerIntegrator0);
innerContext1 = context.createLinkedContext(innerSystem1, innerIntegrator1); innerContext1 = context.createLinkedContext(innerSystem1, innerIntegrator1);
vector<Vec3> positions(system.getNumParticles(), Vec3());
innerContext0->setPositions(positions);
innerContext1->setPositions(positions);
// Create the kernel. // Create the kernel.
......
...@@ -1714,19 +1714,15 @@ public: ...@@ -1714,19 +1714,15 @@ public:
virtual ComputeContext& getInnerComputeContext(ContextImpl& innerContext) = 0; virtual ComputeContext& getInnerComputeContext(ContextImpl& innerContext) = 0;
private: private:
class ForceInfo;
class ReorderListener; class ReorderListener;
void initKernels(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1); void initKernels(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1);
bool hasInitializedKernel; bool hasInitializedKernel;
ComputeContext& cc; ComputeContext& cc;
std::vector<mm_float4> displVector1;
std::vector<mm_float4> displVector0;
ComputeArray displ1; ComputeArray displ1;
ComputeArray displ0; ComputeArray displ0;
ComputeArray invAtomOrder, inner0InvAtomOrder, inner1InvAtomOrder;
ComputeKernel copyStateKernel; ComputeKernel copyStateKernel;
ComputeKernel hybridForceKernel; ComputeKernel hybridForceKernel;
......
...@@ -5629,6 +5629,7 @@ void CommonCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& ...@@ -5629,6 +5629,7 @@ void CommonCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl&
listener1->execute(); listener1->execute();
listener2->execute(); listener2->execute();
} }
cc2.reorderAtoms();
copyStateKernel->execute(numAtoms); copyStateKernel->execute(numAtoms);
Vec3 a, b, c; Vec3 a, b, c;
context.getPeriodicBoxVectors(a, b, c); context.getPeriodicBoxVectors(a, b, c);
...@@ -8027,49 +8028,20 @@ void CommonApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& contex ...@@ -8027,49 +8028,20 @@ void CommonApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& contex
cc.setAtomIndex(lastAtomOrder); cc.setAtomIndex(lastAtomOrder);
} }
class CommonCalcATMForceKernel::ForceInfo : public ComputeForceInfo {
public:
ForceInfo(ComputeForceInfo& force) : force(force) {
}
bool areParticlesIdentical(int particle1, int particle2) {
return force.areParticlesIdentical(particle1, particle2);
}
int getNumParticleGroups() {
return force.getNumParticleGroups();
}
void getParticlesInGroup(int index, vector<int>& particles) {
force.getParticlesInGroup(index, particles);
}
bool areGroupsIdentical(int group1, int group2) {
return force.areGroupsIdentical(group1, group2);
}
private:
ComputeForceInfo& force;
};
class CommonCalcATMForceKernel::ReorderListener : public ComputeContext::ReorderListener { class CommonCalcATMForceKernel::ReorderListener : public ComputeContext::ReorderListener {
public: public:
ReorderListener(ComputeContext& cc, vector<mm_float4>& displVector1, ArrayInterface& displ1, ReorderListener(ComputeContext& cc, ArrayInterface& invAtomOrder) : cc(cc), invAtomOrder(invAtomOrder) {
vector<mm_float4>& displVector0, ArrayInterface& displ0) :
cc(cc), displVector1(displVector1), displ1(displ1), displVector0(displVector0), displ0(displ0) {
} }
void execute() { void execute() {
const vector<int>& id = cc.getAtomIndex(); vector<int> invOrder(cc.getPaddedNumAtoms());
vector<mm_float4> newDisplVectorContext1(cc.getPaddedNumAtoms()); const vector<int>& order = cc.getAtomIndex();
vector<mm_float4> newDisplVectorContext0(cc.getPaddedNumAtoms()); for (int i = 0; i < order.size(); i++)
for (int i = 0; i < cc.getNumAtoms(); i++) { invOrder[order[i]] = i;
newDisplVectorContext1[i] = displVector1[id[i]]; invAtomOrder.upload(invOrder);
newDisplVectorContext0[i] = displVector0[id[i]];
}
displ1.upload(newDisplVectorContext1);
displ0.upload(newDisplVectorContext0);
} }
private: private:
ComputeContext& cc; ComputeContext& cc;
ArrayInterface& displ1; ArrayInterface& invAtomOrder;
ArrayInterface& displ0;
std::vector<mm_float4> displVector1;
std::vector<mm_float4> displVector0;
}; };
CommonCalcATMForceKernel::~CommonCalcATMForceKernel() { CommonCalcATMForceKernel::~CommonCalcATMForceKernel() {
...@@ -8080,31 +8052,23 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce& ...@@ -8080,31 +8052,23 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce&
numParticles = force.getNumParticles(); numParticles = force.getNumParticles();
if (numParticles == 0) if (numParticles == 0)
return; return;
displVector1.resize(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0)); vector<mm_float4> displVector1(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
displVector0.resize(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0)); vector<mm_float4> displVector0(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
vector<mm_float4> displVectorContext1(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
vector<mm_float4> displVectorContext0(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
Vec3 displacement1, displacement0; Vec3 displacement1, displacement0;
force.getParticleParameters(i, displacement1, displacement0); force.getParticleParameters(i, displacement1, displacement0);
displVector1[i] = mm_float4(displacement1[0], displacement1[1], displacement1[2], 0); displVector1[i] = mm_float4(displacement1[0], displacement1[1], displacement1[2], 0);
displVector0[i] = mm_float4(displacement0[0], displacement0[1], displacement0[2], 0); displVector0[i] = mm_float4(displacement0[0], displacement0[1], displacement0[2], 0);
} }
const vector<int>& id = cc.getAtomIndex();
for (int i = 0; i < numParticles; i++)
displVectorContext1[i] = displVector1[id[i]];
displ1.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ1"); displ1.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ1");
displ1.upload(displVectorContext1); displ1.upload(displVector1);
for (int i = 0; i < numParticles; i++)
displVectorContext0[i] = displVector0[id[i]];
displ0.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ0"); displ0.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ0");
displ0.upload(displVectorContext0); displ0.upload(displVector0);
invAtomOrder.initialize<int>(cc, cc.getPaddedNumAtoms(), "invAtomOrder");
inner0InvAtomOrder.initialize<int>(cc, cc.getPaddedNumAtoms(), "inner0InvAtomOrder");
inner1InvAtomOrder.initialize<int>(cc, cc.getPaddedNumAtoms(), "inner1InvAtomOrder");
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
cc.addEnergyParameterDerivative(force.getEnergyParameterDerivativeName(i)); cc.addEnergyParameterDerivative(force.getEnergyParameterDerivativeName(i));
cc.addForce(new ComputeForceInfo());
} }
void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1) { void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1) {
...@@ -8115,10 +8079,22 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in ...@@ -8115,10 +8079,22 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in
ComputeContext& cc0 = getInnerComputeContext(innerContext0); ComputeContext& cc0 = getInnerComputeContext(innerContext0);
ComputeContext& cc1 = getInnerComputeContext(innerContext1); ComputeContext& cc1 = getInnerComputeContext(innerContext1);
//initialize the listener, this reorders the displacement vectors // Copy positions to the inner contexts.
ReorderListener* listener = new ReorderListener(cc, displVector1, displ1, displVector0, displ0); vector<Vec3> positions;
context.getPositions(positions);
innerContext0.setPositions(positions);
innerContext1.setPositions(positions);
// Initialize the listeners.
ReorderListener* listener = new ReorderListener(cc, invAtomOrder);
ReorderListener* listener0 = new ReorderListener(cc0, inner0InvAtomOrder);
ReorderListener* listener1 = new ReorderListener(cc1, inner1InvAtomOrder);
cc.addReorderListener(listener); cc.addReorderListener(listener);
cc0.addReorderListener(listener0);
cc1.addReorderListener(listener1);
listener->execute(); listener->execute();
listener0->execute();
listener1->execute();
//create CopyState kernel //create CopyState kernel
ComputeProgram program = cc.compileProgram(CommonKernelSources::atmforce); ComputeProgram program = cc.compileProgram(CommonKernelSources::atmforce);
...@@ -8129,6 +8105,9 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in ...@@ -8129,6 +8105,9 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in
copyStateKernel->addArg(cc1.getPosq()); copyStateKernel->addArg(cc1.getPosq());
copyStateKernel->addArg(displ0); copyStateKernel->addArg(displ0);
copyStateKernel->addArg(displ1); copyStateKernel->addArg(displ1);
copyStateKernel->addArg(cc.getAtomIndexArray());
copyStateKernel->addArg(inner0InvAtomOrder);
copyStateKernel->addArg(inner1InvAtomOrder);
if (cc.getUseMixedPrecision()) { if (cc.getUseMixedPrecision()) {
copyStateKernel->addArg(cc.getPosqCorrection()); copyStateKernel->addArg(cc.getPosqCorrection());
copyStateKernel->addArg(cc0.getPosqCorrection()); copyStateKernel->addArg(cc0.getPosqCorrection());
...@@ -8142,6 +8121,9 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in ...@@ -8142,6 +8121,9 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in
hybridForceKernel->addArg(cc.getLongForceBuffer()); hybridForceKernel->addArg(cc.getLongForceBuffer());
hybridForceKernel->addArg(cc0.getLongForceBuffer()); hybridForceKernel->addArg(cc0.getLongForceBuffer());
hybridForceKernel->addArg(cc1.getLongForceBuffer()); hybridForceKernel->addArg(cc1.getLongForceBuffer());
hybridForceKernel->addArg(invAtomOrder);
hybridForceKernel->addArg(inner0InvAtomOrder);
hybridForceKernel->addArg(inner1InvAtomOrder);
hybridForceKernel->addArg(); hybridForceKernel->addArg();
hybridForceKernel->addArg(); hybridForceKernel->addArg();
...@@ -8156,12 +8138,12 @@ void CommonCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl& in ...@@ -8156,12 +8138,12 @@ void CommonCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl& in
ContextSelector selector(cc); ContextSelector selector(cc);
initKernels(context, innerContext0, innerContext1); initKernels(context, innerContext0, innerContext1);
if (cc.getUseDoublePrecision()) { if (cc.getUseDoublePrecision()) {
hybridForceKernel->setArg(5, dEdu0); hybridForceKernel->setArg(8, dEdu0);
hybridForceKernel->setArg(6, dEdu1); hybridForceKernel->setArg(9, dEdu1);
} }
else { else {
hybridForceKernel->setArg(5, (float) dEdu0); hybridForceKernel->setArg(8, (float) dEdu0);
hybridForceKernel->setArg(6, (float) dEdu1); hybridForceKernel->setArg(9, (float) dEdu1);
} }
hybridForceKernel->execute(numParticles); hybridForceKernel->execute(numParticles);
map<string, double>& derivs = cc.getEnergyParamDerivWorkspace(); map<string, double>& derivs = cc.getEnergyParamDerivWorkspace();
...@@ -8175,6 +8157,10 @@ void CommonCalcATMForceKernel::copyState(ContextImpl& context, ...@@ -8175,6 +8157,10 @@ void CommonCalcATMForceKernel::copyState(ContextImpl& context,
initKernels(context, innerContext0, innerContext1); initKernels(context, innerContext0, innerContext1);
ComputeContext& cc0 = getInnerComputeContext(innerContext0);
ComputeContext& cc1 = getInnerComputeContext(innerContext1);
cc0.reorderAtoms();
cc1.reorderAtoms();
copyStateKernel->execute(numParticles); copyStateKernel->execute(numParticles);
Vec3 a, b, c; Vec3 a, b, c;
...@@ -8195,23 +8181,16 @@ void CommonCalcATMForceKernel::copyParametersToContext(ContextImpl& context, con ...@@ -8195,23 +8181,16 @@ void CommonCalcATMForceKernel::copyParametersToContext(ContextImpl& context, con
ContextSelector selector(cc); ContextSelector selector(cc);
if (force.getNumParticles() != numParticles) if (force.getNumParticles() != numParticles)
throw OpenMMException("copyParametersToContext: The number of ATMMetaForce particles has changed"); throw OpenMMException("copyParametersToContext: The number of ATMMetaForce particles has changed");
displVector1.resize(cc.getPaddedNumAtoms()); vector<mm_float4> displVector1(cc.getPaddedNumAtoms());
displVector0.resize(cc.getPaddedNumAtoms()); vector<mm_float4> displVector0(cc.getPaddedNumAtoms());
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
Vec3 displacement1, displacement0; Vec3 displacement1, displacement0;
force.getParticleParameters(i, displacement1, displacement0); force.getParticleParameters(i, displacement1, displacement0);
displVector1[i] = mm_float4(displacement1[0], displacement1[1], displacement1[2], 0); displVector1[i] = mm_float4(displacement1[0], displacement1[1], displacement1[2], 0);
displVector0[i] = mm_float4(displacement0[0], displacement0[1], displacement0[2], 0); displVector0[i] = mm_float4(displacement0[0], displacement0[1], displacement0[2], 0);
} }
const vector<int>& id = cc.getAtomIndex(); displ1.upload(displVector1);
vector<mm_float4> displVectorContext1(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0)); displ0.upload(displVector0);
vector<mm_float4> displVectorContext0(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
for (int i = 0; i < numParticles; i++) {
displVectorContext1[i] = displVector1[id[i]];
displVectorContext0[i] = displVector0[id[i]];
}
displ1.upload(displVectorContext1);
displ0.upload(displVectorContext0);
} }
class CommonCalcCustomCPPForceKernel::StartCalculationPreComputation : public ComputeContext::ForcePreComputation { class CommonCalcCustomCPPForceKernel::StartCalculationPreComputation : public ComputeContext::ForcePreComputation {
......
...@@ -3,12 +3,18 @@ KERNEL void hybridForce(int numParticles, ...@@ -3,12 +3,18 @@ KERNEL void hybridForce(int numParticles,
GLOBAL mm_long* RESTRICT force, GLOBAL mm_long* RESTRICT force,
GLOBAL mm_long* RESTRICT force0, GLOBAL mm_long* RESTRICT force0,
GLOBAL mm_long* RESTRICT force1, GLOBAL mm_long* RESTRICT force1,
GLOBAL int* RESTRICT invAtomOrder,
GLOBAL int* RESTRICT inner0InvAtomOrder,
GLOBAL int* RESTRICT inner1InvAtomOrder,
real dEdu0, real dEdu0,
real dEdu1) { real dEdu1) {
for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) { for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) {
force[i] += (mm_long) (dEdu0*force0[i] + dEdu1*force1[i]); int index = invAtomOrder[i];
force[i+paddedNumParticles] += (mm_long) (dEdu0*force0[i+paddedNumParticles] + dEdu1*force1[i+paddedNumParticles]); int index0 = inner0InvAtomOrder[i];
force[i+paddedNumParticles*2] += (mm_long) (dEdu0*force0[i+paddedNumParticles*2] + dEdu1*force1[i+paddedNumParticles*2]); int index1 = inner1InvAtomOrder[i];
force[index] += (mm_long) (dEdu0*force0[index0] + dEdu1*force1[index1]);
force[index+paddedNumParticles] += (mm_long) (dEdu0*force0[index0+paddedNumParticles] + dEdu1*force1[index1+paddedNumParticles]);
force[index+paddedNumParticles*2] += (mm_long) (dEdu0*force0[index0+paddedNumParticles*2] + dEdu1*force1[index1+paddedNumParticles*2]);
} }
} }
...@@ -17,7 +23,10 @@ KERNEL void copyState(int numParticles, ...@@ -17,7 +23,10 @@ KERNEL void copyState(int numParticles,
GLOBAL real4* RESTRICT posq0, GLOBAL real4* RESTRICT posq0,
GLOBAL real4* RESTRICT posq1, GLOBAL real4* RESTRICT posq1,
GLOBAL float4* RESTRICT displ0, GLOBAL float4* RESTRICT displ0,
GLOBAL float4* RESTRICT displ1 GLOBAL float4* RESTRICT displ1,
GLOBAL int* RESTRICT atomOrder,
GLOBAL int* RESTRICT inner0InvAtomOrder,
GLOBAL int* RESTRICT inner1InvAtomOrder
#ifdef USE_MIXED_PRECISION #ifdef USE_MIXED_PRECISION
, ,
GLOBAL real4* RESTRICT posqCorrection, GLOBAL real4* RESTRICT posqCorrection,
...@@ -26,15 +35,18 @@ KERNEL void copyState(int numParticles, ...@@ -26,15 +35,18 @@ KERNEL void copyState(int numParticles,
#endif #endif
) { ) {
for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) { for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) {
real4 p0 = posq[i] + make_real4((real) displ0[i].x, (real) displ0[i].y, (real) displ0[i].z, 0); int atom = atomOrder[i];
real4 p1 = posq[i] + make_real4((real) displ1[i].x, (real) displ1[i].y, (real) displ1[i].z, 0); int index0 = inner0InvAtomOrder[atom];
int index1 = inner1InvAtomOrder[atom];
real4 p0 = posq[i] + make_real4((real) displ0[atom].x, (real) displ0[atom].y, (real) displ0[atom].z, 0);
real4 p1 = posq[i] + make_real4((real) displ1[atom].x, (real) displ1[atom].y, (real) displ1[atom].z, 0);
p0.w = posq0[i].w; p0.w = posq0[i].w;
p1.w = posq1[i].w; p1.w = posq1[i].w;
posq0[i] = p0; posq0[index0] = p0;
posq1[i] = p1; posq1[index1] = p1;
#ifdef USE_MIXED_PRECISION #ifdef USE_MIXED_PRECISION
posq0Correction[i] = posqCorrection[i]; posq0Correction[index0] = posqCorrection[i];
posq1Correction[i] = posqCorrection[i]; posq1Correction[index1] = posqCorrection[i];
#endif #endif
} }
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2023 Stanford University and the Authors. * * Portions copyright (c) 2008-2024 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -2959,13 +2959,13 @@ void ReferenceCalcATMForceKernel::copyState(ContextImpl& context, ContextImpl& i ...@@ -2959,13 +2959,13 @@ void ReferenceCalcATMForceKernel::copyState(ContextImpl& context, ContextImpl& i
vector<Vec3> pos0(pos); vector<Vec3> pos0(pos);
for (int i = 0; i < pos0.size(); i++) for (int i = 0; i < pos0.size(); i++)
pos0[i] += displ0[i]; pos0[i] += displ0[i];
extractPositions(innerContext0) = pos0; innerContext0.setPositions(pos0);
//in the target state, particles are displaced by displ1 //in the target state, particles are displaced by displ1
vector<Vec3> pos1(pos); vector<Vec3> pos1(pos);
for (int i = 0; i < pos1.size(); i++) for (int i = 0; i < pos1.size(); i++)
pos1[i] += displ1[i]; pos1[i] += displ1[i];
extractPositions(innerContext1) = pos1; innerContext1.setPositions(pos1);
Vec3 a, b, c; Vec3 a, b, c;
context.getPeriodicBoxVectors(a, b, c); context.getPeriodicBoxVectors(a, b, c);
......
...@@ -245,7 +245,6 @@ void testNonbonded() { ...@@ -245,7 +245,6 @@ void testNonbonded() {
System system; System system;
double u0, u1, energy; double u0, u1, energy;
double lambda = 0.5; double lambda = 0.5;
int numParticles = 216;
double width = 4.0; double width = 4.0;
system.setDefaultPeriodicBoxVectors(Vec3(width, 0, 0), Vec3(0, width, 0), Vec3(0, 0, width)); system.setDefaultPeriodicBoxVectors(Vec3(width, 0, 0), Vec3(0, width, 0), Vec3(0, 0, width));
...@@ -270,7 +269,8 @@ void testNonbonded() { ...@@ -270,7 +269,8 @@ void testNonbonded() {
//in this scenario the non-bonded force is added to the System, a copy is added to ATMForce and //in this scenario the non-bonded force is added to the System, a copy is added to ATMForce and
//the System's copy is disabled by giving it a force group that is not evaluated. //the System's copy is disabled by giving it a force group that is not evaluated.
//This should cause atom reordering in the main context //This used to be needed to ensure atoms would be reordered. It isn't anymore, but
//this test is left in to make sure it still works.
system.addForce(nbforce); system.addForce(nbforce);
atm->addForce(XmlSerializer::clone<Force>(*nbforce)); atm->addForce(XmlSerializer::clone<Force>(*nbforce));
nbforce->setForceGroup(1); nbforce->setForceGroup(1);
...@@ -286,7 +286,6 @@ void testNonbonded() { ...@@ -286,7 +286,6 @@ void testNonbonded() {
double epert1 = u1 - u0; double epert1 = u1 - u0;
//in this second scenario the non-bonded force is remove from the System //in this second scenario the non-bonded force is remove from the System
//and atom reordering is disabled.
system.removeForce(0); system.removeForce(0);
LangevinMiddleIntegrator integrator2(300, 1.0, 0.004); LangevinMiddleIntegrator integrator2(300, 1.0, 0.004);
Context context2(system, integrator2, platform); Context context2(system, integrator2, platform);
...@@ -411,6 +410,7 @@ void testLargeSystem() { ...@@ -411,6 +410,7 @@ void testLargeSystem() {
int numParticles = 1000; int numParticles = 1000;
System system; System system;
system.setDefaultPeriodicBoxVectors(Vec3(3, 0, 0), Vec3(0, 3, 0), Vec3(0, 0, 3));
CustomExternalForce* external = new CustomExternalForce("x^2 + 2*y^2 + 3*z^2"); CustomExternalForce* external = new CustomExternalForce("x^2 + 2*y^2 + 3*z^2");
ATMForce* atm = new ATMForce(0.0, 0.0, 0.1, 0.0, 0.0, 1e6, 5e5, 1.0/16, 1.0); ATMForce* atm = new ATMForce(0.0, 0.0, 0.1, 0.0, 0.0, 1e6, 5e5, 1.0/16, 1.0);
atm->addForce(external); atm->addForce(external);
...@@ -427,19 +427,26 @@ void testLargeSystem() { ...@@ -427,19 +427,26 @@ void testLargeSystem() {
atm->addParticle(d); atm->addParticle(d);
} }
// Also add a nonbonded force to trigger atom reordering on the GPU. // Also add nonbonded forces to trigger atom reordering on the GPU.
CustomNonbondedForce* nb = new CustomNonbondedForce("a*r^2"); CustomNonbondedForce* nb = new CustomNonbondedForce("a*r^2");
nb->addGlobalParameter("a", 0.0); nb->addGlobalParameter("a", 0.0);
for (int i = 0; i < numParticles; i++) for (int i = 0; i < numParticles; i++)
nb->addParticle(); nb->addParticle();
nb->setNonbondedMethod(CustomNonbondedForce::CutoffPeriodic);
system.addForce(nb); system.addForce(nb);
VerletIntegrator integrator(1.0); CustomNonbondedForce* nb1 = new CustomNonbondedForce("0");
Context context(system, integrator, platform); nb1->addPerParticleParameter("b");
context.setPositions(positions); for (int i = 0; i < numParticles; i++)
nb1->addParticle({(double) (i%3)});
nb1->setNonbondedMethod(CustomNonbondedForce::CutoffPeriodic);
atm->addForce(nb1);
// Evaluate the forces to see if the particles are at the correct positions. // Evaluate the forces to see if the particles are at the correct positions.
VerletIntegrator integrator(1.0);
Context context(system, integrator, platform);
context.setPositions(positions);
for (double lambda : {0.0, 1.0}) { for (double lambda : {0.0, 1.0}) {
context.setParameter(ATMForce::Lambda1(), lambda); context.setParameter(ATMForce::Lambda1(), lambda);
context.setParameter(ATMForce::Lambda2(), lambda); context.setParameter(ATMForce::Lambda2(), lambda);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment