Commit ec39f6ff authored by Yutong Zhao's avatar Yutong Zhao
Browse files

nonbonded.cu will now use shuffles on sm_30 or higher

parent 44665537
...@@ -394,7 +394,7 @@ CUmodule CudaContext::createModule(const string source, const map<string, string ...@@ -394,7 +394,7 @@ CUmodule CudaContext::createModule(const string source, const map<string, string
// Write out the source to a temporary file. // Write out the source to a temporary file.
stringstream tempFileName; stringstream tempFileName;
tempFileName << "openmmTempKernel" << /*rand() <<*/ this; // Include a pointer to this context as part of the filename to avoid collisions. tempFileName << "openmmTempKernel" << this; // Include a pointer to this context as part of the filename to avoid collisions.
string inputFile = (tempDir+tempFileName.str()+".cu"); string inputFile = (tempDir+tempFileName.str()+".cu");
string outputFile = (tempDir+tempFileName.str()+".ptx"); string outputFile = (tempDir+tempFileName.str()+".ptx");
string logFile = (tempDir+tempFileName.str()+".log"); string logFile = (tempDir+tempFileName.str()+".log");
...@@ -438,15 +438,15 @@ CUmodule CudaContext::createModule(const string source, const map<string, string ...@@ -438,15 +438,15 @@ CUmodule CudaContext::createModule(const string source, const map<string, string
m<<"Error loading CUDA module: "<<getErrorString(result)<<" ("<<result<<")"; m<<"Error loading CUDA module: "<<getErrorString(result)<<" ("<<result<<")";
throw OpenMMException(m.str()); throw OpenMMException(m.str());
} }
//remove(inputFile.c_str()); remove(inputFile.c_str());
//remove(outputFile.c_str()); remove(outputFile.c_str());
//remove(logFile.c_str()); remove(logFile.c_str());
return module; return module;
} }
catch (...) { catch (...) {
//remove(inputFile.c_str()); remove(inputFile.c_str());
//remove(outputFile.c_str()); remove(outputFile.c_str());
//remove(logFile.c_str()); remove(logFile.c_str());
throw; throw;
} }
} }
......
...@@ -416,6 +416,11 @@ void CudaNonbondedUtilities::setAtomBlockRange(double startFraction, double endF ...@@ -416,6 +416,11 @@ void CudaNonbondedUtilities::setAtomBlockRange(double startFraction, double endF
} }
CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, vector<ParameterInfo>& params, vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) { CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, vector<ParameterInfo>& params, vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) {
map<string, string> defines;
if (context.getComputeCapability() >= 3.0 && !context.getUseDoublePrecision())
defines["ENABLE_SHUFFLE"] = "1";
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_INTERACTION"] = source; replacements["COMPUTE_INTERACTION"] = source;
const string suffixes[] = {"x", "y", "z", "w"}; const string suffixes[] = {"x", "y", "z", "w"};
...@@ -446,88 +451,100 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, ...@@ -446,88 +451,100 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source,
} }
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
/* stringstream load1;
for (int i = 0; i < (int) params.size(); i++) {
load1 << params[i].getType();
load1 << " ";
load1 << params[i].getName();
load1 << "1 = global_";
load1 << params[i].getName();
load1 << "[atom1];\n";
}
replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
bool useShuffle = (defines["ENABLE_SHUFFLE"]=="1");
// Part 1. Defines for on diagonal exclusion tiles
stringstream loadLocal1; stringstream loadLocal1;
if(useShuffle) {
// not needed if using shuffles as we can directly fetch from
// LOAD_ATOM1_PARAMETERS
} else {
for (int i = 0; i < (int) params.size(); i++) { for (int i = 0; i < (int) params.size(); i++) {
if (params[i].getNumComponents() == 1) { if (params[i].getNumComponents() == 1) {
loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n"; loadLocal1<<"localData[threadIdx.x]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
} }
else { else {
for (int j = 0; j < params[i].getNumComponents(); ++j) for (int j = 0; j < params[i].getNumComponents(); ++j)
loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n"; loadLocal1<<"localData[threadIdx.x]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
}
} }
} }
replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
*/
stringstream loadLocal1;
loadLocal1 << "tempSigmaEpsilon = sigmaEpsilon1;" << endl;
//for (int i = 0; i < (int) params.size(); i++) {
// loadLocal1<<params[i].getType()<<" temp"<<params[i].getName()<<"="<<params[i].getName()<<"1;\n";
//}
//cout << loadLocal1.str() << endl;
replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
/* stringstream broadcastWarpData;
stringstream loadLocal2; if(useShuffle) {
for (int i = 0; i < (int) params.size(); i++) { broadcastWarpData << "posq2.x = real_shfl(shflPosq.x, j);\n";
broadcastWarpData << "posq2.y = real_shfl(shflPosq.y, j);\n";
broadcastWarpData << "posq2.z = real_shfl(shflPosq.z, j);\n";
broadcastWarpData << "posq2.w = real_shfl(shflPosq.w, j);\n";
for(int i=0; i< (int) params.size();i++) {
broadcastWarpData << params[i].getType() << " shfl" << params[i].getName() << ";\n";
for(int j=0; j < params[i].getNumComponents(); j++) {
string name;
if (params[i].getNumComponents() == 1) { if (params[i].getNumComponents() == 1) {
loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n"; broadcastWarpData << "shfl" << params[i].getName() << "=real_shfl(" << params[i].getName() <<"1,j);\n";
} else {
broadcastWarpData << "shfl" << params[i].getName()+"."+suffixes[j] << "=real_shfl(" << params[i].getName()+"1."+suffixes[j] <<",j);\n";
} }
else {
loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
for (int j = 0; j < params[i].getNumComponents(); ++j)
loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
} }
} }
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str(); } else {
*/ // not used if not shuffling
}
replacements["BROADCAST_WARP_DATA"] = broadcastWarpData.str();
// Part 2. Defines for off-diagonal exclusions, and neighborlist tiles.
stringstream declareLocal2; stringstream declareLocal2;
if(useShuffle) {
for(int i=0; i< (int) params.size(); i++) { for(int i=0; i< (int) params.size(); i++) {
if (params[i].getNumComponents() == 1) { declareLocal2<<params[i].getType()<<" shfl"<<params[i].getName()<<";\n";
// loadLocal2<<params[i].getType()<<" "<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n"; //if (params[i].getNumComponents() == 1) {
} else { //declareLocal2<<params[i].getType()<<" "<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
declareLocal2<<params[i].getType()<<" temp"<<params[i].getName()<<";\n"; //} else {
// declareLocal2<<params[i].getType()<<" temp"<<params[i].getName()<<";\n";
//}
} }
} else {
// not used if using shared memory
} }
replacements["DECLARE_LOCAL_PARAMETERS"] = declareLocal2.str(); replacements["DECLARE_LOCAL_PARAMETERS"] = declareLocal2.str();
stringstream loadLocal2; stringstream loadLocal2;
if(useShuffle) {
for(int i=0; i< (int) params.size(); i++) { for(int i=0; i< (int) params.size(); i++) {
if (params[i].getNumComponents() == 1) { loadLocal2<<"shfl"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
// loadLocal2<<params[i].getType()<<" "<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
} else {
loadLocal2<<"temp"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
} }
} } else {
/*
for (int i = 0; i < (int) params.size(); i++) { for (int i = 0; i < (int) params.size(); i++) {
if (params[i].getNumComponents() == 1) { if (params[i].getNumComponents() == 1) {
loadLocal2<<params[i].getType()<<" "<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n"; loadLocal2<<"localData[threadIdx.x]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
} }
else { else {
loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n"; loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
for (int j = 0; j < params[i].getNumComponents(); ++j) for (int j = 0; j < params[i].getNumComponents(); ++j)
loadLocal2<<params[i].getType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n"; loadLocal2<<"localData[threadIdx.x]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
} }
} }
*/
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
stringstream load1;
for (int i = 0; i < (int) params.size(); i++) {
load1 << params[i].getType();
load1 << " ";
load1 << params[i].getName();
load1 << "1 = global_";
load1 << params[i].getName();
load1 << "[atom1];\n";
} }
replacements["LOAD_ATOM1_PARAMETERS"] = load1.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
/*
stringstream load2j; stringstream load2j;
if(useShuffle) {
for(int i = 0; i < (int) params.size(); i++)
load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = shfl"<<params[i].getName()<<";\n";
} else {
for (int i = 0; i < (int) params.size(); i++) { for (int i = 0; i < (int) params.size(); i++) {
if (params[i].getNumComponents() == 1) { if (params[i].getNumComponents() == 1) {
load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n"; load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
...@@ -542,67 +559,37 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, ...@@ -542,67 +559,37 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source,
load2j<<");\n"; load2j<<");\n";
} }
} }
replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
*/
stringstream load2j;
for (int i = 0; i < (int) params.size(); i++) {
/*
if (params[i].getNumComponents() == 1) {
load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = "<<params[i].getName()<<";\n";
}
else {
load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = make_"<<params[i].getType()<<"(";
for (int j = 0; j < params[i].getNumComponents(); ++j) {
if (j > 0)
load2j<<", ";
load2j<<params[i].getName()<<"_"<<suffixes[j];
}
load2j<<");\n";
}*/
load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = temp"<<params[i].getName()<<";\n";
} }
replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str(); replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
stringstream broadcastWarpData; stringstream shuffleWarpData;
broadcastWarpData << "posq2.x = __shfl(tempPosq.x, j);\n"; if(useShuffle) {
broadcastWarpData << "posq2.y = __shfl(tempPosq.y, j);\n"; shuffleWarpData << "shflPosq.x = real_shfl(shflPosq.x, tgx+1);\n";
broadcastWarpData << "posq2.z = __shfl(tempPosq.z, j);\n"; shuffleWarpData << "shflPosq.y = real_shfl(shflPosq.y, tgx+1);\n";
broadcastWarpData << "posq2.w = __shfl(tempPosq.w, j);\n"; shuffleWarpData << "shflPosq.z = real_shfl(shflPosq.z, tgx+1);\n";
shuffleWarpData << "shflPosq.w = real_shfl(shflPosq.w, tgx+1);\n";
for(int i=0; i< (int) params.size();i++) { shuffleWarpData << "shflForce.x = real_shfl(shflForce.x, tgx+1);\n";
broadcastWarpData << params[i].getType() << " temp" << params[i].getName() << ";\n"; shuffleWarpData << "shflForce.y = real_shfl(shflForce.y, tgx+1);\n";
for(int j=0; j < params[i].getNumComponents(); j++) { shuffleWarpData << "shflForce.z = real_shfl(shflForce.z, tgx+1);\n";
string name; for(int i=0; i < (int) params.size(); i++) {
if (params[i].getNumComponents() == 1) { if(params[i].getNumComponents() == 1) {
broadcastWarpData << "temp" << params[i].getName() << "=__shfl(" << params[i].getName() <<"1,j);\n"; shuffleWarpData<<"shfl"<<params[i].getName()<<"=real_shfl(shfl"<<params[i].getName()<<", tgx+1);\n";
} else { } else {
broadcastWarpData << "temp" << params[i].getName()+"."+suffixes[j] << "=__shfl(" << params[i].getName()+"1."+suffixes[j] <<",j);\n"; for(int j=0;j<params[i].getNumComponents();j++) {
// looks something like
// shflsigmaEpsilon.x = real_shfl(shflsigmaEpsilon.x,tgx+1);
shuffleWarpData<<"shfl"<<params[i].getName()
<<"."<<suffixes[j]<<"=real_shfl(shfl"
<<params[i].getName()<<"."<<suffixes[j]
<<", tgx+1);\n";
} }
} }
} }
replacements["BROADCAST_WARP_DATA"] = broadcastWarpData.str(); } else {
// not used otherwise
stringstream shuffleWarpData;
shuffleWarpData << "tempPosq.x = __shfl(tempPosq.x, tgx+1);\n";
shuffleWarpData << "tempPosq.y = __shfl(tempPosq.y, tgx+1);\n";
shuffleWarpData << "tempPosq.z = __shfl(tempPosq.z, tgx+1);\n";
shuffleWarpData << "tempPosq.w = __shfl(tempPosq.w, tgx+1);\n";
shuffleWarpData << "tempForces.x = __shfl(tempForces.x, tgx+1);\n";
shuffleWarpData << "tempForces.y = __shfl(tempForces.y, tgx+1);\n";
shuffleWarpData << "tempForces.z = __shfl(tempForces.z, tgx+1);\n";
shuffleWarpData << "tempsigmaEpsilon.x = __shfl(tempsigmaEpsilon.x, tgx+1);\n";
shuffleWarpData << "tempsigmaEpsilon.y = __shfl(tempsigmaEpsilon.y, tgx+1);\n";
/*
for(int i=0; i< (int) params.size(); i++) {
shuffleWarpData << params[i].getName() << "=__shfl(" << params[i].getName() << ", tgx+1);\n";
} }
*/
replacements["SHUFFLE_WARP_DATA"] = shuffleWarpData.str(); replacements["SHUFFLE_WARP_DATA"] = shuffleWarpData.str();
map<string, string> defines;
if (useCutoff) if (useCutoff)
defines["USE_CUTOFF"] = "1"; defines["USE_CUTOFF"] = "1";
if (usePeriodic) if (usePeriodic)
......
This diff is collapsed.
...@@ -872,21 +872,21 @@ int main(int argc, char* argv[]) { ...@@ -872,21 +872,21 @@ int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
platform.setPropertyDefaultValue("CudaPrecision", string(argv[1])); platform.setPropertyDefaultValue("CudaPrecision", string(argv[1]));
//testCoulomb(); testCoulomb();
//testLJ(); testLJ();
//testExclusionsAnd14(); testExclusionsAnd14();
//testCutoff(); testCutoff();
//testCutoff14(); testCutoff14();
//testPeriodic(); testPeriodic();
testLargeSystem(); testLargeSystem();
//testBlockInteractions(false); //testBlockInteractions(false);
//testBlockInteractions(true); //testBlockInteractions(true);
//testDispersionCorrection(); testDispersionCorrection();
//testChangingParameters(); testChangingParameters();
//testParallelComputation(false); testParallelComputation(false);
//testParallelComputation(true); testParallelComputation(true);
//testSwitchingFunction(NonbondedForce::CutoffNonPeriodic); testSwitchingFunction(NonbondedForce::CutoffNonPeriodic);
//testSwitchingFunction(NonbondedForce::PME); testSwitchingFunction(NonbondedForce::PME);
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment