// inference_context.hpp #pragma once #include "cache_manager.hpp" #include "jiuge/jiuge_impl.hpp" #include "jiuge/jiuge_weight.hpp" struct InferenceContext { DeviceResource *rsrc; CacheManager *cache_manager; infinirtStream_t stream; std::shared_ptr workspace_storage; size_t current_workspace_size = 0; InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream); void ensure_workspace(size_t required_size); void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, float epsilon); void gemm(std::shared_ptr c, std::shared_ptr c_desc_overwrite, std::shared_ptr a, std::shared_ptr a_desc_overwrite, std::shared_ptr b, std::shared_ptr b_desc_overwrite, float alpha, float beta); void rearrange(std::shared_ptr dst, std::shared_ptr dst_desc_overwrite, std::shared_ptr src, std::shared_ptr src_desc_overwrite); void rope(std::shared_ptr q, std::shared_ptr k, std::shared_ptr pos, std::shared_ptr sin, std::shared_ptr cos); void causalSoftmax(std::shared_ptr y, std::shared_ptr y_desc_overwrite, std::shared_ptr x, std::shared_ptr x_desc_overwrite); void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate); void randomSample(std::shared_ptr out, std::shared_ptr out_desc_overwrite, std::shared_ptr prob, std::shared_ptr prob_desc_overwrite, float random_val, float top_p, uint32_t top_k, float temperature); };