Commit 1330af5c authored by peastman's avatar peastman
Browse files

Minor code cleanup

parent cfd815ec
...@@ -26,5 +26,4 @@ __kernel void execFFT(__global const real2* restrict in, __global real2* restric ...@@ -26,5 +26,4 @@ __kernel void execFFT(__global const real2* restrict in, __global real2* restric
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
COMPUTE_FFT COMPUTE_FFT
} }
} }
...@@ -382,7 +382,6 @@ void CudaIntegrateRPMDStepKernel::copyToContext(int copy, ContextImpl& context) ...@@ -382,7 +382,6 @@ void CudaIntegrateRPMDStepKernel::copyToContext(int copy, ContextImpl& context)
string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable, bool forward) { string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable, bool forward) {
stringstream source; stringstream source;
int unfactored = size;
int stage = 0; int stage = 0;
int L = size; int L = size;
int m = 1; int m = 1;
...@@ -398,16 +397,27 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable, ...@@ -398,16 +397,27 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable,
// Factor size, generating an appropriate block of code for each factor. // Factor size, generating an appropriate block of code for each factor.
while (unfactored > 1) { while (L > 1) {
int input = stage%2; int input = stage%2;
int output = 1-input; int output = 1-input;
int radix;
if (L%5 == 0)
radix = 5;
else if (L%4 == 0)
radix = 4;
else if (L%3 == 0)
radix = 3;
else if (L%2 == 0)
radix = 2;
else
throw OpenMMException("Illegal size for FFT: "+cu.intToString(size));
source<<"{\n"; source<<"{\n";
if (unfactored%5 == 0) { L = L/radix;
L = L/5; source<<"// Pass "<<(stage+1)<<" (radix "<<radix<<")\n";
source<<"// Pass "<<(stage+1)<<" (radix 5)\n";
source<<"if (indexInBlock < "<<(L*m)<<") {\n"; source<<"if (indexInBlock < "<<(L*m)<<") {\n";
source<<"int i = indexInBlock;\n"; source<<"int i = indexInBlock;\n";
source<<"int j = i/"<<m<<";\n"; source<<"int j = i/"<<m<<";\n";
if (radix == 5) {
source<<"mixed3 c0r = real"<<input<<"[i];\n"; source<<"mixed3 c0r = real"<<input<<"[i];\n";
source<<"mixed3 c0i = imag"<<input<<"[i];\n"; source<<"mixed3 c0i = imag"<<input<<"[i];\n";
source<<"mixed3 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"mixed3 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -451,16 +461,8 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable, ...@@ -451,16 +461,8 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable,
source<<"imag"<<output<<"[i+(4*j+3)*"<<m<<"] = "<<multImag<<"(w[j*"<<(3*size)<<"/"<<(5*L)<<"], d8r-d10r, d8i-d10i);\n"; source<<"imag"<<output<<"[i+(4*j+3)*"<<m<<"] = "<<multImag<<"(w[j*"<<(3*size)<<"/"<<(5*L)<<"], d8r-d10r, d8i-d10i);\n";
source<<"real"<<output<<"[i+(4*j+4)*"<<m<<"] = "<<multReal<<"(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7r-d9r, d7i-d9i);\n"; source<<"real"<<output<<"[i+(4*j+4)*"<<m<<"] = "<<multReal<<"(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7r-d9r, d7i-d9i);\n";
source<<"imag"<<output<<"[i+(4*j+4)*"<<m<<"] = "<<multImag<<"(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7r-d9r, d7i-d9i);\n"; source<<"imag"<<output<<"[i+(4*j+4)*"<<m<<"] = "<<multImag<<"(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7r-d9r, d7i-d9i);\n";
source<<"}\n";
m = m*5;
unfactored /= 5;
} }
else if (unfactored%4 == 0) { else if (radix == 4) {
L = L/4;
source<<"// Pass "<<(stage+1)<<" (radix 4)\n";
source<<"if (indexInBlock < "<<(L*m)<<") {\n";
source<<"int i = indexInBlock;\n";
source<<"int j = i/"<<m<<";\n";
source<<"mixed3 c0r = real"<<input<<"[i];\n"; source<<"mixed3 c0r = real"<<input<<"[i];\n";
source<<"mixed3 c0i = imag"<<input<<"[i];\n"; source<<"mixed3 c0i = imag"<<input<<"[i];\n";
source<<"mixed3 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"mixed3 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -485,16 +487,8 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable, ...@@ -485,16 +487,8 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable,
source<<"imag"<<output<<"[i+(3*j+2)*"<<m<<"] = "<<multImag<<"(w[j*"<<(2*size)<<"/"<<(4*L)<<"], d0r-d2r, d0i-d2i);\n"; source<<"imag"<<output<<"[i+(3*j+2)*"<<m<<"] = "<<multImag<<"(w[j*"<<(2*size)<<"/"<<(4*L)<<"], d0r-d2r, d0i-d2i);\n";
source<<"real"<<output<<"[i+(3*j+3)*"<<m<<"] = "<<multReal<<"(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1r-d3r, d1i-d3i);\n"; source<<"real"<<output<<"[i+(3*j+3)*"<<m<<"] = "<<multReal<<"(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1r-d3r, d1i-d3i);\n";
source<<"imag"<<output<<"[i+(3*j+3)*"<<m<<"] = "<<multImag<<"(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1r-d3r, d1i-d3i);\n"; source<<"imag"<<output<<"[i+(3*j+3)*"<<m<<"] = "<<multImag<<"(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1r-d3r, d1i-d3i);\n";
source<<"}\n";
m = m*4;
unfactored /= 4;
} }
else if (unfactored%3 == 0) { else if (radix == 3) {
L = L/3;
source<<"// Pass "<<(stage+1)<<" (radix 3)\n";
source<<"if (indexInBlock < "<<(L*m)<<") {\n";
source<<"int i = indexInBlock;\n";
source<<"int j = i/"<<m<<";\n";
source<<"mixed3 c0r = real"<<input<<"[i];\n"; source<<"mixed3 c0r = real"<<input<<"[i];\n";
source<<"mixed3 c0i = imag"<<input<<"[i];\n"; source<<"mixed3 c0i = imag"<<input<<"[i];\n";
source<<"mixed3 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"mixed3 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -513,16 +507,8 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable, ...@@ -513,16 +507,8 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable,
source<<"imag"<<output<<"[i+(2*j+1)*"<<m<<"] = "<<multImag<<"(w[j*"<<size<<"/"<<(3*L)<<"], d1r+d2r, d1i+d2i);\n"; source<<"imag"<<output<<"[i+(2*j+1)*"<<m<<"] = "<<multImag<<"(w[j*"<<size<<"/"<<(3*L)<<"], d1r+d2r, d1i+d2i);\n";
source<<"real"<<output<<"[i+(2*j+2)*"<<m<<"] = "<<multReal<<"(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1r-d2r, d1i-d2i);\n"; source<<"real"<<output<<"[i+(2*j+2)*"<<m<<"] = "<<multReal<<"(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1r-d2r, d1i-d2i);\n";
source<<"imag"<<output<<"[i+(2*j+2)*"<<m<<"] = "<<multImag<<"(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1r-d2r, d1i-d2i);\n"; source<<"imag"<<output<<"[i+(2*j+2)*"<<m<<"] = "<<multImag<<"(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1r-d2r, d1i-d2i);\n";
source<<"}\n";
m = m*3;
unfactored /= 3;
} }
else if (unfactored%2 == 0) { else if (radix == 2) {
L = L/2;
source<<"// Pass "<<(stage+1)<<" (radix 2)\n";
source<<"if (indexInBlock < "<<(L*m)<<") {\n";
source<<"int i = indexInBlock;\n";
source<<"int j = i/"<<m<<";\n";
source<<"mixed3 c0r = real"<<input<<"[i];\n"; source<<"mixed3 c0r = real"<<input<<"[i];\n";
source<<"mixed3 c0i = imag"<<input<<"[i];\n"; source<<"mixed3 c0i = imag"<<input<<"[i];\n";
source<<"mixed3 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"mixed3 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -531,12 +517,9 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable, ...@@ -531,12 +517,9 @@ string CudaIntegrateRPMDStepKernel::createFFT(int size, const string& variable,
source<<"imag"<<output<<"[i+j*"<<m<<"] = c0i+c1i;\n"; source<<"imag"<<output<<"[i+j*"<<m<<"] = c0i+c1i;\n";
source<<"real"<<output<<"[i+(j+1)*"<<m<<"] = "<<multReal<<"(w[j*"<<size<<"/"<<(2*L)<<"], c0r-c1r, c0i-c1i);\n"; source<<"real"<<output<<"[i+(j+1)*"<<m<<"] = "<<multReal<<"(w[j*"<<size<<"/"<<(2*L)<<"], c0r-c1r, c0i-c1i);\n";
source<<"imag"<<output<<"[i+(j+1)*"<<m<<"] = "<<multImag<<"(w[j*"<<size<<"/"<<(2*L)<<"], c0r-c1r, c0i-c1i);\n"; source<<"imag"<<output<<"[i+(j+1)*"<<m<<"] = "<<multImag<<"(w[j*"<<size<<"/"<<(2*L)<<"], c0r-c1r, c0i-c1i);\n";
source<<"}\n";
m = m*2;
unfactored /= 2;
} }
else source<<"}\n";
throw OpenMMException("Illegal size for FFT: "+cu.intToString(size)); m = m*radix;
source<<"__syncthreads();\n"; source<<"__syncthreads();\n";
source<<"}\n"; source<<"}\n";
++stage; ++stage;
......
...@@ -387,7 +387,6 @@ void OpenCLIntegrateRPMDStepKernel::copyToContext(int copy, ContextImpl& context ...@@ -387,7 +387,6 @@ void OpenCLIntegrateRPMDStepKernel::copyToContext(int copy, ContextImpl& context
string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable, bool forward) { string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable, bool forward) {
stringstream source; stringstream source;
int unfactored = size;
int stage = 0; int stage = 0;
int L = size; int L = size;
int m = 1; int m = 1;
...@@ -403,16 +402,27 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable ...@@ -403,16 +402,27 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable
// Factor size, generating an appropriate block of code for each factor. // Factor size, generating an appropriate block of code for each factor.
while (unfactored > 1) { while (L > 1) {
int input = stage%2; int input = stage%2;
int output = 1-input; int output = 1-input;
int radix;
if (L%5 == 0)
radix = 5;
else if (L%4 == 0)
radix = 4;
else if (L%3 == 0)
radix = 3;
else if (L%2 == 0)
radix = 2;
else
throw OpenMMException("Illegal size for FFT: "+cl.intToString(size));
source<<"{\n"; source<<"{\n";
if (unfactored%5 == 0) { L = L/radix;
L = L/5; source<<"// Pass "<<(stage+1)<<" (radix "<<radix<<")\n";
source<<"// Pass "<<(stage+1)<<" (radix 5)\n";
source<<"if (indexInBlock < "<<(L*m)<<") {\n"; source<<"if (indexInBlock < "<<(L*m)<<") {\n";
source<<"int i = indexInBlock;\n"; source<<"int i = indexInBlock;\n";
source<<"int j = i/"<<m<<";\n"; source<<"int j = i/"<<m<<";\n";
if (radix == 5) {
source<<"mixed4 c0r = real"<<input<<"[i];\n"; source<<"mixed4 c0r = real"<<input<<"[i];\n";
source<<"mixed4 c0i = imag"<<input<<"[i];\n"; source<<"mixed4 c0i = imag"<<input<<"[i];\n";
source<<"mixed4 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"mixed4 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -456,16 +466,8 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable ...@@ -456,16 +466,8 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable
source<<"imag"<<output<<"[i+(4*j+3)*"<<m<<"] = "<<multImag<<"(w[j*"<<(3*size)<<"/"<<(5*L)<<"], d8r-d10r, d8i-d10i);\n"; source<<"imag"<<output<<"[i+(4*j+3)*"<<m<<"] = "<<multImag<<"(w[j*"<<(3*size)<<"/"<<(5*L)<<"], d8r-d10r, d8i-d10i);\n";
source<<"real"<<output<<"[i+(4*j+4)*"<<m<<"] = "<<multReal<<"(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7r-d9r, d7i-d9i);\n"; source<<"real"<<output<<"[i+(4*j+4)*"<<m<<"] = "<<multReal<<"(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7r-d9r, d7i-d9i);\n";
source<<"imag"<<output<<"[i+(4*j+4)*"<<m<<"] = "<<multImag<<"(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7r-d9r, d7i-d9i);\n"; source<<"imag"<<output<<"[i+(4*j+4)*"<<m<<"] = "<<multImag<<"(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7r-d9r, d7i-d9i);\n";
source<<"}\n";
m = m*5;
unfactored /= 5;
} }
else if (unfactored%4 == 0) { else if (radix == 4) {
L = L/4;
source<<"// Pass "<<(stage+1)<<" (radix 4)\n";
source<<"if (indexInBlock < "<<(L*m)<<") {\n";
source<<"int i = indexInBlock;\n";
source<<"int j = i/"<<m<<";\n";
source<<"mixed4 c0r = real"<<input<<"[i];\n"; source<<"mixed4 c0r = real"<<input<<"[i];\n";
source<<"mixed4 c0i = imag"<<input<<"[i];\n"; source<<"mixed4 c0i = imag"<<input<<"[i];\n";
source<<"mixed4 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"mixed4 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -490,16 +492,8 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable ...@@ -490,16 +492,8 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable
source<<"imag"<<output<<"[i+(3*j+2)*"<<m<<"] = "<<multImag<<"(w[j*"<<(2*size)<<"/"<<(4*L)<<"], d0r-d2r, d0i-d2i);\n"; source<<"imag"<<output<<"[i+(3*j+2)*"<<m<<"] = "<<multImag<<"(w[j*"<<(2*size)<<"/"<<(4*L)<<"], d0r-d2r, d0i-d2i);\n";
source<<"real"<<output<<"[i+(3*j+3)*"<<m<<"] = "<<multReal<<"(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1r-d3r, d1i-d3i);\n"; source<<"real"<<output<<"[i+(3*j+3)*"<<m<<"] = "<<multReal<<"(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1r-d3r, d1i-d3i);\n";
source<<"imag"<<output<<"[i+(3*j+3)*"<<m<<"] = "<<multImag<<"(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1r-d3r, d1i-d3i);\n"; source<<"imag"<<output<<"[i+(3*j+3)*"<<m<<"] = "<<multImag<<"(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1r-d3r, d1i-d3i);\n";
source<<"}\n";
m = m*4;
unfactored /= 4;
} }
else if (unfactored%3 == 0) { else if (radix == 3) {
L = L/3;
source<<"// Pass "<<(stage+1)<<" (radix 3)\n";
source<<"if (indexInBlock < "<<(L*m)<<") {\n";
source<<"int i = indexInBlock;\n";
source<<"int j = i/"<<m<<";\n";
source<<"mixed4 c0r = real"<<input<<"[i];\n"; source<<"mixed4 c0r = real"<<input<<"[i];\n";
source<<"mixed4 c0i = imag"<<input<<"[i];\n"; source<<"mixed4 c0i = imag"<<input<<"[i];\n";
source<<"mixed4 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"mixed4 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -518,16 +512,8 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable ...@@ -518,16 +512,8 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable
source<<"imag"<<output<<"[i+(2*j+1)*"<<m<<"] = "<<multImag<<"(w[j*"<<size<<"/"<<(3*L)<<"], d1r+d2r, d1i+d2i);\n"; source<<"imag"<<output<<"[i+(2*j+1)*"<<m<<"] = "<<multImag<<"(w[j*"<<size<<"/"<<(3*L)<<"], d1r+d2r, d1i+d2i);\n";
source<<"real"<<output<<"[i+(2*j+2)*"<<m<<"] = "<<multReal<<"(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1r-d2r, d1i-d2i);\n"; source<<"real"<<output<<"[i+(2*j+2)*"<<m<<"] = "<<multReal<<"(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1r-d2r, d1i-d2i);\n";
source<<"imag"<<output<<"[i+(2*j+2)*"<<m<<"] = "<<multImag<<"(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1r-d2r, d1i-d2i);\n"; source<<"imag"<<output<<"[i+(2*j+2)*"<<m<<"] = "<<multImag<<"(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1r-d2r, d1i-d2i);\n";
source<<"}\n";
m = m*3;
unfactored /= 3;
} }
else if (unfactored%2 == 0) { else if (radix == 2) {
L = L/2;
source<<"// Pass "<<(stage+1)<<" (radix 2)\n";
source<<"if (indexInBlock < "<<(L*m)<<") {\n";
source<<"int i = indexInBlock;\n";
source<<"int j = i/"<<m<<";\n";
source<<"mixed4 c0r = real"<<input<<"[i];\n"; source<<"mixed4 c0r = real"<<input<<"[i];\n";
source<<"mixed4 c0i = imag"<<input<<"[i];\n"; source<<"mixed4 c0i = imag"<<input<<"[i];\n";
source<<"mixed4 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"mixed4 c1r = real"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -536,12 +522,9 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable ...@@ -536,12 +522,9 @@ string OpenCLIntegrateRPMDStepKernel::createFFT(int size, const string& variable
source<<"imag"<<output<<"[i+j*"<<m<<"] = c0i+c1i;\n"; source<<"imag"<<output<<"[i+j*"<<m<<"] = c0i+c1i;\n";
source<<"real"<<output<<"[i+(j+1)*"<<m<<"] = "<<multReal<<"(w[j*"<<size<<"/"<<(2*L)<<"], c0r-c1r, c0i-c1i);\n"; source<<"real"<<output<<"[i+(j+1)*"<<m<<"] = "<<multReal<<"(w[j*"<<size<<"/"<<(2*L)<<"], c0r-c1r, c0i-c1i);\n";
source<<"imag"<<output<<"[i+(j+1)*"<<m<<"] = "<<multImag<<"(w[j*"<<size<<"/"<<(2*L)<<"], c0r-c1r, c0i-c1i);\n"; source<<"imag"<<output<<"[i+(j+1)*"<<m<<"] = "<<multImag<<"(w[j*"<<size<<"/"<<(2*L)<<"], c0r-c1r, c0i-c1i);\n";
source<<"}\n";
m = m*2;
unfactored /= 2;
} }
else source<<"}\n";
throw OpenMMException("Illegal size for FFT: "+cl.intToString(size)); m = m*radix;
source<<"barrier(CLK_LOCAL_MEM_FENCE);\n"; source<<"barrier(CLK_LOCAL_MEM_FENCE);\n";
source<<"}\n"; source<<"}\n";
++stage; ++stage;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment