allocator.h 14.7 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
/**
 * Memory Allocator
 **/

#pragma once

#include "cuda_utils.h"
Chen Xin's avatar
Chen Xin committed
23
#include "src/turbomind/macro.h"
Li Zhang's avatar
Li Zhang committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <cuda_runtime.h>
#include <unordered_map>
#include <vector>

#ifdef GOOGLE_CUDA
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#endif

#ifdef TORCH_CUDA
#include "torch/extension.h"
#include <memory>
#endif

lvhan028's avatar
lvhan028 committed
45
#include "src/turbomind/utils/logger.h"
Li Zhang's avatar
Li Zhang committed
46
47
48
49
50

#if defined(CUDART_VERSION) && CUDART_VERSION < 11020
#define CUDA_MEMORY_POOL_DISABLED
#endif

lvhan028's avatar
lvhan028 committed
51
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
52

AllentDan's avatar
AllentDan committed
53
54
enum class AllocatorType
{
Li Zhang's avatar
Li Zhang committed
55
56
57
58
59
    CUDA,
    TF,
    TH
};

AllentDan's avatar
AllentDan committed
60
61
enum class ReallocType
{
Li Zhang's avatar
Li Zhang committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    INCREASE,
    REUSE,
    DECREASE,
};

class IAllocator {
public:
    virtual ~IAllocator(){};

    virtual void*        malloc(size_t size, const bool is_set_zero = true, bool is_host = false) = 0;
    virtual void         free(void** ptr, bool is_host = false) const                             = 0;
    virtual void         setStream(cudaStream_t stream)                                           = 0;
    virtual cudaStream_t returnStream()                                                           = 0;
    virtual void         memSet(void* ptr, const int val, const size_t size)                      = 0;

    template<typename T>
    void* reMalloc(T* ptr, size_t size, const bool is_set_zero = true, bool is_host = false)
    {
lvhan028's avatar
lvhan028 committed
80
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
81
82
83
84
85
86
        size              = ((size + 31) / 32) * 32;  // make the buffer align with 32 bytes
        void* void_ptr    = (void*)ptr;
        void* ptr_address = getAddress(void_ptr);
        if (isExist(ptr_address)) {
            ReallocType realloc_type = isReMalloc(ptr_address, size);
            if (realloc_type == ReallocType::INCREASE) {
lvhan028's avatar
lvhan028 committed
87
                TM_LOG_DEBUG("ReMalloc the buffer %p since it is too small.", void_ptr);
Li Zhang's avatar
Li Zhang committed
88
89
90
91
92
                free((void**)(&void_ptr), is_host);
                return malloc(size, is_set_zero, is_host);
            }
#if !defined(CUDA_MEMORY_POOL_DISABLED)
            else if (realloc_type == ReallocType::DECREASE) {
lvhan028's avatar
lvhan028 committed
93
                TM_LOG_DEBUG("ReMalloc the buffer %p to release unused memory to memory pools.", void_ptr);
Li Zhang's avatar
Li Zhang committed
94
95
96
97
98
                free((void**)(&void_ptr), is_host);
                return malloc(size, is_set_zero, is_host);
            }
#endif
            else {
lvhan028's avatar
lvhan028 committed
99
                TM_LOG_DEBUG("Reuse original buffer %p with size %d and do nothing for reMalloc.", void_ptr, size);
Li Zhang's avatar
Li Zhang committed
100
101
102
103
104
105
106
                if (is_set_zero) {
                    memSet(void_ptr, 0, size);
                }
                return void_ptr;
            }
        }
        else {
lvhan028's avatar
lvhan028 committed
107
            TM_LOG_DEBUG("Cannot find buffer %p, mallocing new one.", void_ptr);
Li Zhang's avatar
Li Zhang committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            return malloc(size, is_set_zero, is_host);
        }
    }

protected:
    virtual bool        isExist(void* address) const                 = 0;
    virtual ReallocType isReMalloc(void* address, size_t size) const = 0;

    void* getAddress(void* ptr) const
    {
        return ptr;
    }
};

template<AllocatorType AllocType_>
class Allocator;

template<>
class Allocator<AllocatorType::CUDA>: public IAllocator {
private:
    const int                          device_id_;
    cudaStream_t                       stream_ = 0;  // initialize as default stream
    std::unordered_map<void*, size_t>* pointer_mapping_;

    bool isExist(void* address) const
    {
        return pointer_mapping_->count(address) > 0;
    }
    ReallocType isReMalloc(void* address, size_t size) const
    {
        FT_CHECK(isExist(address));
        if (pointer_mapping_->at(address) < size) {
            return ReallocType::INCREASE;
        }
        else if (pointer_mapping_->at(address) == size) {
            return ReallocType::REUSE;
        }
        else {
            return ReallocType::DECREASE;
        }
    }

public:
    Allocator(int device_id): device_id_(device_id)
    {
lvhan028's avatar
lvhan028 committed
153
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
154
155
        pointer_mapping_ = new std::unordered_map<void*, size_t>();
#if defined(CUDA_MEMORY_POOL_DISABLED)
lvhan028's avatar
lvhan028 committed
156
        TM_LOG_WARNING(
Li Zhang's avatar
Li Zhang committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free."
            "Note this may lead to hang with NCCL kernels launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP");
#else
        int device_count = 1;
        check_cuda_error(cudaGetDeviceCount(&device_count));
        cudaMemPool_t mempool;
        check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id));
        cudaMemAccessDesc desc                  = {};
        int               peer_access_available = 0;
        for (int i = 0; i < device_count; i++) {
            if (i == device_id) {
                continue;
            }
            check_cuda_error(cudaDeviceCanAccessPeer(&peer_access_available, device_id, i));
            if (!peer_access_available) {
lvhan028's avatar
lvhan028 committed
172
                TM_LOG_WARNING("Device " + std::to_string(device_id) + " peer access Device " + std::to_string(i)
Li Zhang's avatar
Li Zhang committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
                               + " is not available.");
                continue;
            }
            desc.location.type = cudaMemLocationTypeDevice;
            desc.location.id   = i;
            desc.flags         = cudaMemAccessFlagsProtReadWrite;
            check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1));
        }
        // set memory pool threshold to avoid shrinking the pool
        uint64_t setVal = UINT64_MAX;
        check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal));
#endif
    }

    virtual ~Allocator()
    {
lvhan028's avatar
lvhan028 committed
189
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        while (!pointer_mapping_->empty()) {
            free((void**)(&pointer_mapping_->begin()->first));
        }
        delete pointer_mapping_;
    }

    void setStream(cudaStream_t stream)
    {
        stream_ = stream;
    }

    cudaStream_t returnStream()
    {
        return stream_;
    };

    void* malloc(size_t size, const bool is_set_zero = true, bool is_host = false)
    {
lvhan028's avatar
lvhan028 committed
208
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        if (size == 0) {
            return nullptr;
        }
        void* ptr      = nullptr;
        int   o_device = 0;

        check_cuda_error(getSetDevice(device_id_, &o_device));
        if (is_host) {
            check_cuda_error(cudaMallocHost(&ptr, (size_t)(ceil(size / 32.)) * 32));
        }
        else {
#if defined(CUDA_MEMORY_POOL_DISABLED)
            check_cuda_error(cudaMalloc(&ptr, (size_t)(ceil(size / 32.)) * 32));
#else
            check_cuda_error(cudaMallocAsync(&ptr, (size_t)(ceil(size / 32.)) * 32, stream_));
#endif
        }
        if (is_set_zero) {
            check_cuda_error(cudaMemsetAsync(ptr, 0, (size_t)(ceil(size / 32.)) * 32, stream_));
        }
        check_cuda_error(getSetDevice(o_device));
lvhan028's avatar
lvhan028 committed
230
        TM_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size);
Li Zhang's avatar
Li Zhang committed
231
232
233
234
235
236
237
238

        pointer_mapping_->insert({getAddress(ptr), size});

        return ptr;
    }

    void free(void** ptr, bool is_host = false) const
    {
lvhan028's avatar
lvhan028 committed
239
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
240
241
242
243
        void* address = getAddress(*ptr);
        if (*ptr != nullptr) {
            int o_device = 0;
            if (pointer_mapping_->count(address)) {
lvhan028's avatar
lvhan028 committed
244
                TM_LOG_DEBUG("Free buffer %p", address);
Li Zhang's avatar
Li Zhang committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
                check_cuda_error(getSetDevice(device_id_, &o_device));
                if (is_host) {
                    check_cuda_error(cudaFreeHost(*ptr));
                }
                else {
#if defined(CUDA_MEMORY_POOL_DISABLED)
                    check_cuda_error(cudaFree(*ptr));
#else
                    check_cuda_error(cudaFreeAsync(*ptr, stream_));
                    cudaStreamSynchronize(stream_);
#endif
                }
                check_cuda_error(getSetDevice(o_device));
                pointer_mapping_->erase(address);
            }
            else {
lvhan028's avatar
lvhan028 committed
261
                TM_LOG_WARNING("pointer_mapping_ does not have information of ptr at %p.", address);
Li Zhang's avatar
Li Zhang committed
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            }
        }
        *ptr = nullptr;
        return;
    }

    void memSet(void* ptr, const int val, const size_t size)
    {
        check_cuda_error(cudaMemsetAsync(ptr, val, size, stream_));
    }
};

#ifdef GOOGLE_CUDA
using namespace tensorflow;
template<>
class Allocator<AllocatorType::TF>: public IAllocator {
    OpKernelContext*                               context_;
    std::unordered_map<void*, tensorflow::Tensor>* pointer_mapping_;
    cudaStream_t                                   stream_;

    bool isExist(void* address) const
    {
        return pointer_mapping_->count(address) > 0;
    }
    ReallocType isReMalloc(void* address, size_t size) const
    {
        FT_CHECK(isExist(address));
        size_t current_buffer_size = 1;
        for (int i = 0; i < pointer_mapping_->at(address).dims(); i++) {
            current_buffer_size *= pointer_mapping_->at(address).dim_size(i);
        }
lvhan028's avatar
lvhan028 committed
293
        TM_LOG_DEBUG("current_buffer_size: %d, new buffer: %d", current_buffer_size, size);
Li Zhang's avatar
Li Zhang committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        if (current_buffer_size < size) {
            return ReallocType::INCREASE;
        }
        else if (current_buffer_size == size) {
            return ReallocType::REUSE;
        }
        else {
            return ReallocType::DECREASE;
        }
    }

public:
    Allocator(OpKernelContext* context, cudaStream_t stream): context_(context), stream_(stream)
    {
        pointer_mapping_ = new std::unordered_map<void*, tensorflow::Tensor>();
    }

    void setStream(cudaStream_t stream)
    {
        stream_ = stream;
    }

    cudaStream_t returnStream()
    {
        return stream_;
    };

    void* malloc(size_t size, const bool is_set_zero = true, bool is_host = false)
    {
lvhan028's avatar
lvhan028 committed
323
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        tensorflow::Tensor buf;
        long long int      buf_size = ((long long int)ceil(size / 32.) * 32);
        tensorflow::Status status;
        if (is_host) {
            tensorflow::AllocatorAttributes pinned_allocator;
            pinned_allocator.set_on_host(true);
            pinned_allocator.set_gpu_compatible(true);
            status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf, pinned_allocator);
        }
        else {
            status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf);
        }

        if (status != tensorflow::Status::OK()) {
            throw std::runtime_error("TF error: context->allocate_temp failed");
        }

        auto  flat = buf.flat<uint8>();
        void* ptr  = (void*)flat.data();
        if (is_set_zero) {
            cudaMemsetAsync(ptr, 0, buf_size, stream_);
        }
        pointer_mapping_->insert({getAddress(ptr), buf});

        return ptr;
    }

    void free(void** ptr, bool is_host = false) const
    {
lvhan028's avatar
lvhan028 committed
353
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        void* address = getAddress(*ptr);
        pointer_mapping_->erase(address);
        *ptr = nullptr;
        return;
    }

    virtual ~Allocator()
    {
        while (!pointer_mapping_->empty()) {
            void* ptr = pointer_mapping_->begin()->second.flat<uint8>().data();
            free((void**)(&ptr));
        }
        pointer_mapping_->clear();
        delete pointer_mapping_;
    }

    void memSet(void* ptr, const int val, const size_t size)
    {
        check_cuda_error(cudaMemsetAsync(ptr, val, size, stream_));
    }
};
#endif

#ifdef TORCH_CUDA
template<>
class Allocator<AllocatorType::TH>: public IAllocator {
    std::unordered_map<void*, torch::Tensor>* pointer_mapping_;

    bool isExist(void* address) const
    {
        return pointer_mapping_->count(address) > 0;
    }
    ReallocType isReMalloc(void* address, size_t size) const
    {
        FT_CHECK(isExist(address));
        size_t current_buffer_size = 1;
        for (int i = 0; i < pointer_mapping_->at(address).dim(); i++) {
            current_buffer_size *= pointer_mapping_->at(address).size(i);
        }
lvhan028's avatar
lvhan028 committed
393
        TM_LOG_DEBUG(
Li Zhang's avatar
Li Zhang committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
            "current_buffer_size: %d, original buffer: %p, new buffer: %d", current_buffer_size, address, size);
        if (current_buffer_size < size) {
            return ReallocType::INCREASE;
        }
        else if (current_buffer_size == size) {
            return ReallocType::REUSE;
        }
        else {
            return ReallocType::DECREASE;
        }
    }

public:
    Allocator()
    {
        pointer_mapping_ = new std::unordered_map<void*, torch::Tensor>();
    }

    void setStream(cudaStream_t stream)
    {
        // nothing to do here;
    }

    cudaStream_t returnStream()
    {
        // nothing to do here;
        return 0;
    };

    void* malloc(size_t size, const bool is_set_zero = true, bool is_host = false)
    {
lvhan028's avatar
lvhan028 committed
425
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
426
427
428
429
430
431
432
433
434
435
436
437
        int64_t       buf_size = static_cast<int64_t>(ceil(size / 32.)) * 32;
        torch::Tensor buf;
        if (is_host) {
            buf = torch::empty({buf_size}, torch::dtype(torch::kUInt8).device(torch::kCPU).pinned_memory(true));
        }
        else {
            buf = torch::empty({buf_size}, torch::dtype(torch::kUInt8).device(torch::kCUDA));
        }
        void* ptr = buf.data_ptr();
        if (is_set_zero) {
            cudaMemset(ptr, 0, buf_size);
        }
lvhan028's avatar
lvhan028 committed
438
        TM_LOG_DEBUG("malloc buffer %p with size %ld", ptr, buf_size);
Li Zhang's avatar
Li Zhang committed
439
440
441
442
443
444
        pointer_mapping_->insert({getAddress(ptr), buf});
        return ptr;
    }

    void free(void** ptr, bool is_host = false) const
    {
lvhan028's avatar
lvhan028 committed
445
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
446
447
448
449
450
451
452
453
        void* address = getAddress(*ptr);
        pointer_mapping_->erase(address);
        *ptr = nullptr;
        return;
    }

    virtual ~Allocator()
    {
lvhan028's avatar
lvhan028 committed
454
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        while (!pointer_mapping_->empty()) {
            void* ptr = pointer_mapping_->begin()->second.data_ptr();
            free((void**)(&ptr));
        }
        pointer_mapping_->clear();
        delete pointer_mapping_;
    }

    void memSet(void* ptr, const int val, const size_t size)
    {
        check_cuda_error(cudaMemset(ptr, val, size));
    }
};
#endif
lvhan028's avatar
lvhan028 committed
469
}  // namespace turbomind