llama_triton_example.cc 24.5 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

Chen Xin's avatar
Chen Xin committed
18
19
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_triton_example.cc
Li Zhang's avatar
Li Zhang committed
20
21

#include "3rdparty/INIReader.h"
Chen Xin's avatar
Chen Xin committed
22
#include <chrono>
Li Zhang's avatar
Li Zhang committed
23
24
25
#include <memory>
#include <thread>

Chen Xin's avatar
Chen Xin committed
26
#include "src/turbomind/macro.h"
lvhan028's avatar
lvhan028 committed
27
28
29
30
31
32
33
34
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include "src/turbomind/utils/custom_ar_comm.h"
#include "src/turbomind/utils/mpi_utils.h"
#include "src/turbomind/utils/nccl_utils.h"
#include "src/turbomind/utils/nvtx_utils.h"
#include "src/turbomind/utils/word_list.h"
Li Zhang's avatar
Li Zhang committed
35

lvhan028's avatar
lvhan028 committed
36
namespace ft = turbomind;
Li Zhang's avatar
Li Zhang committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

constexpr const bool kUSE_MPI = true;

struct RequestParam {
    int                    beam_width;
    int                    request_output_len;
    float                  beam_search_diversity_rate;
    uint                   runtime_top_k;
    float                  runtime_top_p;
    float                  temperature;
    float                  len_penalty;
    float                  repetition_penalty;
    float                  presence_penalty;
    int                    min_length;
    unsigned long long int random_seed;
    int                    start_id;
    int                    end_id;
};

std::vector<std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>>
broadCastRequest(const std::vector<int>& v_start_ids,
                 const std::vector<int>& v_start_lengths,
                 const std::vector<int>& v_bad_words,
                 const int               node_id,
                 const int               gpu_count,
                 const RequestParam      param,
                 std::vector<void*>*     pointer_record)
{
    // broadcast the request to all nodes, and copy "gpu_count" copies on
    // different gpu
    int size_1         = v_start_ids.size();
    int size_2         = v_start_lengths.size();
    int size_bad_words = v_bad_words.size();
    if (kUSE_MPI) {
        ft::mpi::bcast(&size_1, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD);
        ft::mpi::bcast(&size_2, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD);
        ft::mpi::bcast(&size_bad_words, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD);
    }

    std::vector<int> v_input_ids(size_1);
    std::vector<int> v_input_lengths(size_2);
    std::vector<int> v_input_bad_words(size_bad_words);

    if (node_id == 0) {
        memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int));
        memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int));
Li Zhang's avatar
Li Zhang committed
83
84
85
        if (!v_input_bad_words.empty()) {
            memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int));
        }
Li Zhang's avatar
Li Zhang committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    }
    if (kUSE_MPI) {
        ft::mpi::barrier();
    }

    int request_batch_size = size_2;
    int max_input_len      = size_1 / size_2;

    std::cerr << "request_batch_size=" << request_batch_size << " max_input_len=" << max_input_len << "\n";

    if (kUSE_MPI) {
        ft::mpi::bcast(v_input_ids.data(), size_1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD);
        ft::mpi::bcast(v_input_lengths.data(), size_2, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD);
        ft::mpi::bcast(v_input_bad_words.data(), size_bad_words, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD);
    }

    std::vector<std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>> request_list;
    for (int device_id = 0; device_id < gpu_count; device_id++) {
        ft::check_cuda_error(cudaSetDevice(device_id));

        int* d_input_ids;
        // int* d_input_lengths;
        int* d_input_bad_words;

        if (max_input_len == 0) {
            // unconditional case, no input ids, so do nothing.
            d_input_ids = nullptr;
            // d_input_lengths = nullptr;
            max_input_len = 0;
        }
        else {
            // conditional case.
            ft::deviceMalloc(&d_input_ids, size_1, false);
            // ft::deviceMalloc(&d_input_lengths, size_2, false);
            ft::cudaH2Dcpy(d_input_ids, v_input_ids.data(), size_1);
            // ft::cudaH2Dcpy(d_input_lengths, v_input_lengths.data(), size_2);
        }

        if (!v_input_bad_words.empty()) {
            ft::deviceMalloc(&d_input_bad_words, size_bad_words, false);
            ft::cudaH2Dcpy(d_input_bad_words, v_input_bad_words.data(), size_bad_words);
        }
        else {
            d_input_bad_words = nullptr;
        }

        uint32_t* request_output_len_ptr = (uint32_t*)malloc(request_batch_size * sizeof(uint32_t));
        int*      input_lengths_ptr      = (int*)malloc(request_batch_size * sizeof(int));
        for (int i = 0; i < request_batch_size; i++) {
            request_output_len_ptr[i] = param.request_output_len;
            input_lengths_ptr[i]      = v_input_lengths[i];
        }

        int* start_ids_ptr = (int*)malloc(request_batch_size * sizeof(int));
        int* end_ids_ptr   = (int*)malloc(request_batch_size * sizeof(int));
        for (int i = 0; i < request_batch_size; i++) {
            start_ids_ptr[i] = param.start_id;
            end_ids_ptr[i]   = param.end_id;
        }
        pointer_record->push_back(start_ids_ptr);
        pointer_record->push_back(end_ids_ptr);

        request_list.push_back(std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>(
            new std::unordered_map<std::string, triton::Tensor>{
                {"input_ids",
                 triton::Tensor{triton::MEMORY_GPU,
                                triton::TYPE_INT32,
                                std::vector<size_t>{(size_t)request_batch_size, (size_t)max_input_len},
                                d_input_ids}},
                {"input_lengths",
                 triton::Tensor{triton::MEMORY_CPU,
                                triton::TYPE_INT32,
                                std::vector<size_t>{(size_t)request_batch_size},
                                input_lengths_ptr}},
                {"request_output_len",
                 triton::Tensor{triton::MEMORY_CPU,
                                triton::TYPE_INT32,
                                std::vector<size_t>{(size_t)request_batch_size},
                                request_output_len_ptr}},
                {"bad_words_list",
                 triton::Tensor{
                     triton::MEMORY_GPU, triton::TYPE_INT32, {2, v_input_bad_words.size() / 2}, d_input_bad_words}},
                {"start_id",
                 triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, start_ids_ptr}},
                {"end_id",
                 triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, end_ids_ptr}}}));

        int* beam_width_ptr = new int(param.beam_width);
        pointer_record->push_back(beam_width_ptr);
        request_list[device_id]->insert(
            {"beam_width",
             triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector<size_t>{1}, beam_width_ptr}});
        if (param.beam_width > 1) {
            float* beam_search_diversity_rate_ptr = new float(param.beam_search_diversity_rate);
            pointer_record->push_back(beam_search_diversity_rate_ptr);
            request_list[device_id]->insert(
                {"beam_search_diversity_rate",
                 triton::Tensor{
                     triton::MEMORY_CPU, triton::TYPE_FP32, std::vector<size_t>{1}, beam_search_diversity_rate_ptr}});
        }
        else {
            if (param.runtime_top_p != 0.0f) {
                float* runtime_top_p_ptr = new float(param.runtime_top_p);
                pointer_record->push_back(runtime_top_p_ptr);
                request_list[device_id]->insert(
                    {"runtime_top_p",
                     triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector<size_t>{1}, runtime_top_p_ptr}});
            }
            if (param.runtime_top_k != 0) {
                uint* runtime_top_k_ptr = new uint(param.runtime_top_k);
                pointer_record->push_back(runtime_top_k_ptr);
                request_list[device_id]->insert(
                    {"runtime_top_k",
                     triton::Tensor{
                         triton::MEMORY_CPU, triton::TYPE_UINT32, std::vector<size_t>{1}, runtime_top_k_ptr}});
            }
        }
        float* temperature_ptr = new float(param.temperature);
        pointer_record->push_back(temperature_ptr);
        request_list[device_id]->insert(
            {"temperature",
             triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector<size_t>{1}, temperature_ptr}});
        float* len_penalty_ptr = new float(param.len_penalty);
        pointer_record->push_back(len_penalty_ptr);
        request_list[device_id]->insert(
            {"len_penalty",
             triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector<size_t>{1}, len_penalty_ptr}});
        if (param.repetition_penalty != 1.0f) {
            float* repetition_penalty_ptr = new float(param.repetition_penalty);
            pointer_record->push_back(repetition_penalty_ptr);
            request_list[device_id]->insert(
                {"repetition_penalty",
                 triton::Tensor{
                     triton::MEMORY_CPU, triton::TYPE_FP32, std::vector<size_t>{1}, repetition_penalty_ptr}});
        }
        if (param.presence_penalty != 0.0f) {
            float* presence_penalty_ptr = new float(param.presence_penalty);
            pointer_record->push_back(presence_penalty_ptr);
            request_list[device_id]->insert(
                {"presence_penalty",
                 triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector<size_t>{1}, presence_penalty_ptr}});
        }
        int* min_length_ptr = new int(param.min_length);
        pointer_record->push_back(min_length_ptr);
        request_list[device_id]->insert(
            {"min_length",
             triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector<size_t>{1}, min_length_ptr}});
        unsigned long long int* random_seed_ptr = new unsigned long long int(param.random_seed);
        pointer_record->push_back(random_seed_ptr);
        request_list[device_id]->insert(
            {"random_seed",
             triton::Tensor{triton::MEMORY_CPU, triton::TYPE_UINT64, std::vector<size_t>{1}, random_seed_ptr}});

        pointer_record->push_back(d_input_ids);
        // pointer_record->push_back(d_input_lengths);
        pointer_record->push_back(d_input_bad_words);
        pointer_record->push_back(request_output_len_ptr);
        pointer_record->push_back(input_lengths_ptr);
    }

    return request_list;
}

int read_start_ids(size_t            batch_size,
                   std::vector<int>* v_start_lengths,
                   std::vector<int>* v_start_ids,
252
                   size_t            max_input_len,
Li Zhang's avatar
Li Zhang committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                   const int         end_id,
                   const int         beam_width,
                   std::string       file_name);

std::vector<std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>>
prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std::vector<void*>* pointer_record)
{
    INIReader reader = INIReader(ini_name);
    if (reader.ParseError() < 0) {
        std::cout << "[ERROR] Can't load '" << ini_name << "'\n";
        ft::FT_CHECK(false);
    }

    const size_t request_batch_size = reader.GetInteger("request", "request_batch_size");
    std::cerr << "request_batch_size=" << request_batch_size << "\n";

269
270
271
    const int start_id      = reader.GetInteger("request", "start_id");
    const int end_id        = reader.GetInteger("request", "end_id");
    const int max_input_len = reader.GetInteger("request", "max_input_len");
Li Zhang's avatar
Li Zhang committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

    std::vector<int> v_start_ids;
    std::vector<int> v_start_lengths;

    read_start_ids(request_batch_size,
                   &v_start_lengths,
                   &v_start_ids,
                   max_input_len,
                   end_id,
                   1,
                   "../examples/cpp/llama/start_ids.csv");
    // drop requests > request_batch_size
    if (v_start_lengths.size() > request_batch_size) {
        v_start_lengths.resize(request_batch_size);
        v_start_ids.resize(request_batch_size * max_input_len);
    }
    std::cerr << "max_input_len=" << max_input_len << "\n";

    std::vector<int> v_bad_words;
    // ft::read_word_list("../examples/cpp/llama/bad_words.csv", v_bad_words);

    RequestParam param;
    param.beam_width                 = reader.GetInteger("request", "beam_width");
    param.request_output_len         = reader.GetInteger("request", "request_output_len");
    param.beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate");
    param.runtime_top_k              = reader.GetInteger("request", "top_k");
    param.runtime_top_p              = reader.GetFloat("request", "top_p");
    param.temperature                = reader.GetFloat("request", "temperature");
    param.len_penalty                = reader.GetFloat("request", "len_penalty");
    param.repetition_penalty         = reader.GetFloat("request", "repetition_penalty", 1.0f);
    param.presence_penalty           = reader.GetFloat("request", "presence_penalty", 0.0f);
    param.min_length                 = reader.GetInteger("request", "min_length", 0);
    param.random_seed                = (unsigned long long int)0;
    param.start_id                   = start_id;
    param.end_id                     = end_id;

    auto request_list =
        broadCastRequest(v_start_ids, v_start_lengths, v_bad_words, node_id, gpu_count, param, pointer_record);
    return request_list;
}

int threadCreateModelInstances(std::shared_ptr<AbstractTransformerModel>                         model,
                               std::vector<std::unique_ptr<AbstractTransformerModelInstance>>*   model_instances,
                               const int                                                         device_id,
                               const int                                                         rank,
                               std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> nccl_params,
                               std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm = nullptr)
{
    printf("[INFO] rank = %d \n", rank);
    ft::check_cuda_error(cudaSetDevice(device_id));
    cudaStream_t stream;
    ft::check_cuda_error(cudaStreamCreate(&stream));
    model->createSharedWeights(device_id, rank);
    auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_params, custom_all_reduce_comm);
    model_instances->at(device_id) = std::move(model_instance);
    printf("model instance %d is created \n", device_id);
    ft::print_mem_usage();
    return 0;
}

int threadForward(std::unique_ptr<AbstractTransformerModelInstance>*                model_instance,
                  std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>  request,
                  std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>* output_tensors,
                  const int                                                         device_id,
                  ft::AbstractInstanceComm*                                         comm)
{
    ft::check_cuda_error(cudaSetDevice(device_id));
    cudaDeviceSynchronize();
    *output_tensors = (*model_instance)->forward(request, comm);
    cudaDeviceSynchronize();
    return 0;
}

int main(int argc, char* argv[])
{
    /*
        Prepare the nccl ids, node id, device id and world size
        by MPI or triton
    */

    int node_id  = 0;
    int node_num = 1;

    if (kUSE_MPI) {
        ft::mpi::initialize(&argc, &argv);
        node_id  = ft::mpi::getCommWorldRank();
        node_num = ft::mpi::getCommWorldSize();
    }

    printf("node_id=%d node_num=%d\n", node_id, node_num);

    // Note: Only supports that all nodes have same gpu count
    const int   gpu_count  = ft::getDeviceCount();
    const int   world_size = node_num * gpu_count;
    std::string ini_name   = argc >= 2 ? std::string(argv[1]) : "../examples/cpp/llama/llama_config.ini";

    // step 1: Create model
    std::shared_ptr<AbstractTransformerModel> model              = AbstractTransformerModel::createLlamaModel(ini_name);
    int                                       tensor_para_size   = model->getTensorParaSize();
    int                                       pipeline_para_size = model->getPipelineParaSize();
    printf(
        "world_size=%d tensor_para_size=%d pipeline_para_size=%d\n", world_size, tensor_para_size, pipeline_para_size);
    FT_CHECK_WITH_INFO(world_size == (tensor_para_size * pipeline_para_size),
                       "World Size != Tensor Parallel Size * Pipeline Parallel Size !");

    std::cout << model->toString();

    // step 2: Initialize the NCCL
    std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> nccl_comms = model->createNcclParams(node_id);
    cudaDeviceSynchronize();

    // Optional Step: create custom all reduce comm
    // std::vector<std::shared_ptr<ft::AbstractCustomComm>>
    // custom_all_reduce_comms; model->createCustomComms(&custom_all_reduce_comms,
    // world_size);

    // step 2.1 create instance comm
    auto instance_comm = model->createInstanceComm(gpu_count);

    // step 3: Create model instances
    std::vector<std::unique_ptr<AbstractTransformerModelInstance>> model_instances((size_t)gpu_count);
    std::vector<std::thread>                                       threads;
    for (int device_id = 0; device_id < gpu_count; device_id++) {
        const int rank = node_id * gpu_count + device_id;
        threads.push_back(
            std::thread(threadCreateModelInstances, model, &model_instances, device_id, rank, nccl_comms, nullptr));
        //   custom_all_reduce_comms[rank]));
    }
    for (auto& t : threads) {
        t.join();
    }

    // step 4: prepare request
    std::vector<void*> pointer_record;  // Used to prevent the pointers are
                                        // release after leaving functions
    std::vector<std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>> request_list =
        prepareRequest(ini_name, node_id, gpu_count, &pointer_record);
    printf("[INFO] request is created \n");

    // step 5: Forward
    std::vector<std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>> output_tensors_lists(
        (size_t)gpu_count);
    for (int i = 0; i < 1; i++) {
        threads.clear();
        for (int device_id = 0; device_id < gpu_count; device_id++) {
            threads.push_back(std::thread(threadForward,
                                          &model_instances[device_id],
                                          request_list[device_id],
                                          &output_tensors_lists[device_id],
                                          device_id,
                                          instance_comm.get()));
        }
        for (auto& t : threads) {
            t.join();
        }
    }
    printf("[INFO] forward is completed. \n");

    const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data;
    const int* d_seq_lens   = (const int*)output_tensors_lists[0].get()->at("sequence_length").data;
    const int  batch_size   = output_tensors_lists[0].get()->at("output_ids").shape[0];
    const int  beam_width   = output_tensors_lists[0].get()->at("output_ids").shape[1];
    const int  seq_len      = output_tensors_lists[0].get()->at("output_ids").shape[2];
Chen Xin's avatar
Chen Xin committed
435

Li Zhang's avatar
Li Zhang committed
436
437
    ft::FT_CHECK(beam_width == 1);

438
    std::vector<int> seq_lens(batch_size);
Li Zhang's avatar
Li Zhang committed
439
440
441
442
443
444
445
446
    // step 6: check results
    if (node_id == 0) {
        std::string fName   = "out";
        auto        outFile = std::ofstream(fName, std::ios::out);
        if (!outFile.is_open()) {
            printf("[WARNING] Cannot write results into output file %s \n", fName.c_str());
        }
        else {
Li Zhang's avatar
Li Zhang committed
447
448
            const size_t outCount = batch_size * beam_width * seq_len;

Li Zhang's avatar
Li Zhang committed
449
            std::vector<int> hBuf(outCount);
Li Zhang's avatar
Li Zhang committed
450

Li Zhang's avatar
Li Zhang committed
451
452
            ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount);
            ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size);
Li Zhang's avatar
Li Zhang committed
453

Li Zhang's avatar
Li Zhang committed
454
455
456
457
458
            std::cout << "sequence length: ";
            for (int i = 0; i < batch_size; ++i) {
                std::cout << (i ? ", " : "") << seq_lens[i];
            }
            std::cout << "\n";
Li Zhang's avatar
Li Zhang committed
459
460
461
462
463
464

            for (int i = 0; i < batch_size; ++i) {
                outFile << (i ? "\n" : "");
                auto buf = hBuf.data() + seq_len * i;
                for (int j = 0; j < seq_lens[i]; ++j) {
                    outFile << buf[j] << " ";
Li Zhang's avatar
Li Zhang committed
465
466
467
468
469
470
471
472
473
474
                }
            }
        }
    }

    if (kUSE_MPI) {
        ft::mpi::barrier();
    }
    cudaDeviceSynchronize();

Li Zhang's avatar
Li Zhang committed
475
    if (0) {
Li Zhang's avatar
Li Zhang committed
476
        // test time
Chen Xin's avatar
Chen Xin committed
477
        auto start = std::chrono::high_resolution_clock::now();
Li Zhang's avatar
Li Zhang committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499

        const int ite = 1;
        for (int i = 0; i < ite; i++) {
            threads.clear();
            for (int device_id = 0; device_id < gpu_count; device_id++) {
                threads.push_back(std::thread(threadForward,
                                              &model_instances[device_id],
                                              request_list[device_id],
                                              &output_tensors_lists[device_id],
                                              device_id,
                                              instance_comm.get()));
            }
            for (auto& t : threads) {
                t.join();
            }
        }

        cudaDeviceSynchronize();
        if (kUSE_MPI) {
            ft::mpi::barrier();
        }

Chen Xin's avatar
Chen Xin committed
500
501
        auto end = std::chrono::high_resolution_clock::now();
        auto dur = std::chrono::duration<float, std::milli>(end - start);
Li Zhang's avatar
Li Zhang committed
502
503
504
505
506

        printf("[INFO] batch_size %d beam_width %d seq_len %d"
               " FT-CPP-GPT-Triton-time %.2f ms\n",
               batch_size,
               beam_width,
507
               seq_lens[0],
Chen Xin's avatar
Chen Xin committed
508
               dur.count() / ite);
Li Zhang's avatar
Li Zhang committed
509
510
511
512
513
514
515
516
517
518
519
    }

    if (kUSE_MPI) {
        ft::mpi::finalize();
    }
    return 0;
}

int read_start_ids(size_t            batch_size,
                   std::vector<int>* v_start_lengths,
                   std::vector<int>* v_start_ids,
520
                   size_t            max_input_len,
Li Zhang's avatar
Li Zhang committed
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
                   const int         end_id,
                   const int         beam_width,
                   std::string       file_name)
{
    std::vector<std::vector<int>> tmp_start_ids;
    std::vector<int>              tmp_start_lengths;

    std::ifstream start_id_file(file_name, std::ios::in);
    int           line_num = 0;
    if (start_id_file.is_open()) {
        std::string line;
        while (std::getline(start_id_file, line)) {
            std::stringstream lineStream(line);
            std::string       vals;
            std::vector<int>  tmp_vec;
            while (std::getline(lineStream, vals, ',')) {
                tmp_vec.push_back(std::stoi(vals));
538
539
                if (tmp_vec.size() == max_input_len)
                    break;
Li Zhang's avatar
Li Zhang committed
540
541
            }
            tmp_start_ids.push_back(tmp_vec);
542
            tmp_start_lengths.push_back(tmp_vec.size());
Li Zhang's avatar
Li Zhang committed
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
            line_num++;
        }
        if (batch_size == 0) {
            batch_size = line_num;
        }
    }
    else {
        printf("[WARNING] Cannot open the file '%s'. \n", file_name.c_str());
        max_input_len = 0;
        return 0;
    }

    // Add padding
    for (int i = 0; i < (int)tmp_start_ids.size(); i++) {
        for (int j = (int)tmp_start_ids[i].size(); j < max_input_len; j++) {
            tmp_start_ids[i].push_back(end_id);
        }
    }

562
563
564
565
566
567
    // Pad to batch_size
    for (int i = (int)tmp_start_lengths.size(); i < batch_size; i++) {
        tmp_start_ids.push_back(tmp_start_ids[0]);
        tmp_start_lengths.push_back(tmp_start_lengths[0]);
    }

Li Zhang's avatar
Li Zhang committed
568
569
570
571
572
573
574
575
576
577
    for (int i = 0; i < (int)tmp_start_ids.size(); i++) {
        for (int b = 0; b < beam_width; b++) {
            for (int j = 0; j < (int)tmp_start_ids[i].size(); j++) {
                v_start_ids->push_back(tmp_start_ids[i][j]);
            }
            v_start_lengths->push_back(tmp_start_lengths[i]);
        }
    }
    return batch_size;
}