Commit 115badb9 authored by wooway777's avatar wooway777
Browse files

issue/21 - Made InferenceContext thread-local to allow cleaner operator calls.

parent 2a2ddc57
// inference_context.hpp
#pragma once
#include "cache_manager.hpp"
#include "jiuge/jiuge_impl.hpp"
#include "jiuge/jiuge_weight.hpp"
#include <cassert>
struct InferenceContext {
DeviceResource *rsrc;
......@@ -49,3 +49,60 @@ struct InferenceContext {
float alpha, float beta,
std::shared_ptr<Tensor> residual);
};
namespace {
thread_local InferenceContext *tls_inference_context = nullptr;
}
inline InferenceContext &getInferenceContext() {
assert(tls_inference_context != nullptr && "InferenceContext not set for this thread");
return *tls_inference_context;
}
inline void setInferenceContext(InferenceContext *ctx) {
tls_inference_context = ctx;
}
inline void add(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> b) {
getInferenceContext().add(c, a, b);
}
inline void rmsnorm(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w, float epsilon) {
getInferenceContext().rmsnorm(y, x, w, epsilon);
}
inline void gemm(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, float alpha, float beta) {
getInferenceContext().gemm(c, a, b, alpha, beta);
}
inline void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<Tensor> src) {
getInferenceContext().rearrange(dst, src);
}
inline void rope(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) {
getInferenceContext().rope(q, k, pos, sin, cos);
}
inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
getInferenceContext().causalSoftmax(y, x);
}
inline void swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> up,
std::shared_ptr<Tensor> gate) {
getInferenceContext().swiglu(out, up, gate);
}
inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature) {
getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature);
}
inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, float alpha, float beta,
std::shared_ptr<Tensor> residual) {
getInferenceContext().linear(c, a, b, alpha, beta, residual);
}
......@@ -117,7 +117,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output, InferenceContext &ctx) {
uint32_t *output) {
auto nlayer = meta.nlayer;
auto nkvh = meta.nkvh / ndev;
auto nh = meta.nh / ndev;
......@@ -191,16 +191,16 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
for (uint32_t layer = 0; layer < nlayer; layer++) {
// 1. Attention
// rms norm
ctx.rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
// qkv_proj
if (has_qkv_bias) {
ctx.rearrange(qkv_buf, rsrc.b_attn_qkv[layer]->view({ntok, (nh + nkvh * 2) * dh}, {0, 1}));
rearrange(qkv_buf, rsrc.b_attn_qkv[layer]->view({ntok, (nh + nkvh * 2) * dh}, {0, 1}));
}
ctx.linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, has_qkv_bias ? qkv_buf : nullptr);
linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, has_qkv_bias ? qkv_buf : nullptr);
// rope
auto qkv_rope = qkv_buf->viewReshaped({ntok, nh + nkvh * 2, dh});
ctx.rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
ctx.rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
......@@ -214,28 +214,28 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// self attention
// concat
ctx.rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k);
ctx.rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k);
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk
auto q_rearrange = rearrange_q_buf->viewReshaped({nkvh, ngroup, seq_len, dh});
ctx.rearrange(q_rearrange, q);
rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->viewReshaped({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
ctx.linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr);
linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr);
// softmax
auto qk_softmax = qk_buf->viewReshaped({nh, seq_len, total_len});
ctx.causalSoftmax(qk_softmax, qk_softmax);
causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
ctx.linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr);
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr);
// rearrange attn val
auto attn_val_gemm = attn_val_buf->viewReshaped({nkvh, ngroup, max_seq_len, dh});
ctx.rearrange(o, attn_val_gemm);
rearrange(o, attn_val_gemm);
token_offset += seq_len;
}
// o_proj
ctx.linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual
linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual
// All_reduce if distributed
if (rsrc.comm != nullptr) {
......@@ -245,10 +245,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infinirtStreamSynchronize(stream));
}
// 2. FFN
ctx.rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon);
ctx.linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr);
ctx.swiglu(gate_buf, up_buf, gate_buf);
ctx.linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual
rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon);
linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr);
swiglu(gate_buf, up_buf, gate_buf);
linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual
// All_reduce if distributed
if (rsrc.comm != nullptr) {
......@@ -264,21 +264,21 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
token_offset += seq_len;
ctx.rmsnorm(logits_out->slice(0, req, 1),
logits_in->slice(0, token_offset - 1, 1),
rsrc.w_out_norm,
meta.epsilon);
rmsnorm(logits_out->slice(0, req, 1),
logits_in->slice(0, token_offset - 1, 1),
rsrc.w_out_norm,
meta.epsilon);
}
ctx.linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr);
linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr);
std::random_device _rd;
std::mt19937 gen(_rd());
token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
ctx.randomSample(result_buf->view({}, {}),
prob_buf->view({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]);
randomSample(result_buf->view({}, {}),
prob_buf->view({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]);
token_offset += seq_len;
}
RUN_INFINI(infinirtStreamSynchronize(stream));
......@@ -327,6 +327,9 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
CacheManager cache_manager(256);
InferenceContext ctx(rsrc, &cache_manager, rsrc->stream);
// Set the inference context for this thread
setInferenceContext(&ctx);
// Create Device Resource
createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm);
{
......@@ -347,8 +350,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok,
req.req_lens, req.nreq, req.req_pos, req.kv_caches,
req.temperature, req.topk, req.topp, req.output,
ctx);
req.temperature, req.topk, req.topp, req.output);
state.proceed = false;
lock.unlock();
......@@ -357,6 +359,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
// Clean-Up
releaseDeviceResource(*rsrc);
setInferenceContext(nullptr); // Clear the context when done
}
JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector<int> device_ids) : meta(*_meta) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment