ffi.cpp 3.22 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
//
// Created by mfuntowicz on 6/30/24.
//
#pragma once

#include <cmath>
#include <exception>
#include <filesystem>
#include <limits>
#include <iterator>
#include <vector>

#include <spdlog/spdlog.h>
#include "backends/trtllm/include/ffi.h"


huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
        const std::string_view &engineFolder,
        const std::string_view &executorWorker
) : TensorRtLlmBackend(engineFolder, executorWorker) {}


bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
    return TensorRtLlmBackend::IsReady();
}

uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
        rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
        float_t frequency_penalty, uint64_t seed) {

    // This will copy all the items from the initial slice
    std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
    return TensorRtLlmBackend::Submit(
            std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
}

size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
        const uint64_t requestId,
        huggingface::tgi::backends::GenerationContext *ctx,
        rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
                      huggingface::tgi::backends::GenerationStep)> callback) {

    size_t numTokens = 0;
    for (const auto &item: Poll(requestId)) {
        GenerationStep step;
        if (!item.hasError()) {
            SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
            const auto decoded = item.getResult();

            const auto token = decoded.outputTokenIds[0][0];
            const auto isFinal = decoded.isFinal;
            const auto logProb = decoded.logProbs.value()[0][0];

            ++numTokens;

            SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
            step = huggingface::tgi::backends::GenerationStep{
                    static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
            };
            SPDLOG_DEBUG("\tStreamTokens -> Post callback");
        } else {
            // TODO : Return rest::Result with error
            const auto what = item.getErrorMsg();
            SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
            step = huggingface::tgi::backends::GenerationStep{
                    std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
            };
        }

        callback(std::move(ctx), std::move(step));
    }

    return numTokens;
}

std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
    // Unconditionally call this to initialize and discover TRTLLM plugins
    InitializeBackend();

    const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
    const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
    return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
}