Commit fdc59e96 authored by Peter Eastman's avatar Peter Eastman
Browse files

Bug fixes to vector expressions

parent b53c6593
...@@ -122,6 +122,7 @@ private: ...@@ -122,6 +122,7 @@ private:
std::vector<const Lepton::ExpressionTreeNode*>& nodes); std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers); std::map<int, const Lepton::ExpressionTreeNode*>& powers);
void callFunction(std::stringstream& out, std::string singleFn, std::string doubleFn, const std::string& arg, const std::string& tempType);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions); std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
CudaContext& context; CudaContext& context;
FunctionPlaceholder fp1, fp2, fp3, periodicDistance; FunctionPlaceholder fp1, fp2, fp3, periodicDistance;
......
...@@ -445,62 +445,91 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -445,62 +445,91 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out << "-" << getTempName(node.getChildren()[0], temps); out << "-" << getTempName(node.getChildren()[0], temps);
break; break;
case Operation::SQRT: case Operation::SQRT:
out << "SQRT(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "sqrtf", "sqrt", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::EXP: case Operation::EXP:
out << "EXP(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "expf", "exp", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::LOG: case Operation::LOG:
out << "LOG(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "logf", "log", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::SIN: case Operation::SIN:
out << "SIN(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::COS: case Operation::COS:
out << "COS(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::SEC: case Operation::SEC:
out << "RECIP(COS(" << getTempName(node.getChildren()[0], temps) << "))"; out << "1/";
callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::CSC: case Operation::CSC:
out << "RECIP(SIN(" << getTempName(node.getChildren()[0], temps) << "))"; out << "1/";
callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::TAN: case Operation::TAN:
out << "TAN(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::COT: case Operation::COT:
out << "RECIP(TAN(" << getTempName(node.getChildren()[0], temps) << "))"; out << "1/";
callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::ASIN: case Operation::ASIN:
out << "ASIN(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "asinf", "asin", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::ACOS: case Operation::ACOS:
out << "ACOS(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "acosf", "acos", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::ATAN: case Operation::ATAN:
out << "ATAN(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "atanf", "atan", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::SINH: case Operation::SINH:
out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "sinh", "sinh", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::COSH: case Operation::COSH:
out << "cosh(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "cosh", "cosh", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::TANH: case Operation::TANH:
out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "tanh", "tanh", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::ERF: case Operation::ERF:
out << "erf(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "erf", "erf", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::ERFC: case Operation::ERFC:
out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "erfc", "erfc", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::STEP: case Operation::STEP:
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f"; {
string compareVal = getTempName(node.getChildren()[0], temps);
if (isVecType) {
out << "make_" << tempType << "(0);\n";
out << "{\n";
out << tempType<<" tempCompareValue = " << compareVal << ";\n";
out << name << ".x = (tempCompareValue.x >= 0 ? 1 : 0);\n";
out << name << ".y = (tempCompareValue.y >= 0 ? 1 : 0);\n";
out << name << ".z = (tempCompareValue.z >= 0 ? 1 : 0);\n";
out << "}\n";
}
else
out << compareVal << " >= 0 ? 1 : 0";
break; break;
}
case Operation::DELTA: case Operation::DELTA:
out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? 1.0f : 0.0f"; {
string compareVal = getTempName(node.getChildren()[0], temps);
if (isVecType) {
out << "make_" << tempType << "(0);\n";
out << "{\n";
out << tempType<<" tempCompareValue = " << compareVal << ";\n";
out << name << ".x = (tempCompareValue.x == 0 ? 1 : 0);\n";
out << name << ".y = (tempCompareValue.y == 0 ? 1 : 0);\n";
out << name << ".z = (tempCompareValue.z == 0 ? 1 : 0);\n";
out << "}\n";
}
else
out << compareVal << " == 0 ? 1 : 0";
break; break;
}
case Operation::SQUARE: case Operation::SQUARE:
{ {
string arg = getTempName(node.getChildren()[0], temps); string arg = getTempName(node.getChildren()[0], temps);
...@@ -586,13 +615,13 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -586,13 +615,13 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out << "max((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")"; out << "max((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
break; break;
case Operation::ABS: case Operation::ABS:
out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "fabs", "fabs", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::FLOOR: case Operation::FLOOR:
out << "floor(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "floor", "floor", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::CEIL: case Operation::CEIL:
out << "ceil(" << getTempName(node.getChildren()[0], temps) << ")"; callFunction(out, "ceil", "ceil", getTempName(node.getChildren()[0], temps), tempType);
break; break;
case Operation::SELECT: case Operation::SELECT:
{ {
...@@ -880,3 +909,20 @@ Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const Ta ...@@ -880,3 +909,20 @@ Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const Ta
Lepton::CustomFunction* CudaExpressionUtilities::getPeriodicDistancePlaceholder() { Lepton::CustomFunction* CudaExpressionUtilities::getPeriodicDistancePlaceholder() {
return &periodicDistance; return &periodicDistance;
} }
void CudaExpressionUtilities::callFunction(stringstream& out, string singleFn, string doubleFn, const string& arg, const string& tempType) {
bool isDouble = (tempType[0] == 'd');
bool isVector = (tempType[tempType.size()-1] == '3');
if (isVector) {
if (isDouble)
out<<"make_double3("<<doubleFn<<"("<<arg<<".x), "<<doubleFn<<"("<<arg<<".y), "<<doubleFn<<"("<<arg<<".z))";
else
out<<"make_float3("<<singleFn<<"("<<arg<<".x), "<<singleFn<<"("<<arg<<".y), "<<singleFn<<"("<<arg<<".z))";
}
else {
if (isDouble)
out<<doubleFn<<"("<<arg<<")";
else
out<<singleFn<<"("<<arg<<")";
}
}
...@@ -51,7 +51,7 @@ extern "C" __global__ void computePerDof(real4* __restrict__ posq, real4* __rest ...@@ -51,7 +51,7 @@ extern "C" __global__ void computePerDof(real4* __restrict__ posq, real4* __rest
#endif #endif
double4 velocity = convertToDouble4(velm[index]); double4 velocity = convertToDouble4(velm[index]);
double4 f = make_double4(forceScale*force[index], forceScale*force[index+PADDED_NUM_ATOMS], forceScale*force[index+PADDED_NUM_ATOMS*2], 0.0); double4 f = make_double4(forceScale*force[index], forceScale*force[index+PADDED_NUM_ATOMS], forceScale*force[index+PADDED_NUM_ATOMS*2], 0.0);
double mass = 1.0/velocity.w; double3 mass = make_double3(1.0/velocity.w);
if (velocity.w != 0.0) { if (velocity.w != 0.0) {
int gaussianIndex = gaussianBaseIndex; int gaussianIndex = gaussianBaseIndex;
int uniformIndex = 0; int uniformIndex = 0;
......
...@@ -478,10 +478,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -478,10 +478,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")"; out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")";
break; break;
case Operation::STEP: case Operation::STEP:
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f"; out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? (" << tempType << ") 1 : (" << tempType << ") 0";
break; break;
case Operation::DELTA: case Operation::DELTA:
out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? 1.0f : 0.0f"; out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? (" << tempType << ") 1 : (" << tempType << ") 0";
break; break;
case Operation::SQUARE: case Operation::SQUARE:
{ {
......
...@@ -537,9 +537,11 @@ void testPerDofVariables() { ...@@ -537,9 +537,11 @@ void testPerDofVariables() {
CustomIntegrator integrator(0.01); CustomIntegrator integrator(0.01);
integrator.addPerDofVariable("temp", 0); integrator.addPerDofVariable("temp", 0);
integrator.addPerDofVariable("pos", 0); integrator.addPerDofVariable("pos", 0);
integrator.addPerDofVariable("computed", 0);
integrator.addComputePerDof("v", "v+dt*f/m"); integrator.addComputePerDof("v", "v+dt*f/m");
integrator.addComputePerDof("x", "x+dt*v"); integrator.addComputePerDof("x", "x+dt*v");
integrator.addComputePerDof("pos", "x"); integrator.addComputePerDof("pos", "x");
integrator.addComputePerDof("computed", "step(v)*log(x^2)");
Context context(system, integrator, platform); Context context(system, integrator, platform);
context.setPositions(positions); context.setPositions(positions);
vector<Vec3> initialValues(numParticles); vector<Vec3> initialValues(numParticles);
...@@ -552,13 +554,24 @@ void testPerDofVariables() { ...@@ -552,13 +554,24 @@ void testPerDofVariables() {
vector<Vec3> values; vector<Vec3> values;
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
integrator.step(1); integrator.step(1);
State state = context.getState(State::Positions); State state = context.getState(State::Positions | State::Velocities);
integrator.getPerDofVariable(0, values); integrator.getPerDofVariable(0, values);
for (int j = 0; j < numParticles; j++) for (int j = 0; j < numParticles; j++)
ASSERT_EQUAL_VEC(initialValues[j], values[j], 1e-5); ASSERT_EQUAL_VEC(initialValues[j], values[j], 1e-5);
integrator.getPerDofVariable(1, values); integrator.getPerDofVariable(1, values);
for (int j = 0; j < numParticles; j++) for (int j = 0; j < numParticles; j++)
ASSERT_EQUAL_VEC(state.getPositions()[j], values[j], 1e-5); ASSERT_EQUAL_VEC(state.getPositions()[j], values[j], 1e-5);
integrator.getPerDofVariable(2, values);
for (int j = 0; j < numParticles; j++)
for (int k = 0; k < 3; k++) {
if (state.getVelocities()[j][k] < 0) {
ASSERT(values[j][k] == 0.0);
}
else {
double v = state.getPositions()[j][k];
ASSERT_EQUAL_TOL(log(v*v), values[j][k], 1e-5);
}
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment