inference_context.h 10 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team
aiss's avatar
aiss committed
5

aiss's avatar
aiss committed
6
7
#pragma once

aiss's avatar
aiss committed
8
#include <c10/cuda/CUDAStream.h>
aiss's avatar
aiss committed
9
10
11
12
13
14
15
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"

aiss's avatar
aiss committed
16
17
18
19
#define MEGABYTE (1024 * 1024)
#define GIGABYTE (1024 * 1024 * 1024)

// TODO: refactor out
aiss's avatar
aiss committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#define WARP_SIZE 32

#define CUDA_CHECK(callstr)                                                                    \
    {                                                                                          \
        cudaError_t error_code = callstr;                                                      \
        if (error_code != cudaSuccess) {                                                       \
            std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
            assert(0);                                                                         \
        }                                                                                      \
    }

#define CUDA_1D_KERNEL_LOOP(i, n) \
    for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)

#define CUDA_2D_KERNEL_LOOP(i, n, j, m)                                                          \
    for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
        for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)

#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144

inline int DS_GET_BLOCKS(const int N)
{
    return std::max(
        std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
        // Use at least 1 block, since CUDA does not allow empty block
        1);
}

aiss's avatar
aiss committed
49
class InferenceContext {
aiss's avatar
aiss committed
50
public:
aiss's avatar
aiss committed
51
    InferenceContext()
aiss's avatar
aiss committed
52
53
54
55
56
57
        : _workspace(nullptr),
          _seed(42),
          _curr_offset(0),
          _stream(0),
          _free_memory_size(0),
          _num_tokens(1),
aiss's avatar
aiss committed
58
59
          _attention_unfused_workspace_offset(0),
          _workSpaceSize(0)
aiss's avatar
aiss committed
60
    {
aiss's avatar
aiss committed
61
62
        _workSpaceSize = 0;
        _workspace = 0;
aiss's avatar
aiss committed
63
64
65
66
67
68
69
        if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
            auto message = std::string("Fail to create cublas handle.");
            std::cerr << message << std::endl;
            throw std::runtime_error(message);
        }
#ifndef __HIP_PLATFORM_HCC__
        cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
aiss's avatar
aiss committed
70
#endif
aiss's avatar
aiss committed
71
72
73
74
75
76
        cudaEventCreate(&_comp1_event);
        cudaEventCreate(&_comp2_event);
        cudaEventCreate(&_comp_event);
        cudaEventCreate(&_comm_event);
    }

aiss's avatar
aiss committed
77
    virtual ~InferenceContext()
aiss's avatar
aiss committed
78
79
80
81
82
83
84
85
86
    {
        cublasDestroy(_cublasHandle);
        cudaFree(_workspace);
        cudaEventDestroy(_comp1_event);
        cudaEventDestroy(_comp2_event);
        cudaEventDestroy(_comp_event);
        cudaEventDestroy(_comm_event);
    }

aiss's avatar
aiss committed
87
    static InferenceContext& Instance()
aiss's avatar
aiss committed
88
    {
aiss's avatar
aiss committed
89
        static InferenceContext _ctx;
aiss's avatar
aiss committed
90
91
92
        return _ctx;
    }

aiss's avatar
aiss committed
93
94
95
96
97
98
99
100
101
    void GenWorkSpace(const unsigned& num_layers,
                      const unsigned& num_heads,
                      const size_t& batch_size,
                      const size_t& prompt_len,
                      const size_t& hidden_dim,
                      const unsigned& mp_size,
                      const bool& external_cache,
                      const size_t& elem_size,
                      const unsigned& rank,
aiss's avatar
aiss committed
102
103
                      unsigned max_out_tokens,
                      unsigned min_out_tokens)
aiss's avatar
aiss committed
104
    {
aiss's avatar
aiss committed
105
106
107
108
109
110
111
112
113
        size_t total_size;
        if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); }

        // Flash attention requires padded heads and we'll conservatively allocate
        // for that here. Flash attention is only enabled for head size <= 128 right now
        const int head_size = hidden_dim / num_heads;
        const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128);
        const int effective_head_size = (head_size > 128) ? head_size : padded_head_size;

aiss's avatar
aiss committed
114
        size_t activation_size = 10 * (num_heads * effective_head_size) * batch_size;
aiss's avatar
aiss committed
115
        // Other sequence length dimension is added when the final workSpaceSize is calculated
aiss's avatar
aiss committed
116
        size_t temp_size = batch_size * (num_heads / mp_size) * max_out_tokens;
aiss's avatar
aiss committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        size_t cache_size =
            num_layers * batch_size * ((num_heads * effective_head_size) / mp_size) * 2;
        size_t minimal_requirements =
            temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE;
        if (_free_memory_size < minimal_requirements) {
            printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
                   minimal_requirements,
                   _free_memory_size,
                   total_size);
            throw std::runtime_error("Workspace can't be allocated, no enough memory.");
        }

        _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) /
                       (activation_size + temp_size + cache_size);
        _max_seq_len = std::min((size_t)max_out_tokens, _max_seq_len);
        size_t workSpaceSize = ((external_cache ? (activation_size + temp_size)
                                                : (activation_size + temp_size + cache_size))) *
                               _max_seq_len * elem_size;
        temp_size *= _max_seq_len * elem_size;
aiss's avatar
aiss committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

        if (_max_seq_len < min_out_tokens) {
            printf(
                "Allocatable workspace available (%d tokens) is less than minimum requested "
                "workspace (%d tokens)\n",
                _max_seq_len,
                min_out_tokens);
            throw std::runtime_error("Workspace can't be allocated, not enough memory");
        }

        if (!_workspace) {
            assert(_workspace == nullptr);
            cudaMalloc(&_workspace, workSpaceSize);
        } else if (_workSpaceSize < workSpaceSize) {
            cudaFree(_workspace);
            cudaMalloc(&_workspace, workSpaceSize);
        }
        if (rank == 0 && (!_workspace || _workSpaceSize < workSpaceSize))
aiss's avatar
aiss committed
154
155
156
157
158
159
            printf(
                "------------------------------------------------------\n"
                "Free memory : %f (GigaBytes)  \n"
                "Total memory: %f (GigaBytes)  \n"
                "Requested memory: %f (GigaBytes) \n"
                "Setting maximum total tokens (input + output) to %lu \n"
aiss's avatar
aiss committed
160
                "WorkSpace: %p \n"
aiss's avatar
aiss committed
161
162
163
164
                "------------------------------------------------------\n",
                (float)_free_memory_size / GIGABYTE,
                (float)total_size / GIGABYTE,
                (float)workSpaceSize / GIGABYTE,
aiss's avatar
aiss committed
165
166
                _max_seq_len,
                _workspace);
aiss's avatar
aiss committed
167

aiss's avatar
aiss committed
168
169
170
171
172
173
174
175
176
        if (!_workspace) {
            printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
                   workSpaceSize,
                   _free_memory_size,
                   total_size);
            throw std::runtime_error("Workspace is null.");
        }
        _workSpaceSize = workSpaceSize;
        _attention_unfused_workspace_offset = workSpaceSize - temp_size;
aiss's avatar
aiss committed
177
    }
aiss's avatar
aiss committed
178
    inline size_t GetMaxTokenLenght() const { return _max_seq_len; }
aiss's avatar
aiss committed
179
180
181
182
183

    cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }

    size_t get_workspace_size() const { return _workSpaceSize; }
    void* GetWorkSpace() { return _workspace; }
aiss's avatar
aiss committed
184
185
186
187
    void* GetAttentionUnfusedWorkspace()
    {
        return (char*)_workspace + _attention_unfused_workspace_offset;
    }
aiss's avatar
aiss committed
188
189
190
191
192
193
194

    inline unsigned new_token(unsigned layer_id)
    {
        if (layer_id == 0) _token_length++;
        return _token_length;
    }

aiss's avatar
aiss committed
195
    inline void reset_tokens(unsigned initial_tokens = 1)
aiss's avatar
aiss committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    {
        _num_tokens = initial_tokens;
    }  //_token_length = 0; }

    inline unsigned current_tokens() const { return _num_tokens; }

    inline void advance_tokens() { _num_tokens++; }

    cudaStream_t GetCommStream(bool async_op = false)
    {
        if (!_comm_stream)
            _comm_stream = async_op ? at::cuda::getStreamFromPool(true)
                                    : at::cuda::getCurrentCUDAStream();
        return _comm_stream;
    }
    cudaStream_t GetCurrentStream(bool other_stream = false)
    {
        // get current pytorch stream.
        if (other_stream) {
            if (!_stream) _stream = at::cuda::getStreamFromPool(true);
            return _stream;
        }
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();
        return stream;
    }

aiss's avatar
aiss committed
222
223
224
225
226
227
228
229
230
231
232
    void release_workspace()
    {
        cudaFree(_workspace);
        _workspace = nullptr;
    }
    bool retake_workspace()
    {
        if (_workspace != nullptr || _workSpaceSize == 0) return true;
        cudaMalloc(&_workspace, _workSpaceSize);
        return _workspace != nullptr;
    }
aiss's avatar
aiss committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    cublasHandle_t GetCublasHandle() { return _cublasHandle; }

    std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
    {
        uint64_t offset = _curr_offset;
        _curr_offset += offset_inc;
        return std::pair<uint64_t, uint64_t>(_seed, offset);
    }

    void SetSeed(uint64_t new_seed) { _seed = new_seed; }

    const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }

    inline void SynchComp()
    {
        cudaEventRecord(_comp_event, _comp_stream);
        cudaStreamWaitEvent(_comm_stream, _comp_event, 0);
    }
    inline void SynchComm()
    {
        cudaEventRecord(_comm_event, _comm_stream);
        cudaStreamWaitEvent(_comp_stream, _comm_event, 0);
    }

private:
    cublasHandle_t _cublasHandle;

    cudaEvent_t _comp_event;
    cudaEvent_t _comm_event;

    void* _workspace;
aiss's avatar
aiss committed
264
265
    // offset from _workspace for attention unfused memory
    size_t _attention_unfused_workspace_offset;
aiss's avatar
aiss committed
266
267
    uint64_t _seed;
    uint64_t _curr_offset;
aiss's avatar
aiss committed
268

aiss's avatar
aiss committed
269
    size_t _workSpaceSize;
aiss's avatar
aiss committed
270
271
272
    size_t _free_memory_size;

    size_t _max_seq_len;
aiss's avatar
aiss committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

    cudaEvent_t _comp1_event;
    cudaEvent_t _comp2_event;

    cudaStream_t _stream;

    unsigned _token_length;
    unsigned _num_tokens;
    std::vector<std::array<int, 3>> _gemm_algos;

    cudaStream_t _comp_stream;
    cudaStream_t _comm_stream;

    std::unordered_map<int, int> _world_sizes;
};