LlamaDecoderLayerWeight.cc 15 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

Li Zhang's avatar
Li Zhang committed
18
// Modified from
lvhan028's avatar
lvhan028 committed
19
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc
Li Zhang's avatar
Li Zhang committed
20

lvhan028's avatar
lvhan028 committed
21
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
22
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
lvhan028's avatar
lvhan028 committed
23
24
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h"
xiabo's avatar
xiabo committed
25
26
27
28
// #include <filesystem>
#include <experimental/filesystem>
#include <sys/stat.h>
#include <string>
Li Zhang's avatar
Li Zhang committed
29

lvhan028's avatar
lvhan028 committed
30
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
31

xiabo's avatar
xiabo committed
32
33
34
35
36
bool fileExists(const std::string& path) {
    struct stat buffer;
    return (stat(path.c_str(), &buffer) == 0);
}

Li Zhang's avatar
Li Zhang committed
37
template<typename T>
38
39
40
LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t     head_num,
                                                    size_t     kv_head_num,
                                                    size_t     size_per_head,
Li Zhang's avatar
Li Zhang committed
41
42
                                                    size_t     inter_size,
                                                    WeightType weight_type,
43
                                                    int        group_size,
Li Zhang's avatar
Li Zhang committed
44
45
46
                                                    bool       attn_bias,
                                                    size_t     tensor_para_size,
                                                    size_t     tensor_para_rank):
47
48
49
50
    head_num_(head_num),
    kv_head_num_(kv_head_num),
    size_per_head_(size_per_head),
    hidden_units_(head_num * size_per_head),
Li Zhang's avatar
Li Zhang committed
51
52
    inter_size_(inter_size),
    weight_type_(weight_type),
Li Zhang's avatar
Li Zhang committed
53
    attn_bias_(attn_bias),
Li Zhang's avatar
Li Zhang committed
54
55
56
    tensor_para_size_(tensor_para_size),
    tensor_para_rank_(tensor_para_rank)
{
Li Zhang's avatar
Li Zhang committed
57
    self_attn_weights.qkv.input_dims  = hidden_units_;
58
    self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_;
Li Zhang's avatar
Li Zhang committed
59
    self_attn_weights.qkv.type        = weight_type;
60
    self_attn_weights.qkv.group_size  = group_size;
Li Zhang's avatar
Li Zhang committed
61
62
63
64

    self_attn_weights.output.input_dims  = hidden_units_ / tensor_para_size_;
    self_attn_weights.output.output_dims = hidden_units_;
    self_attn_weights.output.type        = weight_type;
65
    self_attn_weights.output.group_size  = group_size;
Li Zhang's avatar
Li Zhang committed
66

Li Zhang's avatar
Li Zhang committed
67
68
69
    ffn_weights.gating.input_dims  = hidden_units_;
    ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_;
    ffn_weights.gating.type        = weight_type;
70
    ffn_weights.gating.group_size  = group_size;
Li Zhang's avatar
Li Zhang committed
71
72
73
74

    ffn_weights.intermediate.input_dims  = hidden_units_;
    ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_;
    ffn_weights.intermediate.type        = weight_type;
75
76
77
78
79
80
    ffn_weights.intermediate.group_size  = group_size;

    ffn_weights.fused_gating_intermediate.input_dims  = hidden_units_;
    ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2;
    ffn_weights.fused_gating_intermediate.type        = weight_type;
    ffn_weights.fused_gating_intermediate.group_size  = group_size;
Li Zhang's avatar
Li Zhang committed
81
82
83
84

    ffn_weights.output.input_dims  = inter_size_ / tensor_para_size_;
    ffn_weights.output.output_dims = hidden_units_;
    ffn_weights.output.type        = weight_type;
85
    ffn_weights.output.group_size  = group_size;
Li Zhang's avatar
Li Zhang committed
86
87
88
89
90
91
92
93
    mallocWeights();
}

template<typename T>
void freeWeights(LlamaDenseWeight<T>& weights)
{
    cudaFree(weights.kernel);
    cudaFree(weights.bias);
94
    cudaFree(weights.scales_and_zeros);
Li Zhang's avatar
Li Zhang committed
95

96
97
98
    weights.kernel           = nullptr;
    weights.bias             = nullptr;
    weights.scales_and_zeros = nullptr;
Li Zhang's avatar
Li Zhang committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
}

template<typename T>
void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
{
    if (bias) {
        deviceMalloc((T**)&weights.bias, weights.output_dims);
    }
    const size_t bit_size = getBitSize(weights.type);
    if (bit_size >= 16) {  // fp16, fp32
        deviceMalloc((T**)&weights.kernel, weights.input_dims * weights.output_dims);
    }
    else {  // int8, int4
        const int factor = sizeof(float) * 8 / bit_size;
        FT_CHECK(weights.input_dims % factor == 0);
114
115
116
117
        deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor);
        deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor);
        // interleaved scales/zeros
        deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2);
Li Zhang's avatar
Li Zhang committed
118
119
120
    }
}

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
template<typename FirstArg, typename... Args>
std::string concat(FirstArg&& first, Args&&... args)
{
    std::stringstream stream;
    stream << first;
    ((stream << "." << args), ...);
    return stream.str();
}

template<typename T>
void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string& prefix, TensorMap& output)
{
    auto get_name = [=](const std::string& name) { return concat(prefix, name); };

    if (bias) {
        output.insert(get_name("bias"),
                      Tensor{MEMORY_GPU, getTensorType<T>(), {weights.output_dims * sizeof(T)}, weights.bias});
    }
    const size_t bit_size = getBitSize(weights.type);
    if (bit_size >= 16) {
        output.insert(get_name("weight"),
                      Tensor{MEMORY_GPU,
                             getTensorType<T>(),
                             {weights.input_dims * weights.output_dims * sizeof(T)},
                             weights.kernel});
    }
    else {  // int8, int4
        const int factor = sizeof(float) * 8 / bit_size;
        output.insert(get_name("qweight"),
                      Tensor{MEMORY_GPU,
                             TYPE_INT32,
                             {weights.input_dims * weights.output_dims * sizeof(int) / factor},
                             weights.kernel});
        output.insert(get_name("scales_zeros"),
                      Tensor{MEMORY_GPU,
                             getTensorType<T>(),
                             {weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)},
                             weights.scales_and_zeros});
    }
}

Li Zhang's avatar
Li Zhang committed
162
template<typename T>
163
164
165
166
167
168
169
void loadWeights(LlamaDenseWeight<T>& w,
                 std::string          prefix,
                 int                  rank,
                 FtCudaDataType       model_file_type,
                 size_t               tensor_para_size,
                 int                  slice_dim   = 0,
                 std::vector<size_t>  slice_shape = {})
Li Zhang's avatar
Li Zhang committed
170
{
171
172
173
174
175
176
177
178
179
180
    auto       max_prefix = prefix + "." + std::to_string(tensor_para_size - 1);
    const auto type       = model_file_type;

    bool enable_slice = true;
    // Disable slice if tensor param rank is 1
    if (tensor_para_size <= 1) {
        enable_slice = false;
    }
    else {
        // Disable slice if weight has already been sliced
xiabo's avatar
xiabo committed
181
182
        // if (std::filesystem::exists(max_prefix + ".weight") || std::filesystem::exists(max_prefix + ".qweight")) {
        if (fileExists(max_prefix + ".weight") || fileExists(max_prefix + ".qweight")) {
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
            TM_LOG_DEBUG("TP weight exists. Disable runtime TP.");
            enable_slice = false;
        }
    }

    size_t dim0 = w.input_dims;
    size_t dim1 = w.output_dims;
    if (enable_slice) {
        // multiple tp size for slice stride
        if (slice_dim == 0) {
            dim0 = dim0 * tensor_para_size;
            if (slice_shape.size() == 0) {
                slice_shape = {dim0};
            }
        }
        else {
            dim1 = dim1 * tensor_para_size;
            if (slice_shape.size() == 0) {
                slice_shape = {dim1};
            }
        }

        prefix += "." + std::to_string(0);
    }
    else {
        prefix += "." + std::to_string(rank);
    }
Li Zhang's avatar
Li Zhang committed
210
211

    if (w.bias) {
212
213
214
215
        std::vector<ConcateSlice> bias_slices{};
        if (enable_slice) {
            if (slice_dim == 1) {
                size_t       start = 0;
Chen Xin's avatar
Chen Xin committed
216
217
                ConcateSlice slice0{{{0, 1}}};
                ConcateSlice slice1{{{}}};
218
219
220
221
222
223
224
225
226
                for (auto len : slice_shape) {
                    size_t stride = len / tensor_para_size;
                    slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
                    start += len;
                }
                bias_slices = {slice0, slice1};
            }
        }
        loadWeightFromBin((T*)w.bias, {1, dim1}, prefix + ".bias", type, bias_slices);
Li Zhang's avatar
Li Zhang committed
227
228
229
    }
    const size_t bit_size = getBitSize(w.type);
    if (bit_size >= 16) {  // fp16, fp32
230
231
232
233
        std::vector<ConcateSlice> weight_slices{};
        if (enable_slice) {
            if (slice_dim == 1) {
                size_t       start = 0;
Chen Xin's avatar
Chen Xin committed
234
235
                ConcateSlice slice0{{{0, dim0}}};
                ConcateSlice slice1{{{}}};
236
237
238
239
240
241
242
243
244
                for (auto len : slice_shape) {
                    size_t stride = len / tensor_para_size;
                    slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
                    start += len;
                }
                weight_slices = {slice0, slice1};
            }
            else {
                size_t       start = 0;
Chen Xin's avatar
Chen Xin committed
245
246
                ConcateSlice slice0{{}};
                ConcateSlice slice1{{{0, dim1}}};
247
248
249
250
251
252
253
254
255
                for (auto len : slice_shape) {
                    size_t stride = len / tensor_para_size;
                    slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
                    start += len;
                }
                weight_slices = {slice0, slice1};
            }
        }
        loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type, weight_slices);
Li Zhang's avatar
Li Zhang committed
256
257
258
    }
    else {  // int8, int4
        const int factor = sizeof(float) * 8 / bit_size;
259

260
261
262
263
264
265
266
267
        FT_CHECK(dim1 % factor == 0);

        std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)};
        loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});

        const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1;

        loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
Li Zhang's avatar
Li Zhang committed
268
269
270
271
272
273
274
275
276
    }
}

template<typename T>
void LlamaDecoderLayerWeight<T>::mallocWeights()
{
    deviceMalloc((T**)&self_attn_norm_weights, hidden_units_);
    deviceMalloc((T**)&ffn_norm_weights, hidden_units_);

lvhan028's avatar
lvhan028 committed
277
278
    turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_);
    turbomind::mallocWeights(self_attn_weights.output, attn_bias_);
279
    self_attn_weights.past_kv_scale = {1.f, 0.f, 1.f, 0.f};
Li Zhang's avatar
Li Zhang committed
280

281
282
283
284
285
286
287
288
    if (weight_type_ == WeightType::kINT4) {
        turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false);
    }
    else {
        turbomind::mallocWeights(ffn_weights.gating, false);
        turbomind::mallocWeights(ffn_weights.intermediate, false);
    }

lvhan028's avatar
lvhan028 committed
289
    turbomind::mallocWeights(ffn_weights.output, false);
Li Zhang's avatar
Li Zhang committed
290
291
292
293
294
295
296
297
298
299
}

template<typename T>
LlamaDecoderLayerWeight<T>::~LlamaDecoderLayerWeight()
{
    cudaFree((void*)self_attn_norm_weights);
    cudaFree((void*)ffn_norm_weights);

    freeWeights(self_attn_weights.qkv);
    freeWeights(self_attn_weights.output);
300
301
302
303
304
305
306
307
308

    if (weight_type_ == WeightType::kINT4) {
        freeWeights(ffn_weights.fused_gating_intermediate);
    }
    else {
        freeWeights(ffn_weights.gating);
        freeWeights(ffn_weights.intermediate);
    }

Li Zhang's avatar
Li Zhang committed
309
310
311
312
313
314
315
316
317
318
319
320
321
    freeWeights(ffn_weights.output);
}

template<typename T>
void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType model_file_type)
{
    const auto rank_spec = std::to_string(tensor_para_rank_);
    const auto type      = model_file_type;

    loadWeightFromBin(
        (T*)self_attn_norm_weights, {hidden_units_}, dir_path + ".attention_norm.weight", model_file_type);
    loadWeightFromBin((T*)ffn_norm_weights, {hidden_units_}, dir_path + ".ffn_norm.weight", model_file_type);

322
323
324
325
326
327
328
    loadWeights(self_attn_weights.qkv,
                dir_path + ".attention.w_qkv",
                tensor_para_rank_,
                type,
                tensor_para_size_,
                1,
                {head_num_ * size_per_head_, kv_head_num_ * size_per_head_, kv_head_num_ * size_per_head_});
329

330
    loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_, 0);
331
332
333
334
335
336
337
338
339
340
341
342
343
344

    if (weight_type_ == WeightType::kINT4) {
        loadWeights(ffn_weights.fused_gating_intermediate,
                    dir_path + ".feed_forward.w13",
                    tensor_para_rank_,
                    type,
                    tensor_para_size_,
                    1);
    }
    else {
        loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_, 1);
        loadWeights(
            ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1);
    }
345
    loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0);
346
347

    // load kv_cache quant scale
AllentDan's avatar
AllentDan committed
348
349
    std::string   scale_path = dir_path + ".past_kv_scale." + rank_spec + ".weight";
    std::ifstream in(scale_path, std::ios::in);
350
351
    if (in.is_open()) {
        in.close();
352
        self_attn_weights.past_kv_scale = loadArrayFromBin({4}, scale_path);
AllentDan's avatar
AllentDan committed
353
    }
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
}

template<typename T>
TensorMap LlamaDecoderLayerWeight<T>::getParams(std::string prefix)
{
    TensorMap output;

    output.insert(concat(prefix, "attention_norm.weight"),
                  Tensor{MEMORY_GPU, getTensorType<T>(), {hidden_units_ * sizeof(T)}, self_attn_norm_weights});

    output.insert(concat(prefix, "ffn_norm.weight"),
                  Tensor{MEMORY_GPU, getTensorType<T>(), {hidden_units_ * sizeof(T)}, ffn_norm_weights});

    auto get_prefix = [=](std::string_view name) { return concat(prefix, name, tensor_para_rank_); };

    getWeightTensor(self_attn_weights.qkv, attn_bias_, get_prefix("attention.w_qkv"), output);

    getWeightTensor(self_attn_weights.output, attn_bias_, get_prefix("attention.wo"), output);

    if (weight_type_ == WeightType::kINT4) {
        getWeightTensor(ffn_weights.fused_gating_intermediate, false, get_prefix("feed_forward.w13"), output);
    }
AllentDan's avatar
AllentDan committed
376
    else {
377
378
        getWeightTensor(ffn_weights.gating, false, get_prefix("feed_forward.w1"), output);
        getWeightTensor(ffn_weights.intermediate, false, get_prefix("feed_forward.w3"), output);
379
    }
380
381
382
383
384
    getWeightTensor(ffn_weights.output, false, get_prefix("feed_forward.w2"), output);
    output.insert(concat(prefix, "past_kv_scale", tensor_para_rank_, "weight"),
                  Tensor{MEMORY_CPU, TYPE_FP32, {4 * sizeof(float)}, self_attn_weights.past_kv_scale.data()});

    return output;
Li Zhang's avatar
Li Zhang committed
385
386
387
388
}

template struct LlamaDecoderLayerWeight<float>;
template struct LlamaDecoderLayerWeight<half>;
q.yao's avatar
q.yao committed
389
390
391
#ifdef ENABLE_BF16
template struct LlamaDecoderLayerWeight<__nv_bfloat16>;
#endif
Li Zhang's avatar
Li Zhang committed
392

lvhan028's avatar
lvhan028 committed
393
}  // namespace turbomind