Commit 7901ce20 authored by Peter Eastman's avatar Peter Eastman
Browse files

Implementing FFT for OpenCL

parent d21bf41a
......@@ -32,14 +32,14 @@ using namespace OpenMM;
using namespace Lepton;
using namespace std;
static string doubleToString(double value) {
string OpenCLExpressionUtilities::doubleToString(double value) {
stringstream s;
s.precision(8);
s << scientific << value << "f";
return s.str();
}
static string intToString(int value) {
string OpenCLExpressionUtilities::intToString(int value) {
stringstream s;
s << value;
return s.str();
......
......@@ -64,6 +64,14 @@ public:
* @return the spline coefficients
*/
static std::vector<mm_float4> computeFunctionCoefficients(const std::vector<double>& values, bool interpolating);
/**
* Convert a number to a string in a format suitable for including in a kernel.
*/
static std::string doubleToString(double value);
/**
* Convert a number to a string in a format suitable for including in a kernel.
*/
static std::string intToString(int value);
class FunctionPlaceholder;
private:
static void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "OpenCLFFT3D.h"
#include "OpenCLExpressionUtilities.h"
#include "OpenCLExpressionUtilities.h"
#include "../src/SimTKUtilities/SimTKOpenMMRealType.h"
#include <map>
#include <sstream>
#include <string>
using namespace OpenMM;
using namespace std;
OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize) : context(context), xsize(xsize), ysize(ysize), zsize(zsize) {
xkernel = createKernel(xsize);
ykernel = createKernel(ysize);
zkernel = createKernel(zsize);
}
OpenCLFFT3D::~OpenCLFFT3D() {
}
void OpenCLFFT3D::execFFT(OpenCLArray<mm_float2>& data, bool forward) {
xkernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
xkernel.setArg<cl_float>(1, forward ? 1.0f : -1.0f);
context.executeKernel(xkernel, xsize, xsize);
}
cl::Kernel OpenCLFFT3D::createKernel(int size) {
map<string, string> replacements;
replacements["SIZE"] = OpenCLExpressionUtilities::intToString(size);
replacements["M_PI"] = OpenCLExpressionUtilities::doubleToString(M_PI);
stringstream source;
int unfactored = size;
int stage = 0;
int L = size;
int m = 1;
while (unfactored > 1) {
int input = stage%2;
int output = 1-input;
source<<"{\n";
if (unfactored%5 == 0) {
L = L/5;
source<<"// Pass "<<(stage+1)<<" (radix 5)\n";
source<<"int j = i/"<<m<<";\n";
source<<"float2 c0 = data"<<input<<"[i];\n";
source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
source<<"float2 c2 = data"<<input<<"[i+"<<(2*L*m)<<"];\n";
source<<"float2 c3 = data"<<input<<"[i+"<<(3*L*m)<<"];\n";
source<<"float2 c4 = data"<<input<<"[i+"<<(4*L*m)<<"];\n";
source<<"float2 d0 = c1+c4;\n";
source<<"float2 d1 = c2+c3;\n";
source<<"float2 d2 = "<<OpenCLExpressionUtilities::doubleToString(sin(0.4*M_PI))<<"*(c1-c4);\n";
source<<"float2 d3 = "<<OpenCLExpressionUtilities::doubleToString(sin(0.4*M_PI))<<"*(c2-c3);\n";
source<<"float2 d4 = d0+d1;\n";
source<<"float2 d5 = "<<OpenCLExpressionUtilities::doubleToString(0.25*sqrt(5.0))<<"*(d0-d1);\n";
source<<"float2 d6 = c0-0.25f*d4;\n";
source<<"float2 d7 = d6+d5;\n";
source<<"float2 d8 = d6-d5;\n";
string coeff = OpenCLExpressionUtilities::doubleToString(sin(0.2*M_PI)/sin(0.4*M_PI));
source<<"float2 d9 = sign*(float2) (d2.y+"<<coeff<<"*d3.y, -d2.x-"<<coeff<<"*d3.x);\n";
source<<"float2 d10 = sign*(float2) ("<<coeff<<"*d2.y-d3.y, d3.x-"<<coeff<<"*d2.x);\n";
source<<"data"<<output<<"[i+4*j*"<<m<<"] = c0+d4;\n";
source<<"data"<<output<<"[i+(4*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<size<<"/"<<(5*L)<<"], d7+d9);\n";
source<<"data"<<output<<"[i+(4*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*size)<<"/"<<(5*L)<<"], d8+d10);\n";
source<<"data"<<output<<"[i+(4*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*size)<<"/"<<(5*L)<<"], d8-d10);\n";
source<<"data"<<output<<"[i+(4*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7-d9);\n";
m = m*5;
unfactored /= 5;
}
else if (unfactored%4 == 0) {
L = L/4;
source<<"// Pass "<<(stage+1)<<" (radix 4)\n";
source<<"int j = i/"<<m<<";\n";
source<<"float2 c0 = data"<<input<<"[i];\n";
source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
source<<"float2 c2 = data"<<input<<"[i+"<<(2*L*m)<<"];\n";
source<<"float2 c3 = data"<<input<<"[i+"<<(3*L*m)<<"];\n";
source<<"float2 d0 = c0+c2;\n";
source<<"float2 d1 = c0-c2;\n";
source<<"float2 d2 = c1+c3;\n";
source<<"float2 d3 = sign*(float2) (c1.y-c3.y, c3.x-c1.x);\n";
source<<"data"<<output<<"[i+3*j*"<<m<<"] = d0+d2;\n";
source<<"data"<<output<<"[i+(3*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<size<<"/"<<(4*L)<<"], d1+d3);\n";
source<<"data"<<output<<"[i+(3*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*size)<<"/"<<(4*L)<<"], d0-d2);\n";
source<<"data"<<output<<"[i+(3*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1-d3);\n";
m = m*4;
unfactored /= 4;
}
else if (unfactored%3 == 0) {
L = L/3;
source<<"// Pass "<<(stage+1)<<" (radix 3)\n";
source<<"int j = i/"<<m<<";\n";
source<<"float2 c0 = data"<<input<<"[i];\n";
source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
source<<"float2 c2 = data"<<input<<"[i+"<<(2*L*m)<<"];\n";
source<<"float2 d0 = c1+c2;\n";
source<<"float2 d1 = c0-0.5f*d0;\n";
source<<"float2 d2 = sign*"<<OpenCLExpressionUtilities::doubleToString(sin(M_PI/3.0))<<"*(float2) (c1.y-c2.y, c2.x-c1.x);\n";
source<<"data"<<output<<"[i+2*j*"<<m<<"] = c0+d0;\n";
source<<"data"<<output<<"[i+(2*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<size<<"/"<<(3*L)<<"], d1+d2);\n";
source<<"data"<<output<<"[i+(2*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1-d2);\n";
m = m*3;
unfactored /= 3;
}
else if (unfactored%2 == 0) {
L = L/2;
source<<"// Pass "<<(stage+1)<<" (radix 2)\n";
source<<"int j = i/"<<m<<";\n";
source<<"float2 c0 = data"<<input<<"[i];\n";
source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
source<<"data"<<output<<"[i+j*"<<m<<"] = c0+c1;\n";
source<<"data"<<output<<"[i+(j+1)*"<<m<<"] = multiplyComplex(w[j*"<<size<<"/"<<(2*L)<<"], c0-c1);\n";
m = m*2;
unfactored /= 2;
}
else
throw OpenMMException("Illegal size for FFT: "+OpenCLExpressionUtilities::intToString(size));
source<<"barrier(CLK_LOCAL_MEM_FENCE);\n";
source<<"}\n";
++stage;
}
source<<"matrix[i] = data"<<(stage%2)<<"[i];";
replacements["COMPUTE_FFT"] = source.str();
cl::Program program = context.createProgram(context.loadSourceFromFile("fft.cl", replacements));
cl::Kernel kernel(program, "execFFT");
kernel.setArg(2, size*sizeof(mm_float2), NULL);
kernel.setArg(3, size*sizeof(mm_float2), NULL);
kernel.setArg(4, size*sizeof(mm_float2), NULL);
return kernel;
}
#ifndef __OPENMM_OPENCLFFT3D_H__
#define __OPENMM_OPENCLFFT3D_H__
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "OpenCLArray.h"
namespace OpenMM {
class OpenCLFFT3D {
public:
OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize);
~OpenCLFFT3D();
void execFFT(OpenCLArray<mm_float2>& data, bool forward = true);
private:
cl::Kernel createKernel(int size);
int xsize, ysize, zsize;
OpenCLContext& context;
cl::Kernel xkernel, ykernel, zkernel;
};
} // namespace OpenMM
#endif // __OPENMM_OPENCLFFT3D_H__
\ No newline at end of file
float2 multiplyComplex(float2 c1, float2 c2) {
return (float2) (c1.x*c2.x-c1.y*c2.y, c1.x*c2.y+c1.y*c2.x);
}
/**
* Perform a 1D FFT on each row along one axis.
*/
__kernel void execFFT(__global float2* matrix, float sign, __local float2* w, __local float2* data0, __local float2* data1) {
const int i = get_local_id(0);
w[i] = (float2) (cos(-sign*i*2*M_PI/SIZE), sin(-sign*i*2*M_PI/SIZE));
data0[i] = matrix[i];
barrier(CLK_LOCAL_MEM_FENCE);
COMPUTE_FFT
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment