#pragma once #include "../cache_manager/opcache_manager.hpp" #include struct InferenceContext { infiniopHandle_t op_handle; std::shared_ptr memory_pool; CacheManager *cache_manager; infinirtStream_t stream; std::shared_ptr workspace_storage; size_t current_workspace_size = 0; InferenceContext(infiniopHandle_t op_handle, std::shared_ptr memory_pool, CacheManager *cache_manager, infinirtStream_t stream); void ensure_workspace(size_t required_size); void add(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b); void conv(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, std::shared_ptr bias, void *pads, void *strides, void *dilations, size_t n); void mul(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b); void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, float epsilon); void gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta); void rearrange(std::shared_ptr dst, std::shared_ptr src); void rope(std::shared_ptr q, std::shared_ptr k, std::shared_ptr pos, std::shared_ptr sin, std::shared_ptr cos, infiniopRoPEAlgo_t algo); void causalSoftmax(std::shared_ptr y, std::shared_ptr x); void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, std::shared_ptr correction_bias, // F32 float routed_scaling_factor, size_t topk); void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate); void silu(std::shared_ptr out, std::shared_ptr input); void randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature); void linear(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias); void dequant(std::shared_ptr weight, std::shared_ptr in_w, std::shared_ptr in_s, std::shared_ptr in_z); }; namespace { thread_local InferenceContext *tls_inference_context = nullptr; } inline InferenceContext &getInferenceContext() { assert(tls_inference_context != nullptr && "InferenceContext not set for this thread"); return *tls_inference_context; } inline void setInferenceContext(InferenceContext *ctx) { tls_inference_context = ctx; } inline void add(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b) { getInferenceContext().add(c, a, b); } inline void conv(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, std::shared_ptr bias, void *pads, void *strides, void *dilations, size_t n) { getInferenceContext().conv(y, x, w, bias, pads, strides, dilations, n); } inline void mul(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b) { getInferenceContext().mul(c, a, b); } inline void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, float epsilon) { getInferenceContext().rmsnorm(y, x, w, epsilon); } inline void gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta) { getInferenceContext().gemm(c, a, b, alpha, beta); } inline void rearrange(std::shared_ptr dst, std::shared_ptr src) { getInferenceContext().rearrange(dst, src); } inline void rope(std::shared_ptr q, std::shared_ptr k, std::shared_ptr pos, std::shared_ptr sin, std::shared_ptr cos) { getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_J); } inline void rope_v2(std::shared_ptr q, std::shared_ptr k, std::shared_ptr pos, std::shared_ptr sin, std::shared_ptr cos) { getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_NEOX); } inline void causalSoftmax(std::shared_ptr y, std::shared_ptr x) { getInferenceContext().causalSoftmax(y, x); } inline void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, std::shared_ptr correction_bias, // F32 float routed_scaling_factor, size_t topk) { getInferenceContext().topkrouter(values, // F32 indices, // I32 x, correction_bias, // F32 routed_scaling_factor, topk); } inline void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate) { getInferenceContext().swiglu(out, up, gate); } inline void silu(std::shared_ptr out, std::shared_ptr input) { getInferenceContext().silu(out, input); } inline void randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature) { getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature); } inline void linear(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias) { getInferenceContext().linear(c, a, b, alpha, beta, residual, bias); } inline void dequant_linear(std::shared_ptr out, std::shared_ptr x, std::shared_ptr w_w, std::shared_ptr w_s, std::shared_ptr w_z, float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias) { auto w = Tensor::buffer(x->dtype(), {x->shape()[1], out->shape()[1]}, getInferenceContext().memory_pool); getInferenceContext().dequant(w, w_w, w_s, w_z); getInferenceContext().linear(out, x, w, alpha, beta, residual, bias); }