HipFFT3D.cpp 6.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
 * Portions copyright (c) 2009-2015 Stanford University and the Authors.      *
10
11
 * Portions copyright (c) 2021 Advanced Micro Devices, Inc.                   *
 * Authors:                                                                   *
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "HipFFT3D.h"
#include "HipContext.h"
30
31
#include <fstream>
#include <iostream>
32
#include <sstream>
33
#include <iterator>
34
35
36
37

using namespace OpenMM;
using namespace std;

38
39
40
41
42
43
44
45
HipFFT3D::HipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex, hipStream_t stream, HipArray& in, HipArray& out) :
        context(context), stream(stream) {

    deviceIndex = context.getDeviceIndex();
    inputBuffer = in.getDevicePointer();
    outputBuffer = out.getDevicePointer();
    size_t valueSize = context.getUseDoublePrecision() ? sizeof(double) : sizeof(float);
    inputBufferSize = zsize * ysize * xsize * valueSize;
46
    if (realToComplex) {
47
48
49
50
51
        outputBufferSize = (zsize/2 + 1) * ysize * xsize * valueSize * 2;
    }
    else {
        outputBufferSize = zsize * ysize * xsize * valueSize;
    }
52

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    VkFFTConfiguration configuration = {};
    configuration.performR2C = realToComplex;
    configuration.device = &deviceIndex;
    configuration.num_streams = 1;
    configuration.stream = &this->stream;
    configuration.doublePrecision = context.getUseDoublePrecision();

    configuration.FFTdim = 3;
    configuration.size[0] = zsize;
    configuration.size[1] = ysize;
    configuration.size[2] = xsize;

    configuration.inverseReturnToInputBuffer = true;
    configuration.isInputFormatted = true;
    configuration.inputBufferSize = &inputBufferSize;
    configuration.inputBuffer = &inputBuffer;
    configuration.inputBufferStride[0] = zsize;
    configuration.inputBufferStride[1] = configuration.inputBufferStride[0] * ysize;
    configuration.inputBufferStride[2] = configuration.inputBufferStride[1] * xsize;

    configuration.bufferSize = &outputBufferSize;
    configuration.buffer = &outputBuffer;
    configuration.bufferStride[0] = realToComplex ? (zsize/2 + 1) : zsize;
    configuration.bufferStride[1] = configuration.bufferStride[0] * ysize;
    configuration.bufferStride[2] = configuration.bufferStride[1] * xsize;

    // Combine all parameters into a unique key
    stringstream info;
    int runtimeVersion;
    (void)hipRuntimeGetVersion(&runtimeVersion);
    info << runtimeVersion;
    info << " " << VkFFTGetVersion();
    info << " " << xsize << " " << ysize << " " << zsize;
    info << " " << realToComplex << " " << context.getUseDoublePrecision();

    string cacheFile = context.getCacheFileName(info.str());

    bool hasCache = false;
    vector<char> cacheContent;

    ifstream cache(cacheFile.c_str(), ios::in | ios::binary);
    if (cache.is_open()) {
        cacheContent.insert(cacheContent.begin(), istreambuf_iterator<char>(cache), istreambuf_iterator<char>());
        cache.close();
        hasCache = true;
        // There is an existing cache, load VkFFT kernels from it
        configuration.loadApplicationFromString = 1;
        configuration.loadApplicationString = cacheContent.data();
    }
    else {
        // There is no existing cache, request saving
        configuration.saveApplicationToString = 1;
    }
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    app = new VkFFTApplication();
    VkFFTResult fftResult = initializeVkFFT(app, configuration);
    if (fftResult != VKFFT_SUCCESS) {
        throw OpenMMException("Error executing initializeVkFFT: "+context.intToString(fftResult));
    }

    if (!hasCache) {
        // There is no existing cache, create it
        string outputFile = context.getTempFileName() + ".vkfftcache";
        try {
            ofstream out(outputFile.c_str(), ios::out | ios::binary);
            out.write(reinterpret_cast<char*>(app->saveApplicationString), size_t(app->applicationStringSize));
            out.close();
            if (!out.fail()) {
                if (rename(outputFile.c_str(), cacheFile.c_str()) != 0)
                    remove(outputFile.c_str());
            }
        }
        catch (...) {
            // An error occurred.  Possibly we don't have permission to write to the temp directory.
127
128
129
130
        }
    }
}

131
132
133
134
HipFFT3D::~HipFFT3D() {
    deleteVkFFT(app);
    delete app;
}
135

136
137
138
139
void HipFFT3D::execFFT(bool forward) {
    VkFFTResult fftResult = VkFFTAppend(app, forward ? -1 : 1, NULL);
    if (fftResult != VKFFT_SUCCESS) {
        throw OpenMMException("Error executing VkFFTAppend: "+context.intToString(fftResult));
140
141
142
143
144
145
146
147
148
149
    }
}

int HipFFT3D::findLegalDimension(int minimum) {
    if (minimum < 1)
        return 1;
    while (true) {
        // Attempt to factor the current value.

        int unfactored = minimum;
150
151
        // VkFFT supports prime factors up to 13
        for (int factor = 2; factor <= 13; factor++) {
152
153
154
155
156
157
158
159
            while (unfactored > 1 && unfactored%factor == 0)
                unfactored /= factor;
        }
        if (unfactored == 1)
            return minimum;
        minimum++;
    }
}