Commit cb1c6c5d authored by Peter Eastman's avatar Peter Eastman
Browse files

Reduced duplicated code in bonded forces

parent ba3cefdc
...@@ -293,9 +293,9 @@ void OpenCLVirtualSitesKernel::computePositions(ContextImpl& context) { ...@@ -293,9 +293,9 @@ void OpenCLVirtualSitesKernel::computePositions(ContextImpl& context) {
cl.getIntegrationUtilities().computeVirtualSites(); cl.getIntegrationUtilities().computeVirtualSites();
} }
class OpenCLBondForceInfo : public OpenCLForceInfo { class OpenCLHarmonicBondForceInfo : public OpenCLForceInfo {
public: public:
OpenCLBondForceInfo(const HarmonicBondForce& force) : OpenCLForceInfo(0), force(force) { OpenCLHarmonicBondForceInfo(const HarmonicBondForce& force) : OpenCLForceInfo(0), force(force) {
} }
int getNumParticleGroups() { int getNumParticleGroups() {
return force.getNumBonds(); return force.getNumBonds();
...@@ -341,9 +341,10 @@ void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const H ...@@ -341,9 +341,10 @@ void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const H
} }
params->upload(paramVector); params->upload(paramVector);
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = OpenCLKernelSources::harmonicBondForce;
replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float2"); replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float2");
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::harmonicBondForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup());
cl.addForce(new OpenCLBondForceInfo(force)); cl.addForce(new OpenCLHarmonicBondForceInfo(force));
} }
double OpenCLCalcHarmonicBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double OpenCLCalcHarmonicBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -448,7 +449,7 @@ void OpenCLCalcCustomBondForceKernel::initialize(const System& system, const Cus ...@@ -448,7 +449,7 @@ void OpenCLCalcCustomBondForceKernel::initialize(const System& system, const Cus
compute << OpenCLExpressionUtilities::createExpressions(expressions, variables, functions, "temp", ""); compute << OpenCLExpressionUtilities::createExpressions(expressions, variables, functions, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customBondForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup());
} }
double OpenCLCalcCustomBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double OpenCLCalcCustomBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -466,9 +467,9 @@ double OpenCLCalcCustomBondForceKernel::execute(ContextImpl& context, bool inclu ...@@ -466,9 +467,9 @@ double OpenCLCalcCustomBondForceKernel::execute(ContextImpl& context, bool inclu
return 0.0; return 0.0;
} }
class OpenCLAngleForceInfo : public OpenCLForceInfo { class OpenCLHarmonicAngleForceInfo : public OpenCLForceInfo {
public: public:
OpenCLAngleForceInfo(const HarmonicAngleForce& force) : OpenCLForceInfo(0), force(force) { OpenCLHarmonicAngleForceInfo(const HarmonicAngleForce& force) : OpenCLForceInfo(0), force(force) {
} }
int getNumParticleGroups() { int getNumParticleGroups() {
return force.getNumAngles(); return force.getNumAngles();
...@@ -516,9 +517,10 @@ void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const ...@@ -516,9 +517,10 @@ void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const
} }
params->upload(paramVector); params->upload(paramVector);
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = OpenCLKernelSources::harmonicAngleForce;
replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float2"); replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float2");
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::harmonicAngleForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup());
cl.addForce(new OpenCLAngleForceInfo(force)); cl.addForce(new OpenCLHarmonicAngleForceInfo(force));
} }
double OpenCLCalcHarmonicAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double OpenCLCalcHarmonicAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -624,7 +626,7 @@ void OpenCLCalcCustomAngleForceKernel::initialize(const System& system, const Cu ...@@ -624,7 +626,7 @@ void OpenCLCalcCustomAngleForceKernel::initialize(const System& system, const Cu
compute << OpenCLExpressionUtilities::createExpressions(expressions, variables, functions, "temp", ""); compute << OpenCLExpressionUtilities::createExpressions(expressions, variables, functions, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customAngleForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup());
} }
double OpenCLCalcCustomAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double OpenCLCalcCustomAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -694,8 +696,9 @@ void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, cons ...@@ -694,8 +696,9 @@ void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, cons
} }
params->upload(paramVector); params->upload(paramVector);
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = OpenCLKernelSources::periodicTorsionForce;
replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float4"); replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float4");
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::periodicTorsionForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
cl.addForce(new OpenCLPeriodicTorsionForceInfo(force)); cl.addForce(new OpenCLPeriodicTorsionForceInfo(force));
} }
...@@ -754,8 +757,9 @@ void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTo ...@@ -754,8 +757,9 @@ void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTo
} }
params->upload(paramVector); params->upload(paramVector);
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = OpenCLKernelSources::rbTorsionForce;
replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float8"); replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float8");
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::rbTorsionForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
cl.addForce(new OpenCLRBTorsionForceInfo(force)); cl.addForce(new OpenCLRBTorsionForceInfo(force));
} }
...@@ -950,8 +954,7 @@ void OpenCLCalcCustomTorsionForceKernel::initialize(const System& system, const ...@@ -950,8 +954,7 @@ void OpenCLCalcCustomTorsionForceKernel::initialize(const System& system, const
compute << OpenCLExpressionUtilities::createExpressions(expressions, variables, functions, "temp", ""); compute << OpenCLExpressionUtilities::createExpressions(expressions, variables, functions, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
replacements["M_PI"] = doubleToString(M_PI); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customTorsionForce, replacements), force.getForceGroup());
} }
double OpenCLCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double OpenCLCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
......
...@@ -2,11 +2,11 @@ float4 v0 = pos2-pos1; ...@@ -2,11 +2,11 @@ float4 v0 = pos2-pos1;
float4 v1 = pos2-pos3; float4 v1 = pos2-pos3;
float4 cp = cross(v0, v1); float4 cp = cross(v0, v1);
float rp = cp.x*cp.x + cp.y*cp.y + cp.z*cp.z; float rp = cp.x*cp.x + cp.y*cp.y + cp.z*cp.z;
rp = max(sqrt(rp), 1.0e-06f); rp = max(SQRT(rp), 1.0e-06f);
float r21 = v0.x*v0.x + v0.y*v0.y + v0.z*v0.z; float r21 = v0.x*v0.x + v0.y*v0.y + v0.z*v0.z;
float r23 = v1.x*v1.x + v1.y*v1.y + v1.z*v1.z; float r23 = v1.x*v1.x + v1.y*v1.y + v1.z*v1.z;
float dot = v0.x*v1.x + v0.y*v1.y + v0.z*v1.z; float dot = v0.x*v1.x + v0.y*v1.y + v0.z*v1.z;
float cosine = clamp(dot/sqrt(r21*r23), -1.0f, 1.0f); float cosine = clamp(dot*RSQRT(r21*r23), -1.0f, 1.0f);
float theta = acos(cosine); float theta = acos(cosine);
COMPUTE_FORCE COMPUTE_FORCE
float4 force1 = cross(v0, cp)*(dEdAngle/(r21*rp)); float4 force1 = cross(v0, cp)*(dEdAngle/(r21*rp));
......
float4 delta = pos2-pos1; float4 delta = pos2-pos1;
float r = SQRT(delta.x*delta.x + delta.y*delta.y + delta.z*delta.z); float r = SQRT(delta.x*delta.x + delta.y*delta.y + delta.z*delta.z);
COMPUTE_FORCE COMPUTE_FORCE
delta.xyz *= -dEdR/r; dEdR = (r > 0.0f) ? (dEdR / r) : 0.0f;
float4 force1 = -delta; delta.xyz *= dEdR;
float4 force2 = delta; float4 force1 = delta;
float4 force2 = -delta;
\ No newline at end of file
float2 angleParams = PARAMS[index]; float2 angleParams = PARAMS[index];
float4 v0 = pos2-pos1; float deltaIdeal = theta-angleParams.x;
float4 v1 = pos2-pos3;
float4 cp = cross(v0, v1);
float rp = cp.x*cp.x + cp.y*cp.y + cp.z*cp.z;
rp = max(SQRT(rp), 1.0e-06f);
float r21 = v0.x*v0.x + v0.y*v0.y + v0.z*v0.z;
float r23 = v1.x*v1.x + v1.y*v1.y + v1.z*v1.z;
float dot = v0.x*v1.x + v0.y*v1.y + v0.z*v1.z;
float cosine = clamp(dot*RSQRT(r21*r23), -1.0f, 1.0f);
float deltaIdeal = acos(cosine)-angleParams.x;
energy += 0.5f*angleParams.y*deltaIdeal*deltaIdeal; energy += 0.5f*angleParams.y*deltaIdeal*deltaIdeal;
float dEdR = angleParams.y*deltaIdeal; float dEdAngle = angleParams.y*deltaIdeal;
float4 force1 = cross(v0, cp)*(dEdR/(r21*rp));
float4 force3 = cross(cp, v1)*(dEdR/(r23*rp));
float4 force2 = -(force1+force3);
float4 delta = pos2-pos1;
float2 bondParams = PARAMS[index]; float2 bondParams = PARAMS[index];
float r = SQRT(delta.x*delta.x + delta.y*delta.y + delta.z*delta.z);
float deltaIdeal = r-bondParams.x; float deltaIdeal = r-bondParams.x;
energy += 0.5f * bondParams.y*deltaIdeal*deltaIdeal; energy += 0.5f * bondParams.y*deltaIdeal*deltaIdeal;
float dEdR = bondParams.y * deltaIdeal; float dEdR = bondParams.y * deltaIdeal;
dEdR = (r > 0.0f) ? (dEdR / r) : 0.0f;
delta.xyz *= dEdR;
float4 force1 = delta;
float4 force2 = -delta;
\ No newline at end of file
const float PI = 3.14159265358979323846f;
float4 torsionParams = PARAMS[index]; float4 torsionParams = PARAMS[index];
float4 v0 = (float4) (pos1.xyz-pos2.xyz, 0.0f); float deltaAngle = torsionParams.z*theta-torsionParams.y;
float4 v1 = (float4) (pos3.xyz-pos2.xyz, 0.0f);
float4 v2 = (float4) (pos3.xyz-pos4.xyz, 0.0f);
float4 cp0 = cross(v0, v1);
float4 cp1 = cross(v1, v2);
float cosangle = dot(normalize(cp0), normalize(cp1));
float dihedralAngle;
if (cosangle > 0.99f || cosangle < -0.99f) {
// We're close to the singularity in acos(), so take the cross product and use asin() instead.
float4 cross_prod = cross(cp0, cp1);
float scale = dot(cp0, cp0)*dot(cp1, cp1);
dihedralAngle = asin(sqrt(dot(cross_prod, cross_prod)/scale));
if (cosangle < 0.0f)
dihedralAngle = PI-dihedralAngle;
}
else
dihedralAngle = acos(cosangle);
dihedralAngle = (dot(v0, cp1) >= 0 ? dihedralAngle : -dihedralAngle);
float deltaAngle = torsionParams.z*dihedralAngle-torsionParams.y;
energy += torsionParams.x*(1.0f+cos(deltaAngle)); energy += torsionParams.x*(1.0f+cos(deltaAngle));
float sinDeltaAngle = sin(deltaAngle); float sinDeltaAngle = sin(deltaAngle);
float dEdAngle = -torsionParams.x*torsionParams.z*sinDeltaAngle; float dEdAngle = -torsionParams.x*torsionParams.z*sinDeltaAngle;
float normCross1 = dot(cp0, cp0);
float normSqrBC = dot(v1, v1);
float normBC = sqrt(normSqrBC);
float normCross2 = dot(cp1, cp1);
float dp = 1.0f/normSqrBC;
float4 ff = (float4) ((-dEdAngle*normBC)/normCross1, dot(v0, v1)*dp, dot(v2, v1)*dp, (dEdAngle*normBC)/normCross2);
float4 force1 = ff.x*cp0;
float4 force4 = ff.w*cp1;
float4 s = ff.y*force1 - ff.z*force4;
float4 force2 = s-force1;
float4 force3 = -s-force4;
\ No newline at end of file
const float PI = 3.14159265358979323846f;
float8 torsionParams = PARAMS[index]; float8 torsionParams = PARAMS[index];
float4 v0 = (float4) (pos1.xyz-pos2.xyz, 0.0f); if (theta < 0.0f)
float4 v1 = (float4) (pos3.xyz-pos2.xyz, 0.0f); theta += PI;
float4 v2 = (float4) (pos3.xyz-pos4.xyz, 0.0f);
float4 cp0 = cross(v0, v1);
float4 cp1 = cross(v1, v2);
float cosangle = dot(normalize(cp0), normalize(cp1));
float dihedralAngle;
if (cosangle > 0.99f || cosangle < -0.99f) {
// We're close to the singularity in acos(), so take the cross product and use asin() instead.
float4 cross_prod = cross(cp0, cp1);
float scale = dot(cp0, cp0)*dot(cp1, cp1);
dihedralAngle = asin(sqrt(dot(cross_prod, cross_prod)/scale));
if (cosangle < 0.0f)
dihedralAngle = PI-dihedralAngle;
}
else else
dihedralAngle = acos(cosangle); theta -= PI;
dihedralAngle = (dot(v0, cp1) >= 0 ? dihedralAngle : -dihedralAngle);
if (dihedralAngle < 0.0f)
dihedralAngle += PI;
else
dihedralAngle -= PI;
cosangle = -cosangle; cosangle = -cosangle;
float cosFactor = cosangle; float cosFactor = cosangle;
float dEdAngle = -torsionParams.s1; float dEdAngle = -torsionParams.s1;
...@@ -40,15 +20,4 @@ dEdAngle -= 5.0f*torsionParams.s5*cosFactor; ...@@ -40,15 +20,4 @@ dEdAngle -= 5.0f*torsionParams.s5*cosFactor;
rbEnergy += torsionParams.s4*cosFactor; rbEnergy += torsionParams.s4*cosFactor;
rbEnergy += torsionParams.s5*cosFactor*cosangle; rbEnergy += torsionParams.s5*cosFactor*cosangle;
energy += rbEnergy; energy += rbEnergy;
dEdAngle *= sin(dihedralAngle); dEdAngle *= sin(theta);
float normCross1 = dot(cp0, cp0);
float normSqrBC = dot(v1, v1);
float normBC = sqrt(normSqrBC);
float normCross2 = dot(cp1, cp1);
float dp = 1.0f/normSqrBC;
float4 ff = (float4) ((-dEdAngle*normBC)/normCross1, dot(v0, v1)*dp, dot(v2, v1)*dp, (dEdAngle*normBC)/normCross2);
float4 force1 = ff.x*cp0;
float4 force4 = ff.w*cp1;
float4 s = ff.y*force1 - ff.z*force4;
float4 force2 = s-force1;
float4 force3 = -s-force4;
const float PI = 3.14159265358979323846f;
float4 v0 = (float4) (pos1.xyz-pos2.xyz, 0.0f); float4 v0 = (float4) (pos1.xyz-pos2.xyz, 0.0f);
float4 v1 = (float4) (pos3.xyz-pos2.xyz, 0.0f); float4 v1 = (float4) (pos3.xyz-pos2.xyz, 0.0f);
float4 v2 = (float4) (pos3.xyz-pos4.xyz, 0.0f); float4 v2 = (float4) (pos3.xyz-pos4.xyz, 0.0f);
...@@ -12,7 +13,7 @@ if (cosangle > 0.99f || cosangle < -0.99f) { ...@@ -12,7 +13,7 @@ if (cosangle > 0.99f || cosangle < -0.99f) {
float scale = dot(cp0, cp0)*dot(cp1, cp1); float scale = dot(cp0, cp0)*dot(cp1, cp1);
theta = asin(sqrt(dot(cross_prod, cross_prod)/scale)); theta = asin(sqrt(dot(cross_prod, cross_prod)/scale));
if (cosangle < 0.0f) if (cosangle < 0.0f)
theta = M_PI-theta; theta = PI-theta;
} }
else else
theta = acos(cosangle); theta = acos(cosangle);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment