OpenCLArray.h 5.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#ifndef OPENMM_OPENCLARRAY_H_
#define OPENMM_OPENCLARRAY_H_

/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
 * Portions copyright (c) 2009 Stanford University and the Authors.           *
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLContext.h"
#include "openmm/OpenMMException.h"
#include <vector>

namespace OpenMM {

/**
 * This class encapsulates an OpenCL Buffer.  It provides a simplified API for working with it,
 * an optionally includes a buffer in host memory for copying data to and from the OpenCL Buffer.
 */

template <class T>
class OpenCLArray {
public:
    /**
     * Create an OpenCLArray object.
     *
     * @param context           the context for which to create the array
     * @param size              the number of elements in the array
     * @param name              the name of the array
     * @param createHostBuffer  specifies whether to create a buffer in host memory for copying data to and from
     *                          the OpenCL Buffer
     */
    OpenCLArray(OpenCLContext& context, int size, const std::string& name, bool createHostBuffer = false) :
54
            context(context), size(size), name(name), local(createHostBuffer ? size : 0), ownsBuffer(true) {
55
56
        buffer = new cl::Buffer(context.getContext(), CL_MEM_READ_WRITE, size*sizeof(T));
    }
57
58
59
60
61
62
63
64
65
66
67
68
69
    /**
     * Create an OpenCLArray object the uses a preexisting Buffer.
     *
     * @param context           the context for which to create the array
     * @param buffer            the OpenCL Buffer this object encapsulates
     * @param size              the number of elements in the array
     * @param name              the name of the array
     * @param createHostBuffer  specifies whether to create a buffer in host memory for copying data to and from
     *                          the OpenCL Buffer
     */
    OpenCLArray(OpenCLContext& context, cl::Buffer* buffer, int size, const std::string& name, bool createHostBuffer = false) :
            context(context), buffer(buffer), size(size), name(name), local(createHostBuffer ? size : 0), ownsBuffer(false) {
    }
70
    ~OpenCLArray() {
71
72
        if (ownsBuffer)
            delete buffer;
73
74
75
76
77
78
79
    }
    const T& operator[](int index) const {
        return local[index];
    }
    T& operator[](int index) {
        return local[index];
    }
80
81
82
83
84
85
    /**
     * Get the size of the array.
     */
    int getSize() {
        return size;
    }
86
87
88
89
90
91
    /**
     * Get the name of the array.
     */
    const std::string& getName() {
        return name;
    }
92
93
94
95
96
97
    /**
     * Get the OpenCL Buffer object.
     */
    cl::Buffer& getDeviceBuffer() {
        return *buffer;
    }
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    /**
     * Get a pointer to the host buffer.
     */
    T* getHostBuffer() {
        return &local[0];
    }
    /**
     * Get an element of the host buffer.
     */
    const T& get(int index) const {
        return local[index];
    }
    /**
     * Set an element of the host buffer.
     */
    void set(int index, const T& value) {
        local[index] = value;
    }
    /**
     * Copy the values in a vector to the Buffer.
     */
    void upload(std::vector<T>& data) {
120
        context.getQueue().enqueueWriteBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), &data[0]);
121
122
123
124
125
    }
    /**
     * Copy the values in the Buffer to a vector.
     */
    void download(std::vector<T>& data) const {
126
        context.getQueue().enqueueReadBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), &data[0]);
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    }
    /**
     * Copy the values in the host buffer to the OpenCL Buffer.
     */
    void upload() {
        if (local.size() == 0)
            throw OpenMMException(name+": Called upload() on an OpenCLArray with no host buffer");
        upload(local);
    }
    /**
     * Copy the values in the Buffer to the host buffer.
     */
    void download() {
        if (local.size() == 0)
            throw OpenMMException(name+": Called download() on an OpenCLArray with no host buffer");
        download(local);
    }
private:
    OpenCLContext& context;
    cl::Buffer* buffer;
    std::vector<T> local;
    int size;
149
    bool ownsBuffer;
150
151
152
153
154
155
    std::string name;
};

} // namespace OpenMM

#endif /*OPENMM_OPENCLARRAY_H_*/