"ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "01016dea345562c8cd116a2d2d2ce9bcf03b99c4"
Unverified Commit f08a1cf8 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Integrators use double precision constants in mixed precision mode (#4079)

parent ee39f7ca
...@@ -414,8 +414,10 @@ public: ...@@ -414,8 +414,10 @@ public:
/** /**
* Convert a number to a string in a format suitable for including in a kernel. * Convert a number to a string in a format suitable for including in a kernel.
* This takes into account whether the context uses single or double precision. * This takes into account whether the context uses single or double precision.
* If mixedIsDouble is true, a double precision constant will also be produced
* in mixed precision mode.
*/ */
std::string doubleToString(double value) const; std::string doubleToString(double value, bool mixedIsDouble=false) const;
/** /**
* Convert a number to a string in a format suitable for including in a kernel. * Convert a number to a string in a format suitable for including in a kernel.
*/ */
......
...@@ -5577,7 +5577,7 @@ void CommonIntegrateNoseHooverStepKernel::initialize(const System& system, const ...@@ -5577,7 +5577,7 @@ void CommonIntegrateNoseHooverStepKernel::initialize(const System& system, const
ContextSelector selector(cc); ContextSelector selector(cc);
bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision(); bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision();
map<string, string> defines; map<string, string> defines;
defines["BOLTZ"] = cc.doubleToString(BOLTZ); defines["BOLTZ"] = cc.doubleToString(BOLTZ, true);
ComputeProgram program = cc.compileProgram(CommonKernelSources::noseHooverIntegrator, defines); ComputeProgram program = cc.compileProgram(CommonKernelSources::noseHooverIntegrator, defines);
kernel1 = program->createKernel("integrateNoseHooverMiddlePart1"); kernel1 = program->createKernel("integrateNoseHooverMiddlePart1");
kernel2 = program->createKernel("integrateNoseHooverMiddlePart2"); kernel2 = program->createKernel("integrateNoseHooverMiddlePart2");
......
...@@ -102,11 +102,12 @@ string ComputeContext::replaceStrings(const string& input, const std::map<std::s ...@@ -102,11 +102,12 @@ string ComputeContext::replaceStrings(const string& input, const std::map<std::s
return result; return result;
} }
string ComputeContext::doubleToString(double value) const { string ComputeContext::doubleToString(double value, bool mixedIsDouble) const {
stringstream s; stringstream s;
s.precision(getUseDoublePrecision() ? 16 : 8); bool useDouble = (getUseDoublePrecision() || (mixedIsDouble && getUseMixedPrecision()));
s.precision(useDouble ? 16 : 8);
s << scientific << value; s << scientific << value;
if (!getUseDoublePrecision()) if (!useDouble)
s << "f"; s << "f";
return s.str(); return s.str();
} }
......
...@@ -72,12 +72,13 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -72,12 +72,13 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
string name = prefix+context.intToString(temps.size()); string name = prefix+context.intToString(temps.size());
bool hasRecordedNode = false; bool hasRecordedNode = false;
bool isVecType = (tempType[tempType.size()-1] == '3'); bool isVecType = (tempType[tempType.size()-1] == '3');
bool useMixedPrecision = (tempType.find("double") != string::npos || tempType.find("mixed") != string::npos);
out << tempType << " " << name << " = "; out << tempType << " " << name << " = ";
switch (node.getOperation().getId()) { switch (node.getOperation().getId()) {
case Operation::CONSTANT: case Operation::CONSTANT:
{ {
string value = context.doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue()); string value = context.doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue(), useMixedPrecision);
if (isVecType) if (isVecType)
out << "make_" << tempType << "(" << value << ")"; out << "make_" << tempType << "(" << value << ")";
else else
...@@ -318,7 +319,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -318,7 +319,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
vector<string> paramsFloat, paramsInt; vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) { for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j])); paramsFloat.push_back(context.doubleToString(functionParams[i][j], useMixedPrecision));
paramsInt.push_back(context.intToString((int) functionParams[i][j])); paramsInt.push_back(context.intToString((int) functionParams[i][j]));
} }
vector<string> suffixes; vector<string> suffixes;
...@@ -682,7 +683,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -682,7 +683,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
break; break;
case Operation::ADD_CONSTANT: case Operation::ADD_CONSTANT:
if (isVecType) { if (isVecType) {
string val = context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()); string val = context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue(), useMixedPrecision);
string arg = getTempName(node.getChildren()[0], temps); string arg = getTempName(node.getChildren()[0], temps);
out << "make_" << tempType << "("; out << "make_" << tempType << "(";
out << val << "+" << arg << ".x, "; out << val << "+" << arg << ".x, ";
...@@ -690,10 +691,10 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -690,10 +691,10 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
out << val << "+" << arg << ".z)"; out << val << "+" << arg << ".z)";
} }
else else
out << context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps); out << context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue(), useMixedPrecision) << "+" << getTempName(node.getChildren()[0], temps);
break; break;
case Operation::MULTIPLY_CONSTANT: case Operation::MULTIPLY_CONSTANT:
out << context.doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps); out << context.doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue(), useMixedPrecision) << "*" << getTempName(node.getChildren()[0], temps);
break; break;
case Operation::POWER_CONSTANT: case Operation::POWER_CONSTANT:
{ {
...@@ -756,7 +757,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -756,7 +757,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
out << "}"; out << "}";
} }
else else
out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << context.doubleToString(exponent) << ")"; out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << context.doubleToString(exponent, useMixedPrecision) << ")";
break; break;
} }
case Operation::MIN: case Operation::MIN:
......
...@@ -137,9 +137,9 @@ void CommonIntegrateRPMDStepKernel::initialize(const System& system, const RPMDI ...@@ -137,9 +137,9 @@ void CommonIntegrateRPMDStepKernel::initialize(const System& system, const RPMDI
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms()); defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
defines["NUM_COPIES"] = cc.intToString(numCopies); defines["NUM_COPIES"] = cc.intToString(numCopies);
defines["THREAD_BLOCK_SIZE"] = cc.intToString(workgroupSize); defines["THREAD_BLOCK_SIZE"] = cc.intToString(workgroupSize);
defines["HBAR"] = cc.doubleToString(1.054571628e-34*AVOGADRO/(1000*1e-12)); defines["HBAR"] = cc.doubleToString(1.054571628e-34*AVOGADRO/(1000*1e-12), true);
defines["SCALE"] = cc.doubleToString(1.0/sqrt((double) numCopies)); defines["SCALE"] = cc.doubleToString(1.0/sqrt((double) numCopies), true);
defines["M_PI"] = cc.doubleToString(M_PI); defines["M_PI"] = cc.doubleToString(M_PI, true);
map<string, string> replacements; map<string, string> replacements;
replacements["FFT_Q_FORWARD"] = createFFT(numCopies, "q", true); replacements["FFT_Q_FORWARD"] = createFFT(numCopies, "q", true);
replacements["FFT_Q_BACKWARD"] = createFFT(numCopies, "q", false); replacements["FFT_Q_BACKWARD"] = createFFT(numCopies, "q", false);
...@@ -159,8 +159,8 @@ void CommonIntegrateRPMDStepKernel::initialize(const System& system, const RPMDI ...@@ -159,8 +159,8 @@ void CommonIntegrateRPMDStepKernel::initialize(const System& system, const RPMDI
int copies = g.first; int copies = g.first;
replacements.clear(); replacements.clear();
replacements["NUM_CONTRACTED_COPIES"] = cc.intToString(copies); replacements["NUM_CONTRACTED_COPIES"] = cc.intToString(copies);
replacements["POS_SCALE"] = cc.doubleToString(1.0/numCopies); replacements["POS_SCALE"] = cc.doubleToString(1.0/numCopies, true);
replacements["FORCE_SCALE"] = cc.doubleToString(0x100000000/(double) copies); replacements["FORCE_SCALE"] = cc.doubleToString(0x100000000/(double) copies, true);
replacements["FFT_Q_FORWARD"] = createFFT(numCopies, "q", true); replacements["FFT_Q_FORWARD"] = createFFT(numCopies, "q", true);
replacements["FFT_Q_BACKWARD"] = createFFT(copies, "q", false); replacements["FFT_Q_BACKWARD"] = createFFT(copies, "q", false);
replacements["FFT_F_FORWARD"] = createFFT(copies, "f", true); replacements["FFT_F_FORWARD"] = createFFT(copies, "f", true);
...@@ -474,21 +474,21 @@ string CommonIntegrateRPMDStepKernel::createFFT(int size, const string& variable ...@@ -474,21 +474,21 @@ string CommonIntegrateRPMDStepKernel::createFFT(int size, const string& variable
source<<"mixed3 d0i = c1i+c4i;\n"; source<<"mixed3 d0i = c1i+c4i;\n";
source<<"mixed3 d1r = c2r+c3r;\n"; source<<"mixed3 d1r = c2r+c3r;\n";
source<<"mixed3 d1i = c2i+c3i;\n"; source<<"mixed3 d1i = c2i+c3i;\n";
source<<"mixed3 d2r = "<<cc.doubleToString(sin(0.4*M_PI))<<"*(c1r-c4r);\n"; source<<"mixed3 d2r = "<<cc.doubleToString(sin(0.4*M_PI), true)<<"*(c1r-c4r);\n";
source<<"mixed3 d2i = "<<cc.doubleToString(sin(0.4*M_PI))<<"*(c1i-c4i);\n"; source<<"mixed3 d2i = "<<cc.doubleToString(sin(0.4*M_PI), true)<<"*(c1i-c4i);\n";
source<<"mixed3 d3r = "<<cc.doubleToString(sin(0.4*M_PI))<<"*(c2r-c3r);\n"; source<<"mixed3 d3r = "<<cc.doubleToString(sin(0.4*M_PI), true)<<"*(c2r-c3r);\n";
source<<"mixed3 d3i = "<<cc.doubleToString(sin(0.4*M_PI))<<"*(c2i-c3i);\n"; source<<"mixed3 d3i = "<<cc.doubleToString(sin(0.4*M_PI), true)<<"*(c2i-c3i);\n";
source<<"mixed3 d4r = d0r+d1r;\n"; source<<"mixed3 d4r = d0r+d1r;\n";
source<<"mixed3 d4i = d0i+d1i;\n"; source<<"mixed3 d4i = d0i+d1i;\n";
source<<"mixed3 d5r = "<<cc.doubleToString(0.25*sqrt(5.0))<<"*(d0r-d1r);\n"; source<<"mixed3 d5r = "<<cc.doubleToString(0.25*sqrt(5.0), true)<<"*(d0r-d1r);\n";
source<<"mixed3 d5i = "<<cc.doubleToString(0.25*sqrt(5.0))<<"*(d0i-d1i);\n"; source<<"mixed3 d5i = "<<cc.doubleToString(0.25*sqrt(5.0), true)<<"*(d0i-d1i);\n";
source<<"mixed3 d6r = c0r-0.25f*d4r;\n"; source<<"mixed3 d6r = c0r-0.25f*d4r;\n";
source<<"mixed3 d6i = c0i-0.25f*d4i;\n"; source<<"mixed3 d6i = c0i-0.25f*d4i;\n";
source<<"mixed3 d7r = d6r+d5r;\n"; source<<"mixed3 d7r = d6r+d5r;\n";
source<<"mixed3 d7i = d6i+d5i;\n"; source<<"mixed3 d7i = d6i+d5i;\n";
source<<"mixed3 d8r = d6r-d5r;\n"; source<<"mixed3 d8r = d6r-d5r;\n";
source<<"mixed3 d8i = d6i-d5i;\n"; source<<"mixed3 d8i = d6i-d5i;\n";
string coeff = cc.doubleToString(sin(0.2*M_PI)/sin(0.4*M_PI)); string coeff = cc.doubleToString(sin(0.2*M_PI)/sin(0.4*M_PI), true);
source<<"mixed3 d9r = "<<sign<<"*(d2i+"<<coeff<<"*d3i);\n"; source<<"mixed3 d9r = "<<sign<<"*(d2i+"<<coeff<<"*d3i);\n";
source<<"mixed3 d9i = "<<sign<<"*(-d2r-"<<coeff<<"*d3r);\n"; source<<"mixed3 d9i = "<<sign<<"*(-d2r-"<<coeff<<"*d3r);\n";
source<<"mixed3 d10r = "<<sign<<"*("<<coeff<<"*d2i-d3i);\n"; source<<"mixed3 d10r = "<<sign<<"*("<<coeff<<"*d2i-d3i);\n";
...@@ -541,8 +541,8 @@ string CommonIntegrateRPMDStepKernel::createFFT(int size, const string& variable ...@@ -541,8 +541,8 @@ string CommonIntegrateRPMDStepKernel::createFFT(int size, const string& variable
source<<"mixed3 d0i = c1i+c2i;\n"; source<<"mixed3 d0i = c1i+c2i;\n";
source<<"mixed3 d1r = c0r-0.5f*d0r;\n"; source<<"mixed3 d1r = c0r-0.5f*d0r;\n";
source<<"mixed3 d1i = c0i-0.5f*d0i;\n"; source<<"mixed3 d1i = c0i-0.5f*d0i;\n";
source<<"mixed3 d2r = "<<sign<<"*"<<cc.doubleToString(sin(M_PI/3.0))<<"*(c1i-c2i);\n"; source<<"mixed3 d2r = "<<sign<<"*"<<cc.doubleToString(sin(M_PI/3.0), true)<<"*(c1i-c2i);\n";
source<<"mixed3 d2i = "<<sign<<"*"<<cc.doubleToString(sin(M_PI/3.0))<<"*(c2r-c1r);\n"; source<<"mixed3 d2i = "<<sign<<"*"<<cc.doubleToString(sin(M_PI/3.0), true)<<"*(c2r-c1r);\n";
source<<"real"<<output<<"[i+2*j*"<<m<<"] = c0r+d0r;\n"; source<<"real"<<output<<"[i+2*j*"<<m<<"] = c0r+d0r;\n";
source<<"imag"<<output<<"[i+2*j*"<<m<<"] = c0i+d0i;\n"; source<<"imag"<<output<<"[i+2*j*"<<m<<"] = c0i+d0i;\n";
source<<"real"<<output<<"[i+(2*j+1)*"<<m<<"] = "<<multReal<<"(w[j*"<<size<<"/"<<(3*L)<<"], d1r+d2r, d1i+d2i);\n"; source<<"real"<<output<<"[i+(2*j+1)*"<<m<<"] = "<<multReal<<"(w[j*"<<size<<"/"<<(3*L)<<"], d1r+d2r, d1i+d2i);\n";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment