Commit 8eaf3c9c authored by peastman's avatar peastman
Browse files

Merge pull request #1192 from rmcgibbo/add-test-for-1191

Fix gradient of periodicdistance at 0
parents 56d3b854 b00edc8e
......@@ -109,7 +109,8 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
}
out << ");\n";
out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z);\n";
out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
int argIndex = -1;
......@@ -123,17 +124,17 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
if (argIndex == -1)
out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
else if (argIndex == 0)
out << nodeNames[j] << " = periodicDistance_delta.x*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 1)
out << nodeNames[j] << " = periodicDistance_delta.y*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 2)
out << nodeNames[j] << " = periodicDistance_delta.z*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
else if (argIndex == 3)
out << nodeNames[j] << " = -periodicDistance_delta.x*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 4)
out << nodeNames[j] << " = -periodicDistance_delta.y*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 5)
out << nodeNames[j] << " = -periodicDistance_delta.z*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
}
}
else {
......
......@@ -109,7 +109,8 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
}
out << ");\n";
out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z);\n";
out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
int argIndex = -1;
......@@ -123,17 +124,17 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
if (argIndex == -1)
out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
else if (argIndex == 0)
out << nodeNames[j] << " = periodicDistance_delta.x*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 1)
out << nodeNames[j] << " = periodicDistance_delta.y*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 2)
out << nodeNames[j] << " = periodicDistance_delta.z*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
else if (argIndex == 3)
out << nodeNames[j] << " = -periodicDistance_delta.x*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 4)
out << nodeNames[j] << " = -periodicDistance_delta.y*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 5)
out << nodeNames[j] << " = -periodicDistance_delta.z*periodicDistance_rinv;\n";
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
}
}
else {
......
......@@ -1488,6 +1488,8 @@ double ReferenceCalcCustomExternalForceKernel::PeriodicDistanceFunction::evaluat
delta -= boxVectors[1]*floor(delta[1]/boxVectors[1][1]+0.5);
delta -= boxVectors[0]*floor(delta[0]/boxVectors[0][0]+0.5);
double r = sqrt(delta.dot(delta));
if (r == 0)
return 0.0;
if (argIndex < 3)
return delta[argIndex]/r;
return -delta[argIndex-3]/r;
......
......@@ -167,6 +167,38 @@ void testPeriodic() {
}
}
void testZeroPeriodicDistance() {
Vec3 vx(5, 0, 0);
Vec3 vy(0, 6, 0);
Vec3 vz(1, 2, 7);
double x0 = 51, y0 = -17, z0 = 11.2;
System system;
system.setDefaultPeriodicBoxVectors(vx, vy, vz);
system.addParticle(1.0);
CustomExternalForce* force = new CustomExternalForce("periodicdistance(x, y, z, x0, y0, z0)^2");
force->addPerParticleParameter("x0");
force->addPerParticleParameter("y0");
force->addPerParticleParameter("z0");
vector<double> params(3);
params[0] = x0;
params[1] = y0;
params[2] = z0;
force->addParticle(0, params);
system.addForce(force);
ASSERT(force->usesPeriodicBoundaryConditions());
ASSERT(system.usesPeriodicBoundaryConditions());
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(x0, y0, z0);
context.setPositions(positions);
State state = context.getState(State::Positions | State::Forces | State::Energy);
vector<Vec3> forces = state.getForces();
for (int i = 0; i < 3; i++)
ASSERT_EQUAL(forces[0][i], forces[0][i]);
}
void testIllegalVariable() {
System system;
system.addParticle(1.0);
......@@ -192,6 +224,7 @@ int main(int argc, char* argv[]) {
testForce();
testManyParameters();
testPeriodic();
testZeroPeriodicDistance();
testIllegalVariable();
runPlatformTests();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment