Commit 8eaf3c9c authored by peastman's avatar peastman
Browse files

Merge pull request #1192 from rmcgibbo/add-test-for-1191

Fix gradient of periodicdistance at 0
parents 56d3b854 b00edc8e
...@@ -109,7 +109,8 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -109,7 +109,8 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
} }
out << ");\n"; out << ");\n";
out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n"; out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z);\n"; out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
int argIndex = -1; int argIndex = -1;
...@@ -123,17 +124,17 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -123,17 +124,17 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
if (argIndex == -1) if (argIndex == -1)
out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n"; out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
else if (argIndex == 0) else if (argIndex == 0)
out << nodeNames[j] << " = periodicDistance_delta.x*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 1) else if (argIndex == 1)
out << nodeNames[j] << " = periodicDistance_delta.y*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 2) else if (argIndex == 2)
out << nodeNames[j] << " = periodicDistance_delta.z*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
else if (argIndex == 3) else if (argIndex == 3)
out << nodeNames[j] << " = -periodicDistance_delta.x*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 4) else if (argIndex == 4)
out << nodeNames[j] << " = -periodicDistance_delta.y*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 5) else if (argIndex == 5)
out << nodeNames[j] << " = -periodicDistance_delta.z*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
} }
} }
else { else {
......
...@@ -109,7 +109,8 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -109,7 +109,8 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
} }
out << ");\n"; out << ");\n";
out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n"; out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z);\n"; out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
int argIndex = -1; int argIndex = -1;
...@@ -123,17 +124,17 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -123,17 +124,17 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
if (argIndex == -1) if (argIndex == -1)
out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n"; out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
else if (argIndex == 0) else if (argIndex == 0)
out << nodeNames[j] << " = periodicDistance_delta.x*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 1) else if (argIndex == 1)
out << nodeNames[j] << " = periodicDistance_delta.y*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 2) else if (argIndex == 2)
out << nodeNames[j] << " = periodicDistance_delta.z*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
else if (argIndex == 3) else if (argIndex == 3)
out << nodeNames[j] << " = -periodicDistance_delta.x*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 4) else if (argIndex == 4)
out << nodeNames[j] << " = -periodicDistance_delta.y*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 5) else if (argIndex == 5)
out << nodeNames[j] << " = -periodicDistance_delta.z*periodicDistance_rinv;\n"; out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
} }
} }
else { else {
......
...@@ -1488,6 +1488,8 @@ double ReferenceCalcCustomExternalForceKernel::PeriodicDistanceFunction::evaluat ...@@ -1488,6 +1488,8 @@ double ReferenceCalcCustomExternalForceKernel::PeriodicDistanceFunction::evaluat
delta -= boxVectors[1]*floor(delta[1]/boxVectors[1][1]+0.5); delta -= boxVectors[1]*floor(delta[1]/boxVectors[1][1]+0.5);
delta -= boxVectors[0]*floor(delta[0]/boxVectors[0][0]+0.5); delta -= boxVectors[0]*floor(delta[0]/boxVectors[0][0]+0.5);
double r = sqrt(delta.dot(delta)); double r = sqrt(delta.dot(delta));
if (r == 0)
return 0.0;
if (argIndex < 3) if (argIndex < 3)
return delta[argIndex]/r; return delta[argIndex]/r;
return -delta[argIndex-3]/r; return -delta[argIndex-3]/r;
......
...@@ -167,6 +167,38 @@ void testPeriodic() { ...@@ -167,6 +167,38 @@ void testPeriodic() {
} }
} }
void testZeroPeriodicDistance() {
Vec3 vx(5, 0, 0);
Vec3 vy(0, 6, 0);
Vec3 vz(1, 2, 7);
double x0 = 51, y0 = -17, z0 = 11.2;
System system;
system.setDefaultPeriodicBoxVectors(vx, vy, vz);
system.addParticle(1.0);
CustomExternalForce* force = new CustomExternalForce("periodicdistance(x, y, z, x0, y0, z0)^2");
force->addPerParticleParameter("x0");
force->addPerParticleParameter("y0");
force->addPerParticleParameter("z0");
vector<double> params(3);
params[0] = x0;
params[1] = y0;
params[2] = z0;
force->addParticle(0, params);
system.addForce(force);
ASSERT(force->usesPeriodicBoundaryConditions());
ASSERT(system.usesPeriodicBoundaryConditions());
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(x0, y0, z0);
context.setPositions(positions);
State state = context.getState(State::Positions | State::Forces | State::Energy);
vector<Vec3> forces = state.getForces();
for (int i = 0; i < 3; i++)
ASSERT_EQUAL(forces[0][i], forces[0][i]);
}
void testIllegalVariable() { void testIllegalVariable() {
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -192,6 +224,7 @@ int main(int argc, char* argv[]) { ...@@ -192,6 +224,7 @@ int main(int argc, char* argv[]) {
testForce(); testForce();
testManyParameters(); testManyParameters();
testPeriodic(); testPeriodic();
testZeroPeriodicDistance();
testIllegalVariable(); testIllegalVariable();
runPlatformTests(); runPlatformTests();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment