qwen3vl_impl.hpp 3.88 KB
Newer Older
hejianlin's avatar
hejianlin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#ifndef QWEN3VL_IMPL_H
#define QWEN3VL_IMPL_H

#include "infinicore_infer.h"

#include "../../allocator.hpp"
#include "../../tensor.hpp"

#include <condition_variable>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>

struct Qwen3vlLayerWeight {
    std::shared_ptr<Tensor> attn_norm;
    std::shared_ptr<Tensor> attn_qkv_proj;
    std::shared_ptr<Tensor> attn_q_norm;
    std::shared_ptr<Tensor> attn_k_norm;
    std::shared_ptr<Tensor> attn_o_proj;

    std::shared_ptr<Tensor> mlp_norm;
    std::shared_ptr<Tensor> mlp_gate_up;
    std::shared_ptr<Tensor> mlp_down;
};

struct Qwen3vlLanguageModelWeight {
    std::shared_ptr<Tensor> in_embd, out_embd, out_norm;
    std::vector<Qwen3vlLayerWeight> layers;
};

struct Qwen3vlVisBlockWeight {
    std::shared_ptr<Tensor> attn_proj_weight, attn_proj_bias, attn_qkv_weight, attn_qkv_bias;
    std::shared_ptr<Tensor> mlp_linear_fc1_weight, mlp_linear_fc1_bias, mlp_linear_fc2_weight, mlp_linear_fc2_bias;
    std::shared_ptr<Tensor> norm1_weight, norm1_bias, norm2_weight, norm2_bias;
};

struct DeepstackMergerWeight {
    std::shared_ptr<Tensor> linear_fc1_weight, linear_fc1_bias, linear_fc2_weight, linear_fc2_bias;
    std::shared_ptr<Tensor> norm_weight, norm_bias;
};

struct MergerWeight {
    std::shared_ptr<Tensor> linear_fc1_weight, linear_fc1_bias, linear_fc2_weight, linear_fc2_bias;
    std::shared_ptr<Tensor> norm_weight, norm_bias;
};

struct Qwen3vlVisualEncoderWeight {
    std::shared_ptr<Tensor> patch_embed_weight, patch_embed_bias, pos_embed_weight;
    std::vector<Qwen3vlVisBlockWeight> blocks;
    std::vector<DeepstackMergerWeight> deepstack_mergers;
    std::shared_ptr<MergerWeight> merger;
};

struct Qwen3vlDeviceWeights {
PanZezhong's avatar
PanZezhong committed
56
    std::shared_ptr<Tensor> sin_table, cos_table;
hejianlin's avatar
hejianlin committed
57
58
59
60
61
62
63
64
65
66
67
68
69
    std::shared_ptr<Qwen3vlLanguageModelWeight> w_lang;
    std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis;
    infiniDevice_t device;
    int dev_id;
    infinirtStream_t load_stream;
};

struct Qwen3vlWeights {
    Qwen3vlMeta const *meta;
    bool transpose_weight;
    std::vector<std::shared_ptr<Qwen3vlDeviceWeights>> device_weights;

    Qwen3vlWeights(const Qwen3vlMeta *meta,
PanZezhong's avatar
PanZezhong committed
70
71
72
73
                   infiniDevice_t device,
                   int ndev,
                   const int *dev_ids,
                   bool transpose_weight);
hejianlin's avatar
hejianlin committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
};

struct Qwen3vlDeviceResource {
    // Device
    infiniDevice_t device;
    int device_id;
    infiniopHandle_t handle;
    // Weights
    std::shared_ptr<Qwen3vlDeviceWeights> weights;
    // Streams
    infinirtStream_t stream;
    // Communicator
    infinicclComm_t comm;

    std::shared_ptr<MemoryPool> memory_pool;
};

struct InferState { // qwen3vl namespace
    inline static std::mutex mtx_sync;
    inline static int sync_cnt;
    inline static std::condition_variable cv_sync;
    std::mutex mtx;
    std::condition_variable cv_load, cv_start, cv_done;
    bool loaded = false;
    bool proceed = false;
    bool exit_flag = false;
};

struct InferRequest { // qwen3vl namespace
    const uint32_t *tokens;
    uint32_t ntok;
    void *pixel_values;
    uint32_t total_patches;
    uint32_t *image_grid_thw;
    uint32_t num_images;
    void *pixel_values_videos;
    uint32_t total_patches_videos;
    uint32_t *video_grid_thw;
    uint32_t num_videos;
    uint32_t patch_features;
    const uint32_t *req_lens;
    uint32_t nreq;
    const uint32_t *req_pos;
    struct Qwen3vlCache **kv_caches;
    const float *temperature;
    const uint32_t *topk;
    const float *topp;
    uint32_t *output;
    void *logits;
};

struct Qwen3vlModel {
    Qwen3vlMeta meta;
    infiniDevice_t device;
    std::vector<int> dev_ids;
    std::vector<Qwen3vlDeviceResource> dev_resources;
    std::vector<InferState> states;
    std::vector<std::thread> threads;
    InferRequest req;

    Qwen3vlModel(const Qwen3vlMeta *, const Qwen3vlWeights *weights);
};

struct Qwen3vlCache {
PanZezhong's avatar
PanZezhong committed
138
    std::vector<std::vector<std::shared_ptr<Tensor>>> k_rot, v;
hejianlin's avatar
hejianlin committed
139
140
};

PanZezhong's avatar
PanZezhong committed
141
#endif