qwen3vl_impl.hpp 3.9 KB
Newer Older
hejianlin's avatar
hejianlin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#ifndef QWEN3VL_IMPL_H
#define QWEN3VL_IMPL_H

#include "infinicore_infer.h"

#include "../../allocator.hpp"
#include "../../tensor.hpp"

#include <condition_variable>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>

struct Qwen3vlLayerWeight {
    std::shared_ptr<Tensor> attn_norm;
    std::shared_ptr<Tensor> attn_qkv_proj;
    std::shared_ptr<Tensor> attn_q_norm;
    std::shared_ptr<Tensor> attn_k_norm;
    std::shared_ptr<Tensor> attn_o_proj;

    std::shared_ptr<Tensor> mlp_norm;
    std::shared_ptr<Tensor> mlp_gate_up;
    std::shared_ptr<Tensor> mlp_down;
};

struct Qwen3vlLanguageModelWeight {
    std::shared_ptr<Tensor> in_embd, out_embd, out_norm;
    std::vector<Qwen3vlLayerWeight> layers;
};

struct Qwen3vlVisBlockWeight {
    std::shared_ptr<Tensor> attn_proj_weight, attn_proj_bias, attn_qkv_weight, attn_qkv_bias;
    std::shared_ptr<Tensor> mlp_linear_fc1_weight, mlp_linear_fc1_bias, mlp_linear_fc2_weight, mlp_linear_fc2_bias;
    std::shared_ptr<Tensor> norm1_weight, norm1_bias, norm2_weight, norm2_bias;
};

struct DeepstackMergerWeight {
    std::shared_ptr<Tensor> linear_fc1_weight, linear_fc1_bias, linear_fc2_weight, linear_fc2_bias;
    std::shared_ptr<Tensor> norm_weight, norm_bias;
};

struct MergerWeight {
    std::shared_ptr<Tensor> linear_fc1_weight, linear_fc1_bias, linear_fc2_weight, linear_fc2_bias;
    std::shared_ptr<Tensor> norm_weight, norm_bias;
};


struct Qwen3vlVisualEncoderWeight {
    std::shared_ptr<Tensor> patch_embed_weight, patch_embed_bias, pos_embed_weight;
    std::vector<Qwen3vlVisBlockWeight> blocks;
    std::vector<DeepstackMergerWeight> deepstack_mergers;
    std::shared_ptr<MergerWeight> merger;
};


struct Qwen3vlDeviceWeights {
    std::shared_ptr<Tensor> sin_table,cos_table;
    std::shared_ptr<Qwen3vlLanguageModelWeight> w_lang;
    std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis;
    infiniDevice_t device;
    int dev_id;
    infinirtStream_t load_stream;
};

struct Qwen3vlWeights {
    Qwen3vlMeta const *meta;
    bool transpose_weight;
    std::vector<std::shared_ptr<Qwen3vlDeviceWeights>> device_weights;

    Qwen3vlWeights(const Qwen3vlMeta *meta,
                      infiniDevice_t device,
                      int ndev,
                      const int *dev_ids,
                      bool transpose_weight);
};

struct Qwen3vlDeviceResource {
    // Device
    infiniDevice_t device;
    int device_id;
    infiniopHandle_t handle;
    // Weights
    std::shared_ptr<Qwen3vlDeviceWeights> weights;
    // Streams
    infinirtStream_t stream;
    // Communicator
    infinicclComm_t comm;

    std::shared_ptr<MemoryPool> memory_pool;
};

struct InferState { // qwen3vl namespace
    inline static std::mutex mtx_sync;
    inline static int sync_cnt;
    inline static std::condition_variable cv_sync;
    std::mutex mtx;
    std::condition_variable cv_load, cv_start, cv_done;
    bool loaded = false;
    bool proceed = false;
    bool exit_flag = false;
};

struct InferRequest { // qwen3vl namespace
    const uint32_t *tokens;
    uint32_t ntok;
    void *pixel_values;
    uint32_t total_patches;
    uint32_t *image_grid_thw;
    uint32_t num_images;
    void *pixel_values_videos;
    uint32_t total_patches_videos;
    uint32_t *video_grid_thw;
    uint32_t num_videos;
    uint32_t patch_features;
    const uint32_t *req_lens;
    uint32_t nreq;
    const uint32_t *req_pos;
    struct Qwen3vlCache **kv_caches;
    const float *temperature;
    const uint32_t *topk;
    const float *topp;
    uint32_t *output;
    void *logits;
};

struct Qwen3vlModel {
    Qwen3vlMeta meta;
    infiniDevice_t device;
    std::vector<int> dev_ids;
    std::vector<Qwen3vlDeviceResource> dev_resources;
    std::vector<InferState> states;
    std::vector<std::thread> threads;
    InferRequest req;

    Qwen3vlModel(const Qwen3vlMeta *, const Qwen3vlWeights *weights);
};

struct Qwen3vlCache {
    std::vector<std::vector<std::shared_ptr<Tensor>>> k_rot, v; 
};

#endif