OpenCLArray.h 6.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#ifndef OPENMM_OPENCLARRAY_H_
#define OPENMM_OPENCLARRAY_H_

/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
 * Portions copyright (c) 2009 Stanford University and the Authors.           *
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLContext.h"
#include "openmm/OpenMMException.h"
32
33
#include <iostream>
#include <sstream>
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <vector>

namespace OpenMM {

/**
 * This class encapsulates an OpenCL Buffer.  It provides a simplified API for working with it,
 * an optionally includes a buffer in host memory for copying data to and from the OpenCL Buffer.
 */

template <class T>
class OpenCLArray {
public:
    /**
     * Create an OpenCLArray object.
     *
     * @param context           the context for which to create the array
     * @param size              the number of elements in the array
     * @param name              the name of the array
     * @param createHostBuffer  specifies whether to create a buffer in host memory for copying data to and from
     *                          the OpenCL Buffer
54
     * @param flags             the set of flags to specify when creating the OpenCL Buffer
55
     */
56
    OpenCLArray(OpenCLContext& context, int size, const std::string& name, bool createHostBuffer = false, cl_int flags = CL_MEM_READ_WRITE) :
57
            context(context), size(size), name(name), local(createHostBuffer ? size : 0), ownsBuffer(true) {
58
        try {
59
            buffer = new cl::Buffer(context.getContext(), flags, size*sizeof(T));
60
61
62
63
64
65
        }
        catch (cl::Error err) {
            std::stringstream str;
            str<<"Error creating array "<<name<<": "<<err.what()<<" ("<<err.err()<<")";
            throw OpenMMException(str.str());
        }
66
    }
67
68
69
70
71
72
73
74
75
76
77
78
79
    /**
     * Create an OpenCLArray object the uses a preexisting Buffer.
     *
     * @param context           the context for which to create the array
     * @param buffer            the OpenCL Buffer this object encapsulates
     * @param size              the number of elements in the array
     * @param name              the name of the array
     * @param createHostBuffer  specifies whether to create a buffer in host memory for copying data to and from
     *                          the OpenCL Buffer
     */
    OpenCLArray(OpenCLContext& context, cl::Buffer* buffer, int size, const std::string& name, bool createHostBuffer = false) :
            context(context), buffer(buffer), size(size), name(name), local(createHostBuffer ? size : 0), ownsBuffer(false) {
    }
80
    ~OpenCLArray() {
81
82
        if (ownsBuffer)
            delete buffer;
83
84
85
86
87
88
89
    }
    const T& operator[](int index) const {
        return local[index];
    }
    T& operator[](int index) {
        return local[index];
    }
90
91
92
93
94
95
    /**
     * Get the size of the array.
     */
    int getSize() {
        return size;
    }
96
97
98
99
100
101
    /**
     * Get the name of the array.
     */
    const std::string& getName() {
        return name;
    }
102
103
104
105
106
107
    /**
     * Get the OpenCL Buffer object.
     */
    cl::Buffer& getDeviceBuffer() {
        return *buffer;
    }
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    /**
     * Get a pointer to the host buffer.
     */
    T* getHostBuffer() {
        return &local[0];
    }
    /**
     * Get an element of the host buffer.
     */
    const T& get(int index) const {
        return local[index];
    }
    /**
     * Set an element of the host buffer.
     */
    void set(int index, const T& value) {
        local[index] = value;
    }
    /**
     * Copy the values in a vector to the Buffer.
     */
    void upload(std::vector<T>& data) {
130
131
132
133
134
135
136
137
        try {
            context.getQueue().enqueueWriteBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), &data[0]);
        }
        catch (cl::Error err) {
            std::stringstream str;
            str<<"Error uploading array "<<name<<": "<<err.what()<<" ("<<err.err()<<")";
            throw OpenMMException(str.str());
        }
138
139
140
141
142
    }
    /**
     * Copy the values in the Buffer to a vector.
     */
    void download(std::vector<T>& data) const {
143
144
        if (data.size() != size)
            data.resize(size);
145
146
147
148
149
150
151
152
        try {
            context.getQueue().enqueueReadBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), &data[0]);
        }
        catch (cl::Error err) {
            std::stringstream str;
            str<<"Error downloading array "<<name<<": "<<err.what()<<" ("<<err.err()<<")";
            throw OpenMMException(str.str());
        }
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    }
    /**
     * Copy the values in the host buffer to the OpenCL Buffer.
     */
    void upload() {
        if (local.size() == 0)
            throw OpenMMException(name+": Called upload() on an OpenCLArray with no host buffer");
        upload(local);
    }
    /**
     * Copy the values in the Buffer to the host buffer.
     */
    void download() {
        if (local.size() == 0)
            throw OpenMMException(name+": Called download() on an OpenCLArray with no host buffer");
        download(local);
    }
private:
    OpenCLContext& context;
    cl::Buffer* buffer;
    std::vector<T> local;
    int size;
175
    bool ownsBuffer;
176
177
178
179
180
181
    std::string name;
};

} // namespace OpenMM

#endif /*OPENMM_OPENCLARRAY_H_*/