"plugins/freeEnergy/vscode:/vscode.git/clone" did not exist on "465e018fc9f1f5d08d1cfdc9ce01803ce4fccfce"
Unverified Commit dd4eed16 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2266 from peastman/vector

Fixed bug in vector operations in CustomIntegrator on CUDA
parents f36f2dba ddcd2f66
......@@ -123,6 +123,7 @@ private:
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
void callFunction(std::stringstream& out, std::string singleFn, std::string doubleFn, const std::string& arg, const std::string& tempType);
void callFunction2(std::stringstream& out, std::string fn, const std::string& arg1, const std::string& arg2, const std::string& tempType);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
CudaContext& context;
FunctionPlaceholder fp1, fp2, fp3, periodicDistance;
......
......@@ -547,6 +547,15 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out << "RECIP(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ADD_CONSTANT:
if (isVecType) {
string val = context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue());
string arg = getTempName(node.getChildren()[0], temps);
out << "make_" << tempType << "(";
out << val << "+" << arg << ".x, ";
out << val << "+" << arg << ".y, ";
out << val << "+" << arg << ".z)";
}
else
out << context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
break;
case Operation::MULTIPLY_CONSTANT:
......@@ -610,10 +619,10 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
break;
}
case Operation::MIN:
out << "min((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
callFunction2(out, "min", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
break;
case Operation::MAX:
out << "max((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
callFunction2(out, "max", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
break;
case Operation::ABS:
callFunction(out, "fabs", "fabs", getTempName(node.getChildren()[0], temps), tempType);
......@@ -927,3 +936,15 @@ void CudaExpressionUtilities::callFunction(stringstream& out, string singleFn, s
out<<singleFn<<"("<<arg<<")";
}
}
void CudaExpressionUtilities::callFunction2(stringstream& out, string fn, const string& arg1, const string& arg2, const string& tempType) {
bool isVector = (tempType[tempType.size()-1] == '3');
if (isVector) {
out<<"make_"<<tempType<<"(";
out<<fn<<"("<<arg1<<".x, "<<arg2<<".x), ";
out<<fn<<"("<<arg1<<".y, "<<arg2<<".y), ";
out<<fn<<"("<<arg1<<".z, "<<arg2<<".z))";
}
else
out<<fn<<"(("<<tempType<<") "<<arg1<<", ("<<tempType<<") "<<arg2<<")";
}
......@@ -1042,10 +1042,12 @@ void testVectorFunctions() {
integrator.addPerDofVariable("angular", 0.0);
integrator.addPerDofVariable("shuffle", 0.0);
integrator.addPerDofVariable("multicross", 0.0);
integrator.addPerDofVariable("maxplus", 0.0);
integrator.addComputeSum("sumy", "x*vector(0, 1, 0)");
integrator.addComputePerDof("angular", "cross(v, x)");
integrator.addComputePerDof("shuffle", "dot(vector(_z(x), _x(x), _y(x)), v)");
integrator.addComputePerDof("multicross", "cross(vector(1, 0, 0), cross(vector(0, 0, 1), vector(1, 0, 0)))");
integrator.addComputePerDof("maxplus", "max(x, 0.1)+0.5");
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
vector<Vec3> positions(numParticles);
......@@ -1063,14 +1065,16 @@ void testVectorFunctions() {
// See if the expressions were computed correctly.
double sumy = 0;
vector<Vec3> angular, shuffle, multicross;
vector<Vec3> angular, shuffle, multicross, maxplus;
integrator.getPerDofVariable(0, angular);
integrator.getPerDofVariable(1, shuffle);
integrator.getPerDofVariable(2, multicross);
integrator.getPerDofVariable(3, maxplus);
for (int i = 0; i < numParticles; i++) {
ASSERT_EQUAL_VEC(velocities[i].cross(positions[i]), angular[i], 1e-5);
ASSERT_EQUAL_VEC(Vec3(1, 1, 1)*velocities[i].dot(Vec3(positions[i][2], positions[i][0], positions[i][1])), shuffle[i], 1e-5);
ASSERT_EQUAL_VEC(Vec3(0, 0, 1), multicross[i], 1e-5);
ASSERT_EQUAL_VEC(Vec3(max(positions[i][0], 0.1)+0.5, max(positions[i][1], 0.1)+0.5, max(positions[i][2], 0.1)+0.5), maxplus[i], 1e-5);
sumy += positions[i][1];
}
ASSERT_EQUAL_TOL(sumy, integrator.getGlobalVariable(0), 1e-5);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment