// inference_context.hpp #pragma once #include "cache_manager.hpp" #include "jiuge/jiuge_impl.hpp" #include "jiuge/jiuge_weight.hpp" struct InferenceContext { DeviceResource *rsrc; CacheManager *cache_manager; infinirtStream_t stream; std::shared_ptr workspace_storage; size_t current_workspace_size = 0; InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream); void ensure_workspace(size_t required_size); void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, float epsilon); void gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta); void rearrange(std::shared_ptr dst, std::shared_ptr src); void rope(std::shared_ptr q, std::shared_ptr k, std::shared_ptr pos, std::shared_ptr sin, std::shared_ptr cos); void causalSoftmax(std::shared_ptr y, std::shared_ptr x); void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate); void randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature); };