inference_context.h 9.07 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

aiss's avatar
aiss committed
5
6
#pragma once

aiss's avatar
aiss committed
7
#include <c10/cuda/CUDAStream.h>
aiss's avatar
aiss committed
8
9
10
11
12
13
14
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"

aiss's avatar
aiss committed
15
16
17
18
#define MEGABYTE (1024 * 1024)
#define GIGABYTE (1024 * 1024 * 1024)

// TODO: refactor out
aiss's avatar
aiss committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#define WARP_SIZE 32

#define CUDA_CHECK(callstr)                                                                    \
    {                                                                                          \
        cudaError_t error_code = callstr;                                                      \
        if (error_code != cudaSuccess) {                                                       \
            std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
            assert(0);                                                                         \
        }                                                                                      \
    }

#define CUDA_1D_KERNEL_LOOP(i, n) \
    for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)

#define CUDA_2D_KERNEL_LOOP(i, n, j, m)                                                          \
    for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
        for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)

#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144

inline int DS_GET_BLOCKS(const int N)
{
    return std::max(
        std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
        // Use at least 1 block, since CUDA does not allow empty block
        1);
}

class Context {
public:
aiss's avatar
aiss committed
50
51
52
53
54
55
56
57
    Context()
        : _workspace(nullptr),
          _seed(42),
          _curr_offset(0),
          _stream(0),
          _free_memory_size(0),
          _num_tokens(1),
          _attention_unfused_workspace_offset(0)
aiss's avatar
aiss committed
58
59
60
61
62
63
64
65
    {
        if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
            auto message = std::string("Fail to create cublas handle.");
            std::cerr << message << std::endl;
            throw std::runtime_error(message);
        }
#ifndef __HIP_PLATFORM_HCC__
        cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
aiss's avatar
aiss committed
66
#endif
aiss's avatar
aiss committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        cudaEventCreate(&_comp1_event);
        cudaEventCreate(&_comp2_event);
        cudaEventCreate(&_comp_event);
        cudaEventCreate(&_comm_event);
    }

    virtual ~Context()
    {
        cublasDestroy(_cublasHandle);
        cudaFree(_workspace);
        cudaEventDestroy(_comp1_event);
        cudaEventDestroy(_comp2_event);
        cudaEventDestroy(_comp_event);
        cudaEventDestroy(_comm_event);
    }

    static Context& Instance()
    {
        static Context _ctx;
        return _ctx;
    }

aiss's avatar
aiss committed
89
90
91
92
93
94
95
96
97
98
    void GenWorkSpace(const unsigned& num_layers,
                      const unsigned& num_heads,
                      const size_t& batch_size,
                      const size_t& prompt_len,
                      const size_t& hidden_dim,
                      const unsigned& mp_size,
                      const bool& external_cache,
                      const size_t& elem_size,
                      const unsigned& rank,
                      unsigned max_out_tokens)
aiss's avatar
aiss committed
99
    {
aiss's avatar
aiss committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        size_t total_size;
        if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); }

        // Flash attention requires padded heads and we'll conservatively allocate
        // for that here. Flash attention is only enabled for head size <= 128 right now
        const int head_size = hidden_dim / num_heads;
        const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128);
        const int effective_head_size = (head_size > 128) ? head_size : padded_head_size;

        size_t activation_size = 16 * (num_heads * effective_head_size) * batch_size;
        // Other sequence length dimension is added when the final workSpaceSize is calculated
        size_t temp_size = batch_size * num_heads * max_out_tokens * 2;
        size_t cache_size =
            num_layers * batch_size * ((num_heads * effective_head_size) / mp_size) * 2;
        size_t minimal_requirements =
            temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE;
        if (_free_memory_size < minimal_requirements) {
            printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
                   minimal_requirements,
                   _free_memory_size,
                   total_size);
            throw std::runtime_error("Workspace can't be allocated, no enough memory.");
        }

        _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) /
                       (activation_size + temp_size + cache_size);
        _max_seq_len = std::min((size_t)max_out_tokens, _max_seq_len);
        size_t workSpaceSize = ((external_cache ? (activation_size + temp_size)
                                                : (activation_size + temp_size + cache_size))) *
                               _max_seq_len * elem_size;
        temp_size *= _max_seq_len * elem_size;
        if (rank == 0 && !_workspace)
            printf(
                "------------------------------------------------------\n"
                "Free memory : %f (GigaBytes)  \n"
                "Total memory: %f (GigaBytes)  \n"
                "Requested memory: %f (GigaBytes) \n"
                "Setting maximum total tokens (input + output) to %lu \n"
                "------------------------------------------------------\n",
                (float)_free_memory_size / GIGABYTE,
                (float)total_size / GIGABYTE,
                (float)workSpaceSize / GIGABYTE,
                _max_seq_len);
aiss's avatar
aiss committed
143
144
        if (!_workspace) {
            assert(_workspace == nullptr);
aiss's avatar
aiss committed
145
146
            cudaMalloc(&_workspace, workSpaceSize);
        } else if (_workSpaceSize < workSpaceSize) {
aiss's avatar
aiss committed
147
            cudaFree(_workspace);
aiss's avatar
aiss committed
148
            cudaMalloc(&_workspace, workSpaceSize);
aiss's avatar
aiss committed
149
150
        }

aiss's avatar
aiss committed
151
152
153
154
155
156
157
158
159
        if (!_workspace) {
            printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
                   workSpaceSize,
                   _free_memory_size,
                   total_size);
            throw std::runtime_error("Workspace is null.");
        }
        _workSpaceSize = workSpaceSize;
        _attention_unfused_workspace_offset = workSpaceSize - temp_size;
aiss's avatar
aiss committed
160
    }
aiss's avatar
aiss committed
161
    inline size_t GetMaxTokenLenght() const { return _max_seq_len; }
aiss's avatar
aiss committed
162
163
164
165
166

    cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }

    size_t get_workspace_size() const { return _workSpaceSize; }
    void* GetWorkSpace() { return _workspace; }
aiss's avatar
aiss committed
167
168
169
170
    void* GetAttentionUnfusedWorkspace()
    {
        return (char*)_workspace + _attention_unfused_workspace_offset;
    }
aiss's avatar
aiss committed
171
172
173
174
175
176
177

    inline unsigned new_token(unsigned layer_id)
    {
        if (layer_id == 0) _token_length++;
        return _token_length;
    }

aiss's avatar
aiss committed
178
    inline void reset_tokens(unsigned initial_tokens = 1)
aiss's avatar
aiss committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    {
        _num_tokens = initial_tokens;
    }  //_token_length = 0; }

    inline unsigned current_tokens() const { return _num_tokens; }

    inline void advance_tokens() { _num_tokens++; }

    cudaStream_t GetCommStream(bool async_op = false)
    {
        if (!_comm_stream)
            _comm_stream = async_op ? at::cuda::getStreamFromPool(true)
                                    : at::cuda::getCurrentCUDAStream();
        return _comm_stream;
    }
    cudaStream_t GetCurrentStream(bool other_stream = false)
    {
        // get current pytorch stream.
        if (other_stream) {
            if (!_stream) _stream = at::cuda::getStreamFromPool(true);
            return _stream;
        }
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();
        return stream;
    }

    cublasHandle_t GetCublasHandle() { return _cublasHandle; }

    std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
    {
        uint64_t offset = _curr_offset;
        _curr_offset += offset_inc;
        return std::pair<uint64_t, uint64_t>(_seed, offset);
    }

    void SetSeed(uint64_t new_seed) { _seed = new_seed; }

    const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }

    inline void SynchComp()
    {
        cudaEventRecord(_comp_event, _comp_stream);
        cudaStreamWaitEvent(_comm_stream, _comp_event, 0);
    }
    inline void SynchComm()
    {
        cudaEventRecord(_comm_event, _comm_stream);
        cudaStreamWaitEvent(_comp_stream, _comm_event, 0);
    }

private:
    cublasHandle_t _cublasHandle;

    cudaEvent_t _comp_event;
    cudaEvent_t _comm_event;

    void* _workspace;
aiss's avatar
aiss committed
236
237
    // offset from _workspace for attention unfused memory
    size_t _attention_unfused_workspace_offset;
aiss's avatar
aiss committed
238
239
    uint64_t _seed;
    uint64_t _curr_offset;
aiss's avatar
aiss committed
240

aiss's avatar
aiss committed
241
    size_t _workSpaceSize;
aiss's avatar
aiss committed
242
243
244
    size_t _free_memory_size;

    size_t _max_seq_len;
aiss's avatar
aiss committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

    cudaEvent_t _comp1_event;
    cudaEvent_t _comp2_event;

    cudaStream_t _stream;

    unsigned _token_length;
    unsigned _num_tokens;
    std::vector<std::array<int, 3>> _gemm_algos;

    cudaStream_t _comp_stream;
    cudaStream_t _comm_stream;

    std::unordered_map<int, int> _world_sizes;
};