Commit fdc59e96 authored by Peter Eastman's avatar Peter Eastman
Browse files

Bug fixes to vector expressions

parent b53c6593
......@@ -122,6 +122,7 @@ private:
std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
void callFunction(std::stringstream& out, std::string singleFn, std::string doubleFn, const std::string& arg, const std::string& tempType);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
CudaContext& context;
FunctionPlaceholder fp1, fp2, fp3, periodicDistance;
......
......@@ -445,62 +445,91 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out << "-" << getTempName(node.getChildren()[0], temps);
break;
case Operation::SQRT:
out << "SQRT(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "sqrtf", "sqrt", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::EXP:
out << "EXP(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "expf", "exp", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::LOG:
out << "LOG(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "logf", "log", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::SIN:
out << "SIN(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::COS:
out << "COS(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::SEC:
out << "RECIP(COS(" << getTempName(node.getChildren()[0], temps) << "))";
out << "1/";
callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::CSC:
out << "RECIP(SIN(" << getTempName(node.getChildren()[0], temps) << "))";
out << "1/";
callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::TAN:
out << "TAN(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::COT:
out << "RECIP(TAN(" << getTempName(node.getChildren()[0], temps) << "))";
out << "1/";
callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::ASIN:
out << "ASIN(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "asinf", "asin", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::ACOS:
out << "ACOS(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "acosf", "acos", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::ATAN:
out << "ATAN(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "atanf", "atan", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::SINH:
out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "sinh", "sinh", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::COSH:
out << "cosh(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "cosh", "cosh", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::TANH:
out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "tanh", "tanh", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::ERF:
out << "erf(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "erf", "erf", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::ERFC:
out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "erfc", "erfc", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::STEP:
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
{
string compareVal = getTempName(node.getChildren()[0], temps);
if (isVecType) {
out << "make_" << tempType << "(0);\n";
out << "{\n";
out << tempType<<" tempCompareValue = " << compareVal << ";\n";
out << name << ".x = (tempCompareValue.x >= 0 ? 1 : 0);\n";
out << name << ".y = (tempCompareValue.y >= 0 ? 1 : 0);\n";
out << name << ".z = (tempCompareValue.z >= 0 ? 1 : 0);\n";
out << "}\n";
}
else
out << compareVal << " >= 0 ? 1 : 0";
break;
}
case Operation::DELTA:
out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? 1.0f : 0.0f";
{
string compareVal = getTempName(node.getChildren()[0], temps);
if (isVecType) {
out << "make_" << tempType << "(0);\n";
out << "{\n";
out << tempType<<" tempCompareValue = " << compareVal << ";\n";
out << name << ".x = (tempCompareValue.x == 0 ? 1 : 0);\n";
out << name << ".y = (tempCompareValue.y == 0 ? 1 : 0);\n";
out << name << ".z = (tempCompareValue.z == 0 ? 1 : 0);\n";
out << "}\n";
}
else
out << compareVal << " == 0 ? 1 : 0";
break;
}
case Operation::SQUARE:
{
string arg = getTempName(node.getChildren()[0], temps);
......@@ -586,13 +615,13 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out << "max((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::ABS:
out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "fabs", "fabs", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::FLOOR:
out << "floor(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "floor", "floor", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::CEIL:
out << "ceil(" << getTempName(node.getChildren()[0], temps) << ")";
callFunction(out, "ceil", "ceil", getTempName(node.getChildren()[0], temps), tempType);
break;
case Operation::SELECT:
{
......@@ -880,3 +909,20 @@ Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const Ta
Lepton::CustomFunction* CudaExpressionUtilities::getPeriodicDistancePlaceholder() {
return &periodicDistance;
}
void CudaExpressionUtilities::callFunction(stringstream& out, string singleFn, string doubleFn, const string& arg, const string& tempType) {
bool isDouble = (tempType[0] == 'd');
bool isVector = (tempType[tempType.size()-1] == '3');
if (isVector) {
if (isDouble)
out<<"make_double3("<<doubleFn<<"("<<arg<<".x), "<<doubleFn<<"("<<arg<<".y), "<<doubleFn<<"("<<arg<<".z))";
else
out<<"make_float3("<<singleFn<<"("<<arg<<".x), "<<singleFn<<"("<<arg<<".y), "<<singleFn<<"("<<arg<<".z))";
}
else {
if (isDouble)
out<<doubleFn<<"("<<arg<<")";
else
out<<singleFn<<"("<<arg<<")";
}
}
......@@ -51,7 +51,7 @@ extern "C" __global__ void computePerDof(real4* __restrict__ posq, real4* __rest
#endif
double4 velocity = convertToDouble4(velm[index]);
double4 f = make_double4(forceScale*force[index], forceScale*force[index+PADDED_NUM_ATOMS], forceScale*force[index+PADDED_NUM_ATOMS*2], 0.0);
double mass = 1.0/velocity.w;
double3 mass = make_double3(1.0/velocity.w);
if (velocity.w != 0.0) {
int gaussianIndex = gaussianBaseIndex;
int uniformIndex = 0;
......
......@@ -478,10 +478,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::STEP:
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? (" << tempType << ") 1 : (" << tempType << ") 0";
break;
case Operation::DELTA:
out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? 1.0f : 0.0f";
out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? (" << tempType << ") 1 : (" << tempType << ") 0";
break;
case Operation::SQUARE:
{
......
......@@ -537,9 +537,11 @@ void testPerDofVariables() {
CustomIntegrator integrator(0.01);
integrator.addPerDofVariable("temp", 0);
integrator.addPerDofVariable("pos", 0);
integrator.addPerDofVariable("computed", 0);
integrator.addComputePerDof("v", "v+dt*f/m");
integrator.addComputePerDof("x", "x+dt*v");
integrator.addComputePerDof("pos", "x");
integrator.addComputePerDof("computed", "step(v)*log(x^2)");
Context context(system, integrator, platform);
context.setPositions(positions);
vector<Vec3> initialValues(numParticles);
......@@ -552,13 +554,24 @@ void testPerDofVariables() {
vector<Vec3> values;
for (int i = 0; i < 100; ++i) {
integrator.step(1);
State state = context.getState(State::Positions);
State state = context.getState(State::Positions | State::Velocities);
integrator.getPerDofVariable(0, values);
for (int j = 0; j < numParticles; j++)
ASSERT_EQUAL_VEC(initialValues[j], values[j], 1e-5);
integrator.getPerDofVariable(1, values);
for (int j = 0; j < numParticles; j++)
ASSERT_EQUAL_VEC(state.getPositions()[j], values[j], 1e-5);
integrator.getPerDofVariable(2, values);
for (int j = 0; j < numParticles; j++)
for (int k = 0; k < 3; k++) {
if (state.getVelocities()[j][k] < 0) {
ASSERT(values[j][k] == 0.0);
}
else {
double v = state.getPositions()[j][k];
ASSERT_EQUAL_TOL(log(v*v), values[j][k], 1e-5);
}
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment