#ifndef QWEN3VL_IMPL_H #define QWEN3VL_IMPL_H #include "infinicore_infer.h" #include "../../allocator.hpp" #include "../../tensor.hpp" #include #include #include #include #include struct Qwen3vlLayerWeight { std::shared_ptr attn_norm; std::shared_ptr attn_qkv_proj; std::shared_ptr attn_q_norm; std::shared_ptr attn_k_norm; std::shared_ptr attn_o_proj; std::shared_ptr mlp_norm; std::shared_ptr mlp_gate_up; std::shared_ptr mlp_down; }; struct Qwen3vlLanguageModelWeight { std::shared_ptr in_embd, out_embd, out_norm; std::vector layers; }; struct Qwen3vlVisBlockWeight { std::shared_ptr attn_proj_weight, attn_proj_bias, attn_qkv_weight, attn_qkv_bias; std::shared_ptr mlp_linear_fc1_weight, mlp_linear_fc1_bias, mlp_linear_fc2_weight, mlp_linear_fc2_bias; std::shared_ptr norm1_weight, norm1_bias, norm2_weight, norm2_bias; }; struct DeepstackMergerWeight { std::shared_ptr linear_fc1_weight, linear_fc1_bias, linear_fc2_weight, linear_fc2_bias; std::shared_ptr norm_weight, norm_bias; }; struct MergerWeight { std::shared_ptr linear_fc1_weight, linear_fc1_bias, linear_fc2_weight, linear_fc2_bias; std::shared_ptr norm_weight, norm_bias; }; struct Qwen3vlVisualEncoderWeight { std::shared_ptr patch_embed_weight, patch_embed_bias, pos_embed_weight; std::vector blocks; std::vector deepstack_mergers; std::shared_ptr merger; }; struct Qwen3vlDeviceWeights { std::shared_ptr sin_table, cos_table; std::shared_ptr w_lang; std::shared_ptr w_vis; infiniDevice_t device; int dev_id; infinirtStream_t load_stream; }; struct Qwen3vlWeights { Qwen3vlMeta const *meta; bool transpose_weight; std::vector> device_weights; Qwen3vlWeights(const Qwen3vlMeta *meta, infiniDevice_t device, int ndev, const int *dev_ids, bool transpose_weight); }; struct Qwen3vlDeviceResource { // Device infiniDevice_t device; int device_id; infiniopHandle_t handle; // Weights std::shared_ptr weights; // Streams infinirtStream_t stream; // Communicator infinicclComm_t comm; std::shared_ptr memory_pool; }; struct InferState { // qwen3vl namespace inline static std::mutex mtx_sync; inline static int sync_cnt; inline static std::condition_variable cv_sync; std::mutex mtx; std::condition_variable cv_load, cv_start, cv_done; bool loaded = false; bool proceed = false; bool exit_flag = false; }; struct InferRequest { // qwen3vl namespace const uint32_t *tokens; uint32_t ntok; void *pixel_values; uint32_t total_patches; uint32_t *image_grid_thw; uint32_t num_images; void *pixel_values_videos; uint32_t total_patches_videos; uint32_t *video_grid_thw; uint32_t num_videos; uint32_t patch_features; const uint32_t *req_lens; uint32_t nreq; const uint32_t *req_pos; struct Qwen3vlCache **kv_caches; const float *temperature; const uint32_t *topk; const float *topp; uint32_t *output; void *logits; }; struct Qwen3vlModel { Qwen3vlMeta meta; infiniDevice_t device; std::vector dev_ids; std::vector dev_resources; std::vector states; std::vector threads; InferRequest req; Qwen3vlModel(const Qwen3vlMeta *, const Qwen3vlWeights *weights); }; struct Qwen3vlCache { std::vector>> k_rot, v; }; #endif