ffi.h 1.9 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
//
// Created by mfuntowicz on 7/11/24.
//

#ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H

#include <cstddef>
#include "backend.h"

namespace huggingface::tgi::backends {
    class TensorRtLlmBackendImpl;
}

#include "backends/trtllm/src/lib.rs.h"


namespace huggingface::tgi::backends {

//    struct GenerationContext;

    class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
    public:
        /***
         *
         * @param engineFolder
         * @param executorWorker
         */
        TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);

        /***
         *
         * @return
         */
        bool IsReady() const;

        /***
         *
         * @param tokens
         * @param topK
         * @param topP
         * @param temperature
         * @param repetition_penalty
         * @param frequency_penalty
         * @param seed
         * @return
         */
        [[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
        uint64_t
        Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
               float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);

        /***
         *
         * @param requestId
         * @param ctx
         * @param callback
         * @return
         */
        size_t StreamTokens(
                const RequestId requestId,
                huggingface::tgi::backends::GenerationContext *ctx,
                rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
                              huggingface::tgi::backends::GenerationStep)> callback);
    };

    /***
    *
    * @param engineFolder
    * @return
    */
    std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
}

#endif //TGI_TRTLLM_BACKEND_FFI_H