#ifndef JIUGE_IMPL_H #define JIUGE_IMPL_H #include "infinicore_infer.h" #include "../../allocator.hpp" #include "../../tensor.hpp" #include #include #include #include #include struct DeviceResource { // Device infiniDevice_t device; int device_id; infiniopHandle_t handle; // Weights std::shared_ptr w_in_embd, w_out_norm, w_out_embd, sin_table, cos_table; std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out, w_ffn_norm, w_ffn_gate_up, w_ffn_down; // Streams infinirtStream_t stream; // Communicator infinicclComm_t comm; std::shared_ptr memory_pool; }; struct InferState { std::mutex mtx; std::condition_variable cv_load, cv_start, cv_done; bool loaded = false; bool proceed = false; bool exit_flag = false; }; struct InferRequest { const uint32_t *tokens; uint32_t ntok; const uint32_t *req_lens; uint32_t nreq; const uint32_t *req_pos; struct KVCache **kv_caches; const float *temperature; const uint32_t *topk; const float *topp; uint32_t *output; void *logits; }; struct JiugeModel { JiugeMeta meta; infiniDevice_t device; std::vector dev_ids; std::vector dev_resources; std::vector states; std::vector threads; InferRequest req; JiugeModel(const JiugeMeta *, const JiugeWeights *, infiniDevice_t device, std::vector device_ids); }; struct KVCache { std::vector>> k, v; }; #endif