#pragma once #include "cache_manager.hpp" #include "jiuge/jiuge_impl.hpp" #include "jiuge/jiuge_weight.hpp" #include struct InferenceContext { DeviceResource *rsrc; CacheManager *cache_manager; infinirtStream_t stream; std::shared_ptr workspace_storage; size_t current_workspace_size = 0; InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream); void ensure_workspace(size_t required_size); void add(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b); void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, float epsilon); void gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta); void rearrange(std::shared_ptr dst, std::shared_ptr src); void rope(std::shared_ptr q, std::shared_ptr k, std::shared_ptr pos, std::shared_ptr sin, std::shared_ptr cos); void causalSoftmax(std::shared_ptr y, std::shared_ptr x); void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate); void randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature); void linear(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta, std::shared_ptr residual); }; namespace { thread_local InferenceContext *tls_inference_context = nullptr; } inline InferenceContext &getInferenceContext() { assert(tls_inference_context != nullptr && "InferenceContext not set for this thread"); return *tls_inference_context; } inline void setInferenceContext(InferenceContext *ctx) { tls_inference_context = ctx; } inline void add(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b) { getInferenceContext().add(c, a, b); } inline void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, float epsilon) { getInferenceContext().rmsnorm(y, x, w, epsilon); } inline void gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta) { getInferenceContext().gemm(c, a, b, alpha, beta); } inline void rearrange(std::shared_ptr dst, std::shared_ptr src) { getInferenceContext().rearrange(dst, src); } inline void rope(std::shared_ptr q, std::shared_ptr k, std::shared_ptr pos, std::shared_ptr sin, std::shared_ptr cos) { getInferenceContext().rope(q, k, pos, sin, cos); } inline void causalSoftmax(std::shared_ptr y, std::shared_ptr x) { getInferenceContext().causalSoftmax(y, x); } inline void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate) { getInferenceContext().swiglu(out, up, gate); } inline void randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature) { getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature); } inline void linear(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta, std::shared_ptr residual) { getInferenceContext().linear(c, a, b, alpha, beta, residual); }